mirror of
https://github.com/pese-git/llm-arch-research.git
synced 2026-01-23 21:10:54 +00:00
feat(gpt2): add GPT2 architecture with universal FeedForward, CachedDecoder, and refactored components. Core modules now shared; add train and generate scripts for GPT2-BPE.
This commit is contained in:
61
llm/src/llm/core/cached_decoder.py
Normal file
61
llm/src/llm/core/cached_decoder.py
Normal file
@@ -0,0 +1,61 @@
|
||||
# llm/src/llm/core/cached_decoder.py
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from .feed_forward import FeedForward
|
||||
from .multi_head_attention import MultiHeadAttention
|
||||
|
||||
class CachedDecoder(nn.Module):
|
||||
"""
|
||||
Универсальный декодер с поддержкой кэша для autoregressive использования (GPT, LLAMA и пр).
|
||||
- Поддерживает использование past_key_values для быстрого генеративного инференса.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
emb_size: int,
|
||||
head_size: int,
|
||||
max_seq_len: int,
|
||||
dropout: float = 0.1,
|
||||
activation: str = "gelu",
|
||||
):
|
||||
super().__init__()
|
||||
self._heads = MultiHeadAttention(
|
||||
num_heads=num_heads,
|
||||
emb_size=emb_size,
|
||||
head_size=head_size,
|
||||
max_seq_len=max_seq_len,
|
||||
dropout=dropout,
|
||||
)
|
||||
self._ff = FeedForward(emb_size=emb_size, dropout=dropout, activation=activation)
|
||||
self._norm1 = nn.LayerNorm(emb_size)
|
||||
self._norm2 = nn.LayerNorm(emb_size)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
mask: torch.Tensor = None,
|
||||
use_cache: bool = True,
|
||||
cache: list = None,
|
||||
):
|
||||
"""
|
||||
x: [batch, seq_len, emb_size]
|
||||
mask: (optional)
|
||||
use_cache: использовать ли кэширование KV-слоев (инкрементальный генератив, GPT-style)
|
||||
cache: список кэшей для голов (или None)
|
||||
Возвращает: (output, new_cache) если use_cache=True, иначе (output, None)
|
||||
"""
|
||||
norm1_out = self._norm1(x)
|
||||
# Передаём все cache/use_cache дальше в attention
|
||||
attention, kv_caches = self._heads(
|
||||
norm1_out, mask=mask, use_cache=use_cache, cache=cache
|
||||
)
|
||||
out = attention + x
|
||||
norm2_out = self._norm2(out)
|
||||
ffn_out = self._ff(norm2_out)
|
||||
result = ffn_out + out
|
||||
|
||||
if use_cache:
|
||||
return (result, kv_caches)
|
||||
else:
|
||||
return (result, None)
|
||||
@@ -88,7 +88,7 @@ class Decoder(nn.Module):
|
||||
4. Добавляем residual connection и LayerNorm
|
||||
"""
|
||||
# Self-Attention блок
|
||||
attention = self._heads(x, mask)
|
||||
attention, _ = self._heads(x, mask, use_cache=False, cache=None)
|
||||
out = self._norm1(attention + x)
|
||||
|
||||
# FeedForward блок
|
||||
|
||||
@@ -1,5 +1,16 @@
|
||||
from torch import nn
|
||||
import torch
|
||||
import math
|
||||
|
||||
class GELU(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.sqrt_2_over_pi = torch.sqrt(torch.tensor(2.0) / math.pi)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return 0.5 * x * (1 + torch.tanh(
|
||||
self.sqrt_2_over_pi * (x + 0.044715 * torch.pow(x, 3))
|
||||
))
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
"""
|
||||
@@ -37,7 +48,7 @@ class FeedForward(nn.Module):
|
||||
>>> output_double = ff(x_double)
|
||||
>>> print(output_double.dtype) # torch.float64
|
||||
"""
|
||||
def __init__(self, emb_size: int, dropout: float = 0.1):
|
||||
def __init__(self, emb_size: int, dropout: float = 0.1, activation: str = "relu"):
|
||||
"""
|
||||
Инициализация слоя Feed Forward Network.
|
||||
|
||||
@@ -49,7 +60,14 @@ class FeedForward(nn.Module):
|
||||
# Первый линейный слой (расширение размерности)
|
||||
self._layer1 = nn.Linear(emb_size, emb_size * 4)
|
||||
# ReLU активация
|
||||
self._relu = nn.ReLU()
|
||||
if activation == "relu":
|
||||
self._activation = nn.ReLU()
|
||||
elif activation == "gelu":
|
||||
self._activation = nn.GELU()
|
||||
elif activation == "gelu_exact":
|
||||
self._activation = GELU()
|
||||
else:
|
||||
raise ValueError(f"Unknown activation: {activation}")
|
||||
# Второй линейный слой (сжатие обратно)
|
||||
self._layer2 = nn.Linear(emb_size * 4, emb_size)
|
||||
# Dropout
|
||||
@@ -75,6 +93,6 @@ class FeedForward(nn.Module):
|
||||
|
||||
# Пропустим тензор x по очереди через все созданные слои
|
||||
x = self._layer1(x)
|
||||
x = self._relu(x)
|
||||
x = self._activation(x)
|
||||
x = self._layer2(x)
|
||||
return self._dropout(x)
|
||||
@@ -45,7 +45,7 @@ class HeadAttention(nn.Module):
|
||||
mask = torch.tril(torch.ones(max_seq_len, max_seq_len))
|
||||
self.register_buffer('_tril_mask', mask.bool() if hasattr(torch, 'bool') else mask.byte())
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
def forward(self, x: torch.Tensor, use_cache: bool = True, cache: tuple = None) -> tuple:
|
||||
"""
|
||||
Прямой проход через слой внимания.
|
||||
|
||||
@@ -69,16 +69,24 @@ class HeadAttention(nn.Module):
|
||||
if seq_len > self._max_seq_len:
|
||||
raise ValueError(f"Длина последовательности {seq_len} превышает максимум {self._max_seq_len}")
|
||||
|
||||
# 1. Линейные преобразования
|
||||
k = self._k(x) # [B, T, hs]
|
||||
q = self._q(x) # [B, T, hs]
|
||||
v = self._v(x) # [B, T, hs]
|
||||
|
||||
if cache is not None:
|
||||
k_cache, v_cache = cache
|
||||
k = torch.cat([k_cache, k], dim=1) # [B, cache_len + T, hs]
|
||||
v = torch.cat([v_cache, v], dim=1) # [B, cache_len + T, hs]
|
||||
|
||||
# 2. Вычисление scores
|
||||
scores = q @ k.transpose(-2, -1) / sqrt(self._head_size)
|
||||
|
||||
# 3. Применение causal маски
|
||||
scores = scores.masked_fill(~self._tril_mask[:seq_len, :seq_len], float('-inf'))
|
||||
if cache is None:
|
||||
scores = scores.masked_fill(~self._tril_mask[:seq_len, :seq_len], float('-inf'))
|
||||
|
||||
# 4. Softmax и умножение на V
|
||||
weights = F.softmax(scores, dim=-1)
|
||||
return weights @ self._v(x)
|
||||
x_out = weights @ v # [B, T, hs]
|
||||
|
||||
if use_cache is True:
|
||||
return (x_out, (k, v))
|
||||
else:
|
||||
return (x_out, None)
|
||||
@@ -58,7 +58,7 @@ class MultiHeadAttention(nn.Module):
|
||||
self._layer = nn.Linear(head_size * num_heads, emb_size)
|
||||
self._dropout = nn.Dropout(dropout)
|
||||
|
||||
def forward(self, x: torch.Tensor, mask: torch.Tensor = None):
|
||||
def forward(self, x: torch.Tensor, mask: torch.Tensor = None, use_cache: bool = True, cache: list = None):
|
||||
"""
|
||||
Прямой проход через слой многоголового внимания.
|
||||
|
||||
@@ -90,7 +90,15 @@ class MultiHeadAttention(nn.Module):
|
||||
-> Dropout: [4, 100, 512]
|
||||
"""
|
||||
# 1. Вычисляем attention для каждой головы
|
||||
attention_outputs = [head(x) for head in self._heads]
|
||||
attention_results = []
|
||||
for i, head in enumerate(self._heads):
|
||||
head_cache = cache[i] if cache is not None else None
|
||||
result = head(x, use_cache=use_cache, cache=head_cache)
|
||||
attention_results.append(result)
|
||||
|
||||
outputs, caches = zip(*attention_results)
|
||||
attention_outputs = list(outputs)
|
||||
kv_caches = list(caches)
|
||||
|
||||
# 2. Объединяем результаты всех голов
|
||||
concatenated_attention = torch.cat(attention_outputs, dim=-1)
|
||||
@@ -101,4 +109,7 @@ class MultiHeadAttention(nn.Module):
|
||||
# 4. Применяем dropout для регуляризации
|
||||
final_output = self._dropout(projected_output)
|
||||
|
||||
return final_output
|
||||
if use_cache is True:
|
||||
return (final_output, kv_caches)
|
||||
else:
|
||||
return (final_output, None)
|
||||
|
||||
@@ -40,7 +40,7 @@ class PositionalEmbeddings(nn.Module):
|
||||
embedding_dim=emb_size
|
||||
)
|
||||
|
||||
def forward(self, seq_len: int) -> Tensor:
|
||||
def forward(self, seq_len: int, start_pos: int = 0) -> Tensor:
|
||||
"""
|
||||
Возвращает позиционные эмбеддинги для заданной длины последовательности.
|
||||
|
||||
@@ -59,32 +59,8 @@ class PositionalEmbeddings(nn.Module):
|
||||
"""
|
||||
if seq_len < 1 or seq_len > self.max_seq_len:
|
||||
raise IndexError(f"Длина {seq_len} должна быть от 1 до {self.max_seq_len}")
|
||||
positions = torch.arange(seq_len, device=self.embedding.weight.device)
|
||||
if start_pos == 0:
|
||||
positions = torch.arange(seq_len, device=self.embedding.weight.device)
|
||||
else:
|
||||
positions = torch.arange(start=start_pos, end=start_pos + seq_len, device=self.embedding.weight.device)
|
||||
return self.embedding(positions)
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Демонстрация работы
|
||||
print("Пример использования PositionalEmbeddings:")
|
||||
pos_emb = PositionalEmbeddings(max_seq_len=50, emb_size=128)
|
||||
|
||||
# Пример 1: Базовое использование
|
||||
print("\n1. Базовый пример:")
|
||||
emb = pos_emb(10)
|
||||
print(f"Форма выходного тензора: {emb.shape}")
|
||||
print(f"Среднее значение: {emb.mean().item():.4f}")
|
||||
|
||||
# Пример 2: Интеграция с моделью
|
||||
print("\n2. Пример интеграции с моделью:")
|
||||
class DemoModel(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.pos_emb = PositionalEmbeddings(50, 128)
|
||||
|
||||
def forward(self, x):
|
||||
pos = self.pos_emb(x.size(1))
|
||||
return x + pos # Добавляем позиционную информацию
|
||||
|
||||
model = DemoModel()
|
||||
input_tensor = torch.randn(2, 10, 128) # [batch, seq, features]
|
||||
output = model(input_tensor)
|
||||
print(f"Вход: {input_tensor.shape}, Выход: {output.shape}")
|
||||
@@ -1,3 +1,4 @@
|
||||
from .gpt import GPT
|
||||
from .gpt2 import GPT2
|
||||
|
||||
__all__ = ["GPT"]
|
||||
__all__ = ["GPT", "GPT2"]
|
||||
|
||||
170
llm/src/llm/models/gpt/gpt2.py
Normal file
170
llm/src/llm/models/gpt/gpt2.py
Normal file
@@ -0,0 +1,170 @@
|
||||
import torch
|
||||
from torch import nn, Tensor
|
||||
import torch.nn.functional as F
|
||||
from llm.core.base_model import BaseModel
|
||||
from llm.core.token_embeddings import TokenEmbeddings
|
||||
from llm.core.positional_embeddings import PositionalEmbeddings
|
||||
from llm.core.cached_decoder import CachedDecoder
|
||||
|
||||
|
||||
class GPT2(BaseModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
# Инициализация слоев
|
||||
self._max_seq_len = config["max_position_embeddings"]
|
||||
self._token_embeddings = TokenEmbeddings(
|
||||
vocab_size=config["vocab_size"],
|
||||
emb_size=config["embed_dim"]
|
||||
)
|
||||
self._position_embeddings = PositionalEmbeddings(
|
||||
max_seq_len=config["max_position_embeddings"],
|
||||
emb_size=config["embed_dim"]
|
||||
)
|
||||
self._dropout = nn.Dropout(config["dropout"])
|
||||
# head_size = emb_size // num_heads
|
||||
self._decoders = nn.ModuleList([CachedDecoder(
|
||||
num_heads=config["num_heads"],
|
||||
emb_size=config["embed_dim"],
|
||||
head_size=config["embed_dim"] // config["num_heads"],
|
||||
max_seq_len=config["max_position_embeddings"],
|
||||
dropout=config["dropout"]
|
||||
) for _ in range(config["num_layers"])])
|
||||
self._norm = nn.LayerNorm(config["embed_dim"])
|
||||
self._linear = nn.Linear(config["embed_dim"], config["vocab_size"])
|
||||
|
||||
def forward(self, x: torch.Tensor, use_cache: bool = True, cache: list = None) -> tuple:
|
||||
# Проверка длины последовательности (только при отсутствии кэша)
|
||||
if cache is None and x.size(1) > self._max_seq_len:
|
||||
raise ValueError(f"Длина последовательности {x.size(1)} превышает максимальную {self.max_seq_len}")
|
||||
|
||||
|
||||
# Вычисление start_pos из кэша (если кэш передан)
|
||||
if cache is not None:
|
||||
# При кэше обрабатываем только один токен (последний)
|
||||
seq_len = 1
|
||||
# Вычисляем start_pos из самого нижнего уровня кэша
|
||||
if cache and cache[0] and cache[0][0]:
|
||||
key_cache, _ = cache[0][0] # Первый декодер, первая голова
|
||||
start_pos = key_cache.size(1) # cache_len
|
||||
else:
|
||||
start_pos = 0
|
||||
else:
|
||||
# Без кэша работаем как раньше
|
||||
start_pos = 0
|
||||
seq_len = x.size(1)
|
||||
|
||||
# Эмбеддинги токенов и позиций
|
||||
tok_out = self._token_embeddings(x) # [batch, seq_len, emb_size]
|
||||
pos_out = self._position_embeddings(seq_len, start_pos=start_pos) # [seq_len, emb_size]
|
||||
|
||||
# Комбинирование
|
||||
out = self._dropout(tok_out + pos_out.unsqueeze(0)) # [batch, seq_len, emb_size]
|
||||
|
||||
# Стек декодеров с передачей кэша
|
||||
new_cache = []
|
||||
for i, decoder in enumerate(self._decoders):
|
||||
decoder_cache = cache[i] if cache is not None else None
|
||||
decoder_result = decoder(out, use_cache=use_cache, cache=decoder_cache)
|
||||
|
||||
# Извлекаем результат из кортежа
|
||||
if use_cache:
|
||||
out, decoder_new_cache = decoder_result
|
||||
new_cache.append(decoder_new_cache)
|
||||
else:
|
||||
out = decoder_result[0]
|
||||
|
||||
out = self._norm(out)
|
||||
logits = self._linear(out)
|
||||
|
||||
# Возвращаем результат с учетом use_cache
|
||||
if use_cache:
|
||||
return (logits, new_cache)
|
||||
else:
|
||||
return (logits, None)
|
||||
|
||||
def generate(self,
|
||||
x: torch.Tensor,
|
||||
max_new_tokens: int,
|
||||
do_sample: bool,
|
||||
temperature: float = 1.0,
|
||||
top_k: int = None,
|
||||
top_p: float = None,
|
||||
use_cache: bool = True
|
||||
) -> torch.Tensor:
|
||||
cache = None
|
||||
|
||||
for _ in range(max_new_tokens):
|
||||
if use_cache and cache is not None:
|
||||
# Используем кэш - передаем только последний токен
|
||||
x_input = x[:, -1:] # [batch_size, 1]
|
||||
else:
|
||||
# Первая итерация или кэш отключен - передаем всю последовательность
|
||||
x_input = x
|
||||
|
||||
# Прямой проход с кэшем
|
||||
logits, new_cache = self.forward(x_input, use_cache=use_cache, cache=cache)
|
||||
|
||||
# Обновляем кэш для следующей итерации
|
||||
if use_cache:
|
||||
cache = new_cache
|
||||
|
||||
last_logits = logits[:, -1, :] # [batch_size, vocab_size]
|
||||
|
||||
# Масштабируем логиты температурой
|
||||
if temperature > 0:
|
||||
logits_scaled = last_logits / temperature
|
||||
else:
|
||||
logits_scaled = last_logits
|
||||
|
||||
if do_sample == True and top_k != None:
|
||||
_, topk_indices = torch.topk(logits_scaled, top_k, dim=-1)
|
||||
|
||||
# # Заменим все НЕ top-k логиты на -inf
|
||||
masked_logits = logits_scaled.clone()
|
||||
vocab_size = logits_scaled.size(-1)
|
||||
|
||||
# создаём маску: 1, если токен НЕ в topk_indices
|
||||
mask = torch.ones_like(logits_scaled, dtype=torch.uint8)
|
||||
mask.scatter_(1, topk_indices, 0) # 0 там, где top-k индексы
|
||||
masked_logits[mask.byte()] = float('-inf')
|
||||
|
||||
logits_scaled = masked_logits
|
||||
|
||||
if do_sample == True and top_p != None:
|
||||
# 1. Применим softmax, чтобы получить вероятности:
|
||||
probs = F.softmax(logits_scaled, dim=-1) # [B, vocab_size]
|
||||
# 2. Отсортируем токены по убыванию вероятностей:
|
||||
sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1)
|
||||
# 3. Посчитаем кумулятивную сумму вероятностей:
|
||||
cum_probs = torch.cumsum(sorted_probs, dim=-1) # [B, vocab_size]
|
||||
# 4. Определим маску: оставить токены, пока сумма < top_p
|
||||
sorted_mask = (cum_probs <= top_p).byte() # [B, vocab_size]
|
||||
# Гарантируем, что хотя бы первый токен останется
|
||||
sorted_mask[:, 0] = 1
|
||||
# 5. Преобразуем маску обратно в оригинальный порядок:
|
||||
# Создаём полную маску из 0
|
||||
mask = torch.zeros_like(probs, dtype=torch.uint8)
|
||||
# Устанавливаем 1 в местах нужных токенов
|
||||
mask.scatter_(dim=1, index=sorted_indices, src=sorted_mask)
|
||||
# 6. Зануляем логиты токенов вне топ-p:
|
||||
logits_scaled[~mask] = float('-inf')
|
||||
|
||||
# 4. Применяем Softmax
|
||||
probs = F.softmax(logits_scaled, dim=-1) # [batch_size, vocab_size]
|
||||
|
||||
|
||||
if do_sample == True:
|
||||
# 5. Если do_sample равен True, то отбираем токен случайно с помощью torch.multinomial
|
||||
next_token = torch.multinomial(probs, num_samples=1) # [batch_size, 1]
|
||||
else:
|
||||
# 5. Если do_sample равен False, то выбираем токен с максимальной вероятностью
|
||||
next_token = torch.argmax(probs, dim=-1, keepdim=True) # [batch_size, 1]
|
||||
|
||||
# 6. Добавляем его к последовательности
|
||||
x = torch.cat([x, next_token], dim=1) # [batch_size, seq_len+1]
|
||||
return x
|
||||
|
||||
@property
|
||||
def max_seq_len(self) -> int:
|
||||
return self._max_seq_len
|
||||
Reference in New Issue
Block a user