mirror of
https://github.com/pese-git/llm-arch-research.git
synced 2026-01-23 13:00:54 +00:00
feat: add LLaMA model implementation with RoPE positional encoding
- Added LLaMA model architecture with RMSNorm and SwiGLU activation - Implemented Rotary Positional Embeddings (RoPE) for better positional encoding - Created training script for LLaMA with BPE tokenizer - Fixed matplotlib dependency version in uv.lock - Added LLaMA module initialization The implementation includes: - TokenEmbeddings, HeadAttention, MultiHeadAttention with RoPE support - RMSNorm normalization layer - SwiGLU feed-forward activation - Cached decoder implementation for efficient generation
This commit is contained in:
231
experiments/llm_only/train_llama_bpe.py
Normal file
231
experiments/llm_only/train_llama_bpe.py
Normal file
@@ -0,0 +1,231 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Experiment: train_gpt_bpe.py
|
||||||
|
Description: Обучение GPT модели с собственным BPE токенизатором.
|
||||||
|
Использует только библиотеку llm без зависимостей от HuggingFace.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
# Добавляем путь к shared модулям
|
||||||
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
|
|
||||||
|
from llm.models.llama import Llama
|
||||||
|
from llm.tokenizers import BPETokenizer
|
||||||
|
from llm.training.dataset import TextDataset
|
||||||
|
from llm.training.trainer import Trainer
|
||||||
|
|
||||||
|
from shared.configs import (
|
||||||
|
TRAIN_TEXTS, BASE_GPT_CONFIG, BPE_CONFIG,
|
||||||
|
TRAINING_CONFIG, PATHS, TEST_PROMPTS
|
||||||
|
)
|
||||||
|
from shared.data import (
|
||||||
|
load_training_data, ensure_directories,
|
||||||
|
print_experiment_info, ExperimentLogger
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def train_bpe_tokenizer(texts: list, config: dict) -> BPETokenizer:
|
||||||
|
"""
|
||||||
|
Обучает BPE токенизатор на текстах.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts: Список текстов для обучения
|
||||||
|
config: Конфигурация токенизатора
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
BPETokenizer: Обученный токенизатор
|
||||||
|
"""
|
||||||
|
print("🔧 Обучение BPE токенизатора...")
|
||||||
|
|
||||||
|
tokenizer = BPETokenizer()
|
||||||
|
tokenizer.train(
|
||||||
|
texts=texts,
|
||||||
|
vocab_size=config["vocab_size"],
|
||||||
|
special_tokens=config["special_tokens"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Сохраняем токенизатор
|
||||||
|
os.makedirs(os.path.dirname(PATHS["bpe_tokenizer"]), exist_ok=True)
|
||||||
|
tokenizer.save(PATHS["bpe_tokenizer"])
|
||||||
|
|
||||||
|
print(f"✅ BPE токенизатор обучен и сохранен: {PATHS['bpe_tokenizer']}")
|
||||||
|
print(f"📊 Размер словаря: {tokenizer.get_vocab_size()}")
|
||||||
|
|
||||||
|
return tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
def test_tokenizer(tokenizer: BPETokenizer, texts: list):
|
||||||
|
"""
|
||||||
|
Тестирует токенизатор на примерах.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tokenizer: Обученный токенизатор
|
||||||
|
texts: Список тестовых текстов
|
||||||
|
"""
|
||||||
|
print("\n🧪 Тестирование токенизатора:")
|
||||||
|
|
||||||
|
for i, text in enumerate(texts[:3]):
|
||||||
|
print(f"\nПример {i+1}:")
|
||||||
|
print(f" Исходный текст: '{text}'")
|
||||||
|
|
||||||
|
# Кодирование
|
||||||
|
tokens = tokenizer.encode(text)
|
||||||
|
token_strings = tokenizer.tokenize(text)
|
||||||
|
|
||||||
|
print(f" Токены (ID): {tokens}")
|
||||||
|
print(f" Токены (текст): {token_strings}")
|
||||||
|
print(f" Количество токенов: {len(tokens)}")
|
||||||
|
|
||||||
|
# Декодирование
|
||||||
|
decoded = tokenizer.decode(tokens)
|
||||||
|
print(f" Декодированный: '{decoded}'")
|
||||||
|
|
||||||
|
if text == decoded:
|
||||||
|
print(" ✅ Кодирование/декодирование корректно")
|
||||||
|
else:
|
||||||
|
print(" ⚠️ Небольшие расхождения")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Основная функция эксперимента."""
|
||||||
|
# === Настройка эксперимента ===
|
||||||
|
experiment_name = "Обучение Llama с BPE токенизатором (только llm)"
|
||||||
|
experiment_config = {
|
||||||
|
"model": "Llama",
|
||||||
|
"tokenizer": "BPE",
|
||||||
|
"vocab_size": BPE_CONFIG["vocab_size"],
|
||||||
|
"training_epochs": TRAINING_CONFIG["num_epochs"],
|
||||||
|
"batch_size": TRAINING_CONFIG["batch_size"],
|
||||||
|
"learning_rate": TRAINING_CONFIG["learning_rate"]
|
||||||
|
}
|
||||||
|
|
||||||
|
print_experiment_info(experiment_name, experiment_config)
|
||||||
|
ensure_directories()
|
||||||
|
logger = ExperimentLogger(experiment_name)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# === Подготовка данных ===
|
||||||
|
train_texts, val_texts = load_training_data()
|
||||||
|
print(f"📊 Данные: {len(train_texts)} train, {len(val_texts)} validation")
|
||||||
|
|
||||||
|
# === Обучение токенизатора ===
|
||||||
|
if os.path.exists(PATHS["bpe_tokenizer"]):
|
||||||
|
print("📝 Загрузка предварительно обученного токенизатора...")
|
||||||
|
tokenizer = BPETokenizer.load(PATHS["bpe_tokenizer"])
|
||||||
|
print(f"✅ Токенизатор загружен (vocab_size={tokenizer.get_vocab_size()})")
|
||||||
|
else:
|
||||||
|
tokenizer = train_bpe_tokenizer(TRAIN_TEXTS, BPE_CONFIG)
|
||||||
|
|
||||||
|
# Тестируем токенизатор
|
||||||
|
test_tokenizer(tokenizer, TEST_PROMPTS[:3])
|
||||||
|
|
||||||
|
# === Инициализация модели ===
|
||||||
|
model_config = BASE_GPT_CONFIG.copy()
|
||||||
|
model_config["vocab_size"] = tokenizer.get_vocab_size()
|
||||||
|
|
||||||
|
print(f"\n🔧 Инициализация Llama модели...")
|
||||||
|
print(f" Размер словаря: {model_config['vocab_size']}")
|
||||||
|
print(f" Размер эмбеддингов: {model_config['embed_dim']}")
|
||||||
|
print(f" Количество слоев: {model_config['num_layers']}")
|
||||||
|
print(f" Количество голов внимания: {model_config['num_heads']}")
|
||||||
|
|
||||||
|
model = Llama(model_config)
|
||||||
|
|
||||||
|
# === Подготовка датасета ===
|
||||||
|
print(f"\n📊 Подготовка датасета...")
|
||||||
|
train_dataset = TextDataset(
|
||||||
|
train_texts,
|
||||||
|
tokenizer,
|
||||||
|
block_size=model_config["max_position_embeddings"]
|
||||||
|
)
|
||||||
|
print(f" Размер train датасета: {len(train_dataset)} примеров")
|
||||||
|
|
||||||
|
# === Обучение модели ===
|
||||||
|
print(f"\n🎯 Начало обучения Llama модели...")
|
||||||
|
|
||||||
|
trainer = Trainer(
|
||||||
|
model=model,
|
||||||
|
train_dataset=train_dataset,
|
||||||
|
lr=TRAINING_CONFIG["learning_rate"],
|
||||||
|
batch_size=TRAINING_CONFIG["batch_size"],
|
||||||
|
num_epochs=TRAINING_CONFIG["num_epochs"],
|
||||||
|
warmup_steps=TRAINING_CONFIG["warmup_steps"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Запускаем обучение
|
||||||
|
trainer.train()
|
||||||
|
|
||||||
|
# === Сохранение модели ===
|
||||||
|
print(f"\n💾 Сохранение модели...")
|
||||||
|
os.makedirs(os.path.dirname(PATHS["gpt_bpe_model"]), exist_ok=True)
|
||||||
|
|
||||||
|
# Сохраняем модель
|
||||||
|
torch.save(model.state_dict(), PATHS["gpt_bpe_model"])
|
||||||
|
|
||||||
|
# Сохраняем конфигурацию
|
||||||
|
import json
|
||||||
|
with open(PATHS["gpt_bpe_config"], 'w', encoding='utf-8') as f:
|
||||||
|
json.dump(model_config, f, indent=2, ensure_ascii=False)
|
||||||
|
|
||||||
|
print(f"✅ Модель сохранена:")
|
||||||
|
print(f" - {PATHS['gpt_bpe_model']}: веса модели")
|
||||||
|
print(f" - {PATHS['gpt_bpe_config']}: конфигурация модели")
|
||||||
|
print(f" - {PATHS['bpe_tokenizer']}: токенизатор")
|
||||||
|
|
||||||
|
# === Тестирование генерации ===
|
||||||
|
print(f"\n🧪 Тестирование генерации текста...")
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
for prompt in TEST_PROMPTS[:3]:
|
||||||
|
print(f"\n🔤 Промпт: '{prompt}'")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Кодируем промпт
|
||||||
|
input_ids = tokenizer.encode(prompt, add_special_tokens=False)
|
||||||
|
input_tensor = torch.tensor([input_ids], dtype=torch.long)
|
||||||
|
|
||||||
|
# Генерируем текст
|
||||||
|
with torch.no_grad():
|
||||||
|
generated_ids = model.generate(
|
||||||
|
x=input_tensor,
|
||||||
|
max_new_tokens=20,
|
||||||
|
do_sample=True,
|
||||||
|
temperature=0.8
|
||||||
|
)
|
||||||
|
|
||||||
|
# Декодируем результат
|
||||||
|
generated_text = tokenizer.decode(generated_ids[0].tolist())
|
||||||
|
generated_part = generated_text[len(prompt):]
|
||||||
|
|
||||||
|
print(f"🎯 Сгенерировано: '{generated_part}'")
|
||||||
|
print(f"📄 Полный текст: '{generated_text}'")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Ошибка генерации: {e}")
|
||||||
|
|
||||||
|
# === Сохранение результатов ===
|
||||||
|
results = {
|
||||||
|
"experiment": experiment_name,
|
||||||
|
"model_config": model_config,
|
||||||
|
"training_config": TRAINING_CONFIG,
|
||||||
|
"tokenizer_vocab_size": tokenizer.get_vocab_size(),
|
||||||
|
"final_loss": "см. логи обучения" # В реальном эксперименте можно сохранить final loss
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.save_logs("checkpoints/llm_only_training_logs.json")
|
||||||
|
|
||||||
|
print(f"\n🎉 Эксперимент завершен успешно!")
|
||||||
|
print(f"\n💡 Для использования обученной модели:")
|
||||||
|
print(f" uv run python experiments/llm_only/generate_gpt_bpe.py")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Ошибка в эксперименте: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
3
llm/src/llm/models/llama/__init__.py
Normal file
3
llm/src/llm/models/llama/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
from .llama import Llama
|
||||||
|
|
||||||
|
__all__ = ["Llama"]
|
||||||
391
llm/src/llm/models/llama/llama.py
Normal file
391
llm/src/llm/models/llama/llama.py
Normal file
@@ -0,0 +1,391 @@
|
|||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from torch import Tensor
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from math import sqrt
|
||||||
|
|
||||||
|
from llm.core.base_model import BaseModel
|
||||||
|
from llm.core.token_embeddings import TokenEmbeddings
|
||||||
|
|
||||||
|
|
||||||
|
class SiLU(nn.Module):
|
||||||
|
def forward(self, x: torch.Tensor): # [batch_size × seq_len × emb_size]
|
||||||
|
return torch.sigmoid(x) * x
|
||||||
|
|
||||||
|
class RMSNorm(nn.Module):
|
||||||
|
def __init__(self, dim: int, eps: float = 1e-6):
|
||||||
|
super().__init__()
|
||||||
|
self._eps = eps
|
||||||
|
self._w = nn.Parameter(torch.ones(dim))
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor): # [batch_size × seq_len × emb_size]
|
||||||
|
rms = (x.pow(2).mean(-1, keepdim=True) + self._eps) ** 0.5
|
||||||
|
norm_x = x / rms
|
||||||
|
return self._w * norm_x
|
||||||
|
|
||||||
|
class SwiGLU(nn.Module):
|
||||||
|
def __init__(self, emb_size: int, dropout: float = 0.1):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self._gate = nn.Linear(emb_size, 4 * emb_size)
|
||||||
|
self._up = nn.Linear(emb_size, 4 * emb_size)
|
||||||
|
self._down = nn.Linear(4 * emb_size, emb_size)
|
||||||
|
self._activation = SiLU()
|
||||||
|
self._dropout = nn.Dropout(dropout)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor): # [batch_size × seq_len × emb_size].
|
||||||
|
gate_out = self._gate(x) # [batch, seq, 4*emb]
|
||||||
|
activation_out = self._activation(gate_out) # [batch, seq, 4*emb]
|
||||||
|
up_out = self._up(x) # [batch, seq, 4*emb]
|
||||||
|
out = up_out * activation_out # поэлементное!
|
||||||
|
out = self._down(out) # [batch, seq, emb]
|
||||||
|
return self._dropout(out)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class RoPE(nn.Module):
|
||||||
|
def __init__(self, head_size: int, max_seq_len: int, base: int = 10_000):
|
||||||
|
super().__init__()
|
||||||
|
assert head_size % 2 == 0, "head_size должен быть четным"
|
||||||
|
|
||||||
|
# Обратные частоты
|
||||||
|
freqs = 1.0 / (base ** (2 * torch.arange(head_size // 2).float() / head_size))
|
||||||
|
|
||||||
|
# Позиции
|
||||||
|
positions = torch.arange(max_seq_len).float()
|
||||||
|
|
||||||
|
# Матрица частот (внешнее произведение)
|
||||||
|
#freq_matrix = torch.outer(positions, freqs)
|
||||||
|
freq_matrix = positions.unsqueeze(1) * freqs.unsqueeze(0)
|
||||||
|
|
||||||
|
# Матрицы косинусов и синусов
|
||||||
|
self.register_buffer('cos_matrix', torch.cos(freq_matrix))
|
||||||
|
self.register_buffer('sin_matrix', torch.sin(freq_matrix))
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor): # Получает на вход тензор x (тип float) размером [batch_size × seq_len × head_size]
|
||||||
|
seq_len = x.size(1)
|
||||||
|
# Берем нужную часть матриц и приводим к типу x
|
||||||
|
cos = self.cos_matrix[:seq_len].to(x.dtype) # [seq_len, head_size//2]
|
||||||
|
sin = self.sin_matrix[:seq_len].to(x.dtype) # [seq_len, head_size//2]
|
||||||
|
|
||||||
|
|
||||||
|
# Разделяем на четные и нечетные
|
||||||
|
x_even = x[:, :, 0::2] # [batch_size, seq_len, head_size//2]
|
||||||
|
x_odd = x[:, :, 1::2] # [batch_size, seq_len, head_size//2]
|
||||||
|
|
||||||
|
# Применяем поворот
|
||||||
|
x_rotated_even = x_even * cos - x_odd * sin
|
||||||
|
x_rotated_odd = x_even * sin + x_odd * cos
|
||||||
|
|
||||||
|
|
||||||
|
# Объединяем обратно
|
||||||
|
x_rotated = torch.stack([x_rotated_even, x_rotated_odd], dim=-1)
|
||||||
|
x_rotated = x_rotated.flatten(-2) # [batch_size, seq_len, head_size]
|
||||||
|
|
||||||
|
return x_rotated
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class HeadAttention(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, emb_size: int, head_size: int, max_seq_len: int, rope: RoPE):
|
||||||
|
super().__init__()
|
||||||
|
self._emb_size = emb_size
|
||||||
|
self._head_size = head_size
|
||||||
|
self._max_seq_len = max_seq_len
|
||||||
|
self._rope = rope
|
||||||
|
|
||||||
|
self._k = nn.Linear(emb_size, head_size)
|
||||||
|
self._q = nn.Linear(emb_size, head_size)
|
||||||
|
self._v = nn.Linear(emb_size, head_size)
|
||||||
|
|
||||||
|
mask = torch.tril(torch.ones(max_seq_len, max_seq_len))
|
||||||
|
self.register_buffer('_tril_mask', mask.bool() if hasattr(torch, 'bool') else mask.byte())
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, use_cache: bool = True, cache: tuple = None) -> tuple:
|
||||||
|
seq_len = x.shape[1]
|
||||||
|
if seq_len > self._max_seq_len:
|
||||||
|
raise ValueError(f"Длина последовательности {seq_len} превышает максимум {self._max_seq_len}")
|
||||||
|
|
||||||
|
k = self._k(x) # [B, T, hs]
|
||||||
|
q = self._q(x) # [B, T, hs]
|
||||||
|
v = self._v(x) # [B, T, hs]
|
||||||
|
|
||||||
|
# ✅ Применяем RoPE к Q и K (НЕ к V!)
|
||||||
|
q = self._rope(q) # [B, T, hs]
|
||||||
|
k = self._rope(k) # [B, T, hs]
|
||||||
|
|
||||||
|
if cache is not None:
|
||||||
|
k_cache, v_cache = cache
|
||||||
|
k = torch.cat([k_cache, k], dim=1) # [B, cache_len + T, hs]
|
||||||
|
v = torch.cat([v_cache, v], dim=1) # [B, cache_len + T, hs]
|
||||||
|
|
||||||
|
scores = q @ k.transpose(-2, -1) / sqrt(self._head_size)
|
||||||
|
|
||||||
|
if cache is None:
|
||||||
|
scores = scores.masked_fill(~self._tril_mask[:seq_len, :seq_len], float('-inf'))
|
||||||
|
|
||||||
|
weights = F.softmax(scores, dim=-1)
|
||||||
|
x_out = weights @ v # [B, T, hs]
|
||||||
|
|
||||||
|
if use_cache is True:
|
||||||
|
return (x_out, (k, v))
|
||||||
|
else:
|
||||||
|
return (x_out, None)
|
||||||
|
|
||||||
|
from torch import nn
|
||||||
|
import torch
|
||||||
|
import math
|
||||||
|
|
||||||
|
class MultiHeadAttention(nn.Module):
|
||||||
|
def __init__(self, num_heads: int, emb_size: int, head_size: int, max_seq_len: int, rope: RoPE, dropout: float = 0.1):
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
self._heads = nn.ModuleList([
|
||||||
|
HeadAttention(
|
||||||
|
emb_size=emb_size,
|
||||||
|
head_size=head_size,
|
||||||
|
max_seq_len=max_seq_len,
|
||||||
|
rope=rope,
|
||||||
|
) for _ in range(num_heads)
|
||||||
|
])
|
||||||
|
self._layer = nn.Linear(head_size * num_heads, emb_size)
|
||||||
|
self._dropout = nn.Dropout(dropout)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, mask: torch.Tensor = None, use_cache: bool = True, cache: list = None):
|
||||||
|
|
||||||
|
attention_results = []
|
||||||
|
for i, head in enumerate(self._heads):
|
||||||
|
head_cache = cache[i] if cache is not None else None
|
||||||
|
result = head(x, use_cache=use_cache, cache=head_cache)
|
||||||
|
attention_results.append(result)
|
||||||
|
|
||||||
|
outputs, caches = zip(*attention_results)
|
||||||
|
attention_outputs = list(outputs)
|
||||||
|
kv_caches = list(caches)
|
||||||
|
|
||||||
|
concatenated_attention = torch.cat(attention_outputs, dim=-1)
|
||||||
|
|
||||||
|
projected_output = self._layer(concatenated_attention)
|
||||||
|
|
||||||
|
final_output = self._dropout(projected_output)
|
||||||
|
|
||||||
|
if use_cache is True:
|
||||||
|
return (final_output, kv_caches)
|
||||||
|
else:
|
||||||
|
return (final_output, None)
|
||||||
|
|
||||||
|
|
||||||
|
class GELU(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.sqrt_2_over_pi = torch.sqrt(torch.tensor(2.0) / math.pi)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
return 0.5 * x * (1 + torch.tanh(
|
||||||
|
self.sqrt_2_over_pi * (x + 0.044715 * torch.pow(x, 3))
|
||||||
|
))
|
||||||
|
|
||||||
|
|
||||||
|
class Decoder(nn.Module):
|
||||||
|
def __init__(self,
|
||||||
|
num_heads: int,
|
||||||
|
emb_size: int,
|
||||||
|
head_size: int,
|
||||||
|
max_seq_len: int,
|
||||||
|
rope: RoPE,
|
||||||
|
dropout: float = 0.1
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self._heads = MultiHeadAttention(
|
||||||
|
num_heads=num_heads,
|
||||||
|
emb_size=emb_size,
|
||||||
|
head_size=head_size,
|
||||||
|
max_seq_len=max_seq_len,
|
||||||
|
rope=rope,
|
||||||
|
dropout=dropout
|
||||||
|
)
|
||||||
|
self._ff = SwiGLU(emb_size=emb_size, dropout=dropout)
|
||||||
|
self._norm1 = RMSNorm(emb_size)
|
||||||
|
self._norm2 = RMSNorm(emb_size)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, mask: torch.Tensor = None, use_cache: bool = True, cache: list = None) -> torch.Tensor:
|
||||||
|
norm1_out = self._norm1(x)
|
||||||
|
attention, kv_caches = self._heads(norm1_out, mask, use_cache=use_cache, cache=cache)
|
||||||
|
out = attention + x
|
||||||
|
|
||||||
|
norm2_out = self._norm2(out)
|
||||||
|
ffn_out = self._ff(norm2_out)
|
||||||
|
|
||||||
|
if use_cache is True:
|
||||||
|
return (ffn_out + out, kv_caches)
|
||||||
|
else:
|
||||||
|
return (ffn_out + out, None)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class Llama(BaseModel):
|
||||||
|
def __init__(self,config):
|
||||||
|
super().__init__(config)
|
||||||
|
|
||||||
|
# Инициализация слоев
|
||||||
|
self._max_seq_len = config["max_position_embeddings"]
|
||||||
|
self._token_embeddings = TokenEmbeddings(
|
||||||
|
vocab_size=config["vocab_size"],
|
||||||
|
emb_size=config["embed_dim"]
|
||||||
|
)
|
||||||
|
self._position_embeddings = RoPE(
|
||||||
|
head_size=config["embed_dim"] // config["num_heads"],
|
||||||
|
max_seq_len=config["max_position_embeddings"]
|
||||||
|
)
|
||||||
|
|
||||||
|
self._dropout = nn.Dropout(config["dropout"])
|
||||||
|
self._decoders = nn.ModuleList([Decoder(
|
||||||
|
num_heads=config["num_heads"],
|
||||||
|
emb_size=config["embed_dim"],
|
||||||
|
head_size=config["embed_dim"] // config["num_heads"],
|
||||||
|
max_seq_len=config["max_position_embeddings"],
|
||||||
|
rope=self._position_embeddings,
|
||||||
|
dropout=config["dropout"],
|
||||||
|
) for _ in range(config["num_layers"])])
|
||||||
|
self._norm = RMSNorm(config["embed_dim"])
|
||||||
|
self._linear = nn.Linear(config["embed_dim"], config["vocab_size"])
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, use_cache: bool = True, cache: list = None) -> tuple:
|
||||||
|
# Проверка длины последовательности (только при отсутствии кэша)
|
||||||
|
if cache is None and x.size(1) > self._max_seq_len:
|
||||||
|
raise ValueError(f"Длина последовательности {x.size(1)} превышает максимальную {self.max_seq_len}")
|
||||||
|
|
||||||
|
|
||||||
|
# Вычисление start_pos из кэша (если кэш передан)
|
||||||
|
#if cache is not None:
|
||||||
|
# # При кэше обрабатываем только один токен (последний)
|
||||||
|
# seq_len = 1
|
||||||
|
# # Вычисляем start_pos из самого нижнего уровня кэша
|
||||||
|
# if cache and cache[0] and cache[0][0]:
|
||||||
|
# key_cache, _ = cache[0][0] # Первый декодер, первая голова
|
||||||
|
# start_pos = key_cache.size(1) # cache_len
|
||||||
|
# else:
|
||||||
|
# start_pos = 0
|
||||||
|
#else:
|
||||||
|
# # Без кэша работаем как раньше
|
||||||
|
# start_pos = 0
|
||||||
|
# seq_len = x.size(1)
|
||||||
|
|
||||||
|
# Эмбеддинги токенов и позиций
|
||||||
|
tok_out = self._token_embeddings(x) # [batch, seq_len, emb_size]
|
||||||
|
#pos_out = self._position_embeddings(x) # [batch, seq_len, emb_size]
|
||||||
|
|
||||||
|
# Комбинирование
|
||||||
|
out = self._dropout(tok_out) # [batch, seq_len, emb_size]
|
||||||
|
|
||||||
|
# Стек декодеров с передачей кэша
|
||||||
|
new_cache = []
|
||||||
|
for i, decoder in enumerate(self._decoders):
|
||||||
|
decoder_cache = cache[i] if cache is not None else None
|
||||||
|
decoder_result = decoder(out, use_cache=use_cache, cache=decoder_cache)
|
||||||
|
|
||||||
|
# Извлекаем результат из кортежа
|
||||||
|
if use_cache:
|
||||||
|
out, decoder_new_cache = decoder_result
|
||||||
|
new_cache.append(decoder_new_cache)
|
||||||
|
else:
|
||||||
|
out = decoder_result[0]
|
||||||
|
|
||||||
|
out = self._norm(out)
|
||||||
|
logits = self._linear(out)
|
||||||
|
|
||||||
|
# Возвращаем результат с учетом use_cache
|
||||||
|
if use_cache:
|
||||||
|
return (logits, new_cache)
|
||||||
|
else:
|
||||||
|
return (logits, None)
|
||||||
|
|
||||||
|
def generate(self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
max_new_tokens: int,
|
||||||
|
do_sample: bool,
|
||||||
|
temperature: float = 1.0,
|
||||||
|
top_k: int = None,
|
||||||
|
top_p: float = None,
|
||||||
|
use_cache: bool = True
|
||||||
|
) -> torch.Tensor:
|
||||||
|
cache = None
|
||||||
|
|
||||||
|
for _ in range(max_new_tokens):
|
||||||
|
if use_cache and cache is not None:
|
||||||
|
# Используем кэш - передаем только последний токен
|
||||||
|
x_input = x[:, -1:] # [batch_size, 1]
|
||||||
|
else:
|
||||||
|
# Первая итерация или кэш отключен - передаем всю последовательность
|
||||||
|
x_input = x
|
||||||
|
|
||||||
|
# Прямой проход с кэшем
|
||||||
|
logits, new_cache = self.forward(x_input, use_cache=use_cache, cache=cache)
|
||||||
|
|
||||||
|
# Обновляем кэш для следующей итерации
|
||||||
|
if use_cache:
|
||||||
|
cache = new_cache
|
||||||
|
|
||||||
|
last_logits = logits[:, -1, :] # [batch_size, vocab_size]
|
||||||
|
|
||||||
|
# Масштабируем логиты температурой
|
||||||
|
if temperature > 0:
|
||||||
|
logits_scaled = last_logits / temperature
|
||||||
|
else:
|
||||||
|
logits_scaled = last_logits
|
||||||
|
|
||||||
|
if do_sample == True and top_k != None:
|
||||||
|
_, topk_indices = torch.topk(logits_scaled, top_k, dim=-1)
|
||||||
|
|
||||||
|
# # Заменим все НЕ top-k логиты на -inf
|
||||||
|
masked_logits = logits_scaled.clone()
|
||||||
|
vocab_size = logits_scaled.size(-1)
|
||||||
|
|
||||||
|
# создаём маску: 1, если токен НЕ в topk_indices
|
||||||
|
mask = torch.ones_like(logits_scaled, dtype=torch.uint8)
|
||||||
|
mask.scatter_(1, topk_indices, 0) # 0 там, где top-k индексы
|
||||||
|
masked_logits[mask.byte()] = float('-inf')
|
||||||
|
|
||||||
|
logits_scaled = masked_logits
|
||||||
|
|
||||||
|
if do_sample == True and top_p != None:
|
||||||
|
# 1. Применим softmax, чтобы получить вероятности:
|
||||||
|
probs = F.softmax(logits_scaled, dim=-1) # [B, vocab_size]
|
||||||
|
# 2. Отсортируем токены по убыванию вероятностей:
|
||||||
|
sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1)
|
||||||
|
# 3. Посчитаем кумулятивную сумму вероятностей:
|
||||||
|
cum_probs = torch.cumsum(sorted_probs, dim=-1) # [B, vocab_size]
|
||||||
|
# 4. Определим маску: оставить токены, пока сумма < top_p
|
||||||
|
sorted_mask = (cum_probs <= top_p).byte() # [B, vocab_size]
|
||||||
|
# Гарантируем, что хотя бы первый токен останется
|
||||||
|
sorted_mask[:, 0] = 1
|
||||||
|
# 5. Преобразуем маску обратно в оригинальный порядок:
|
||||||
|
# Создаём полную маску из 0
|
||||||
|
mask = torch.zeros_like(probs, dtype=torch.uint8)
|
||||||
|
# Устанавливаем 1 в местах нужных токенов
|
||||||
|
mask.scatter_(dim=1, index=sorted_indices, src=sorted_mask)
|
||||||
|
# 6. Зануляем логиты токенов вне топ-p:
|
||||||
|
logits_scaled[~mask] = float('-inf')
|
||||||
|
|
||||||
|
# 4. Применяем Softmax
|
||||||
|
probs = F.softmax(logits_scaled, dim=-1) # [batch_size, vocab_size]
|
||||||
|
|
||||||
|
|
||||||
|
if do_sample == True:
|
||||||
|
# 5. Если do_sample равен True, то отбираем токен случайно с помощью torch.multinomial
|
||||||
|
next_token = torch.multinomial(probs, num_samples=1) # [batch_size, 1]
|
||||||
|
else:
|
||||||
|
# 5. Если do_sample равен False, то выбираем токен с максимальной вероятностью
|
||||||
|
next_token = torch.argmax(probs, dim=-1, keepdim=True) # [batch_size, 1]
|
||||||
|
|
||||||
|
# 6. Добавляем его к последовательности
|
||||||
|
x = torch.cat([x, next_token], dim=1) # [batch_size, seq_len+1]
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@property
|
||||||
|
def max_seq_len(self) -> int:
|
||||||
|
return self._max_seq_len
|
||||||
4
uv.lock
generated
4
uv.lock
generated
@@ -1759,7 +1759,6 @@ dependencies = [
|
|||||||
dev = [
|
dev = [
|
||||||
{ name = "black" },
|
{ name = "black" },
|
||||||
{ name = "jupyter" },
|
{ name = "jupyter" },
|
||||||
{ name = "matplotlib" },
|
|
||||||
{ name = "mypy" },
|
{ name = "mypy" },
|
||||||
{ name = "pytest" },
|
{ name = "pytest" },
|
||||||
{ name = "ruff" },
|
{ name = "ruff" },
|
||||||
@@ -1777,8 +1776,7 @@ requires-dist = [
|
|||||||
{ name = "ipykernel" },
|
{ name = "ipykernel" },
|
||||||
{ name = "jupyter", marker = "extra == 'dev'", specifier = ">=1.0.0" },
|
{ name = "jupyter", marker = "extra == 'dev'", specifier = ">=1.0.0" },
|
||||||
{ name = "llm", editable = "llm" },
|
{ name = "llm", editable = "llm" },
|
||||||
{ name = "matplotlib" },
|
{ name = "matplotlib", specifier = "==3.10.6" },
|
||||||
{ name = "matplotlib", marker = "extra == 'dev'", specifier = ">=1.0.0" },
|
|
||||||
{ name = "mypy", marker = "extra == 'dev'", specifier = ">=1.8.0" },
|
{ name = "mypy", marker = "extra == 'dev'", specifier = ">=1.8.0" },
|
||||||
{ name = "pytest", marker = "extra == 'dev'", specifier = ">=8.0.0" },
|
{ name = "pytest", marker = "extra == 'dev'", specifier = ">=8.0.0" },
|
||||||
{ name = "pytest", marker = "extra == 'test'", specifier = ">=8.0.0" },
|
{ name = "pytest", marker = "extra == 'test'", specifier = ">=8.0.0" },
|
||||||
|
|||||||
Reference in New Issue
Block a user