feat(gpt2): add GPT2 architecture with universal FeedForward, CachedDecoder, and refactored components. Core modules now shared; add train and generate scripts for GPT2-BPE.

This commit is contained in:
Sergey Penkovsky
2025-10-05 19:11:20 +03:00
parent f866ed7ac7
commit c39e68d71a
10 changed files with 833 additions and 44 deletions

View File

@@ -0,0 +1,61 @@
# llm/src/llm/core/cached_decoder.py
import torch
from torch import nn
from .feed_forward import FeedForward
from .multi_head_attention import MultiHeadAttention
class CachedDecoder(nn.Module):
"""
Универсальный декодер с поддержкой кэша для autoregressive использования (GPT, LLAMA и пр).
- Поддерживает использование past_key_values для быстрого генеративного инференса.
"""
def __init__(
self,
num_heads: int,
emb_size: int,
head_size: int,
max_seq_len: int,
dropout: float = 0.1,
activation: str = "gelu",
):
super().__init__()
self._heads = MultiHeadAttention(
num_heads=num_heads,
emb_size=emb_size,
head_size=head_size,
max_seq_len=max_seq_len,
dropout=dropout,
)
self._ff = FeedForward(emb_size=emb_size, dropout=dropout, activation=activation)
self._norm1 = nn.LayerNorm(emb_size)
self._norm2 = nn.LayerNorm(emb_size)
def forward(
self,
x: torch.Tensor,
mask: torch.Tensor = None,
use_cache: bool = True,
cache: list = None,
):
"""
x: [batch, seq_len, emb_size]
mask: (optional)
use_cache: использовать ли кэширование KV-слоев (инкрементальный генератив, GPT-style)
cache: список кэшей для голов (или None)
Возвращает: (output, new_cache) если use_cache=True, иначе (output, None)
"""
norm1_out = self._norm1(x)
# Передаём все cache/use_cache дальше в attention
attention, kv_caches = self._heads(
norm1_out, mask=mask, use_cache=use_cache, cache=cache
)
out = attention + x
norm2_out = self._norm2(out)
ffn_out = self._ff(norm2_out)
result = ffn_out + out
if use_cache:
return (result, kv_caches)
else:
return (result, None)

View File

@@ -88,7 +88,7 @@ class Decoder(nn.Module):
4. Добавляем residual connection и LayerNorm
"""
# Self-Attention блок
attention = self._heads(x, mask)
attention, _ = self._heads(x, mask, use_cache=False, cache=None)
out = self._norm1(attention + x)
# FeedForward блок

View File

@@ -1,5 +1,16 @@
from torch import nn
import torch
import math
class GELU(nn.Module):
def __init__(self):
super().__init__()
self.sqrt_2_over_pi = torch.sqrt(torch.tensor(2.0) / math.pi)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return 0.5 * x * (1 + torch.tanh(
self.sqrt_2_over_pi * (x + 0.044715 * torch.pow(x, 3))
))
class FeedForward(nn.Module):
"""
@@ -37,7 +48,7 @@ class FeedForward(nn.Module):
>>> output_double = ff(x_double)
>>> print(output_double.dtype) # torch.float64
"""
def __init__(self, emb_size: int, dropout: float = 0.1):
def __init__(self, emb_size: int, dropout: float = 0.1, activation: str = "relu"):
"""
Инициализация слоя Feed Forward Network.
@@ -49,7 +60,14 @@ class FeedForward(nn.Module):
# Первый линейный слой (расширение размерности)
self._layer1 = nn.Linear(emb_size, emb_size * 4)
# ReLU активация
self._relu = nn.ReLU()
if activation == "relu":
self._activation = nn.ReLU()
elif activation == "gelu":
self._activation = nn.GELU()
elif activation == "gelu_exact":
self._activation = GELU()
else:
raise ValueError(f"Unknown activation: {activation}")
# Второй линейный слой (сжатие обратно)
self._layer2 = nn.Linear(emb_size * 4, emb_size)
# Dropout
@@ -75,6 +93,6 @@ class FeedForward(nn.Module):
# Пропустим тензор x по очереди через все созданные слои
x = self._layer1(x)
x = self._relu(x)
x = self._activation(x)
x = self._layer2(x)
return self._dropout(x)

View File

@@ -45,7 +45,7 @@ class HeadAttention(nn.Module):
mask = torch.tril(torch.ones(max_seq_len, max_seq_len))
self.register_buffer('_tril_mask', mask.bool() if hasattr(torch, 'bool') else mask.byte())
def forward(self, x: torch.Tensor) -> torch.Tensor:
def forward(self, x: torch.Tensor, use_cache: bool = True, cache: tuple = None) -> tuple:
"""
Прямой проход через слой внимания.
@@ -69,16 +69,24 @@ class HeadAttention(nn.Module):
if seq_len > self._max_seq_len:
raise ValueError(f"Длина последовательности {seq_len} превышает максимум {self._max_seq_len}")
# 1. Линейные преобразования
k = self._k(x) # [B, T, hs]
q = self._q(x) # [B, T, hs]
v = self._v(x) # [B, T, hs]
if cache is not None:
k_cache, v_cache = cache
k = torch.cat([k_cache, k], dim=1) # [B, cache_len + T, hs]
v = torch.cat([v_cache, v], dim=1) # [B, cache_len + T, hs]
# 2. Вычисление scores
scores = q @ k.transpose(-2, -1) / sqrt(self._head_size)
# 3. Применение causal маски
scores = scores.masked_fill(~self._tril_mask[:seq_len, :seq_len], float('-inf'))
if cache is None:
scores = scores.masked_fill(~self._tril_mask[:seq_len, :seq_len], float('-inf'))
# 4. Softmax и умножение на V
weights = F.softmax(scores, dim=-1)
return weights @ self._v(x)
x_out = weights @ v # [B, T, hs]
if use_cache is True:
return (x_out, (k, v))
else:
return (x_out, None)

View File

@@ -58,7 +58,7 @@ class MultiHeadAttention(nn.Module):
self._layer = nn.Linear(head_size * num_heads, emb_size)
self._dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor, mask: torch.Tensor = None):
def forward(self, x: torch.Tensor, mask: torch.Tensor = None, use_cache: bool = True, cache: list = None):
"""
Прямой проход через слой многоголового внимания.
@@ -90,7 +90,15 @@ class MultiHeadAttention(nn.Module):
-> Dropout: [4, 100, 512]
"""
# 1. Вычисляем attention для каждой головы
attention_outputs = [head(x) for head in self._heads]
attention_results = []
for i, head in enumerate(self._heads):
head_cache = cache[i] if cache is not None else None
result = head(x, use_cache=use_cache, cache=head_cache)
attention_results.append(result)
outputs, caches = zip(*attention_results)
attention_outputs = list(outputs)
kv_caches = list(caches)
# 2. Объединяем результаты всех голов
concatenated_attention = torch.cat(attention_outputs, dim=-1)
@@ -101,4 +109,7 @@ class MultiHeadAttention(nn.Module):
# 4. Применяем dropout для регуляризации
final_output = self._dropout(projected_output)
return final_output
if use_cache is True:
return (final_output, kv_caches)
else:
return (final_output, None)

View File

@@ -40,7 +40,7 @@ class PositionalEmbeddings(nn.Module):
embedding_dim=emb_size
)
def forward(self, seq_len: int) -> Tensor:
def forward(self, seq_len: int, start_pos: int = 0) -> Tensor:
"""
Возвращает позиционные эмбеддинги для заданной длины последовательности.
@@ -59,32 +59,8 @@ class PositionalEmbeddings(nn.Module):
"""
if seq_len < 1 or seq_len > self.max_seq_len:
raise IndexError(f"Длина {seq_len} должна быть от 1 до {self.max_seq_len}")
positions = torch.arange(seq_len, device=self.embedding.weight.device)
if start_pos == 0:
positions = torch.arange(seq_len, device=self.embedding.weight.device)
else:
positions = torch.arange(start=start_pos, end=start_pos + seq_len, device=self.embedding.weight.device)
return self.embedding(positions)
if __name__ == "__main__":
# Демонстрация работы
print("Пример использования PositionalEmbeddings:")
pos_emb = PositionalEmbeddings(max_seq_len=50, emb_size=128)
# Пример 1: Базовое использование
print("\n1. Базовый пример:")
emb = pos_emb(10)
print(f"Форма выходного тензора: {emb.shape}")
print(f"Среднее значение: {emb.mean().item():.4f}")
# Пример 2: Интеграция с моделью
print("\n2. Пример интеграции с моделью:")
class DemoModel(nn.Module):
def __init__(self):
super().__init__()
self.pos_emb = PositionalEmbeddings(50, 128)
def forward(self, x):
pos = self.pos_emb(x.size(1))
return x + pos # Добавляем позиционную информацию
model = DemoModel()
input_tensor = torch.randn(2, 10, 128) # [batch, seq, features]
output = model(input_tensor)
print(f"Вход: {input_tensor.shape}, Выход: {output.shape}")

View File

@@ -1,3 +1,4 @@
from .gpt import GPT
from .gpt2 import GPT2
__all__ = ["GPT"]
__all__ = ["GPT", "GPT2"]

View File

@@ -0,0 +1,170 @@
import torch
from torch import nn, Tensor
import torch.nn.functional as F
from llm.core.base_model import BaseModel
from llm.core.token_embeddings import TokenEmbeddings
from llm.core.positional_embeddings import PositionalEmbeddings
from llm.core.cached_decoder import CachedDecoder
class GPT2(BaseModel):
def __init__(self, config):
super().__init__(config)
# Инициализация слоев
self._max_seq_len = config["max_position_embeddings"]
self._token_embeddings = TokenEmbeddings(
vocab_size=config["vocab_size"],
emb_size=config["embed_dim"]
)
self._position_embeddings = PositionalEmbeddings(
max_seq_len=config["max_position_embeddings"],
emb_size=config["embed_dim"]
)
self._dropout = nn.Dropout(config["dropout"])
# head_size = emb_size // num_heads
self._decoders = nn.ModuleList([CachedDecoder(
num_heads=config["num_heads"],
emb_size=config["embed_dim"],
head_size=config["embed_dim"] // config["num_heads"],
max_seq_len=config["max_position_embeddings"],
dropout=config["dropout"]
) for _ in range(config["num_layers"])])
self._norm = nn.LayerNorm(config["embed_dim"])
self._linear = nn.Linear(config["embed_dim"], config["vocab_size"])
def forward(self, x: torch.Tensor, use_cache: bool = True, cache: list = None) -> tuple:
# Проверка длины последовательности (только при отсутствии кэша)
if cache is None and x.size(1) > self._max_seq_len:
raise ValueError(f"Длина последовательности {x.size(1)} превышает максимальную {self.max_seq_len}")
# Вычисление start_pos из кэша (если кэш передан)
if cache is not None:
# При кэше обрабатываем только один токен (последний)
seq_len = 1
# Вычисляем start_pos из самого нижнего уровня кэша
if cache and cache[0] and cache[0][0]:
key_cache, _ = cache[0][0] # Первый декодер, первая голова
start_pos = key_cache.size(1) # cache_len
else:
start_pos = 0
else:
# Без кэша работаем как раньше
start_pos = 0
seq_len = x.size(1)
# Эмбеддинги токенов и позиций
tok_out = self._token_embeddings(x) # [batch, seq_len, emb_size]
pos_out = self._position_embeddings(seq_len, start_pos=start_pos) # [seq_len, emb_size]
# Комбинирование
out = self._dropout(tok_out + pos_out.unsqueeze(0)) # [batch, seq_len, emb_size]
# Стек декодеров с передачей кэша
new_cache = []
for i, decoder in enumerate(self._decoders):
decoder_cache = cache[i] if cache is not None else None
decoder_result = decoder(out, use_cache=use_cache, cache=decoder_cache)
# Извлекаем результат из кортежа
if use_cache:
out, decoder_new_cache = decoder_result
new_cache.append(decoder_new_cache)
else:
out = decoder_result[0]
out = self._norm(out)
logits = self._linear(out)
# Возвращаем результат с учетом use_cache
if use_cache:
return (logits, new_cache)
else:
return (logits, None)
def generate(self,
x: torch.Tensor,
max_new_tokens: int,
do_sample: bool,
temperature: float = 1.0,
top_k: int = None,
top_p: float = None,
use_cache: bool = True
) -> torch.Tensor:
cache = None
for _ in range(max_new_tokens):
if use_cache and cache is not None:
# Используем кэш - передаем только последний токен
x_input = x[:, -1:] # [batch_size, 1]
else:
# Первая итерация или кэш отключен - передаем всю последовательность
x_input = x
# Прямой проход с кэшем
logits, new_cache = self.forward(x_input, use_cache=use_cache, cache=cache)
# Обновляем кэш для следующей итерации
if use_cache:
cache = new_cache
last_logits = logits[:, -1, :] # [batch_size, vocab_size]
# Масштабируем логиты температурой
if temperature > 0:
logits_scaled = last_logits / temperature
else:
logits_scaled = last_logits
if do_sample == True and top_k != None:
_, topk_indices = torch.topk(logits_scaled, top_k, dim=-1)
# # Заменим все НЕ top-k логиты на -inf
masked_logits = logits_scaled.clone()
vocab_size = logits_scaled.size(-1)
# создаём маску: 1, если токен НЕ в topk_indices
mask = torch.ones_like(logits_scaled, dtype=torch.uint8)
mask.scatter_(1, topk_indices, 0) # 0 там, где top-k индексы
masked_logits[mask.byte()] = float('-inf')
logits_scaled = masked_logits
if do_sample == True and top_p != None:
# 1. Применим softmax, чтобы получить вероятности:
probs = F.softmax(logits_scaled, dim=-1) # [B, vocab_size]
# 2. Отсортируем токены по убыванию вероятностей:
sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1)
# 3. Посчитаем кумулятивную сумму вероятностей:
cum_probs = torch.cumsum(sorted_probs, dim=-1) # [B, vocab_size]
# 4. Определим маску: оставить токены, пока сумма < top_p
sorted_mask = (cum_probs <= top_p).byte() # [B, vocab_size]
# Гарантируем, что хотя бы первый токен останется
sorted_mask[:, 0] = 1
# 5. Преобразуем маску обратно в оригинальный порядок:
# Создаём полную маску из 0
mask = torch.zeros_like(probs, dtype=torch.uint8)
# Устанавливаем 1 в местах нужных токенов
mask.scatter_(dim=1, index=sorted_indices, src=sorted_mask)
# 6. Зануляем логиты токенов вне топ-p:
logits_scaled[~mask] = float('-inf')
# 4. Применяем Softmax
probs = F.softmax(logits_scaled, dim=-1) # [batch_size, vocab_size]
if do_sample == True:
# 5. Если do_sample равен True, то отбираем токен случайно с помощью torch.multinomial
next_token = torch.multinomial(probs, num_samples=1) # [batch_size, 1]
else:
# 5. Если do_sample равен False, то выбираем токен с максимальной вероятностью
next_token = torch.argmax(probs, dim=-1, keepdim=True) # [batch_size, 1]
# 6. Добавляем его к последовательности
x = torch.cat([x, next_token], dim=1) # [batch_size, seq_len+1]
return x
@property
def max_seq_len(self) -> int:
return self._max_seq_len