mirror of
https://github.com/pese-git/simple-llm.git
synced 2026-01-23 21:14:17 +00:00
152 lines
6.5 KiB
Python
152 lines
6.5 KiB
Python
|
|
from torch import nn
|
|||
|
|
import torch
|
|||
|
|
import torch.nn.functional as F
|
|||
|
|
from simple_llm.embedding.token_embeddings import TokenEmbeddings
|
|||
|
|
from simple_llm.embedding.positional_embeddings import PositionalEmbeddings
|
|||
|
|
from simple_llm.transformer.decoder import Decoder
|
|||
|
|
|
|||
|
|
class GPT(nn.Module):
|
|||
|
|
"""GPT-like трансформер для генерации текста
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
vocab_size: Размер словаря
|
|||
|
|
max_seq_len: Макс. длина последовательности
|
|||
|
|
emb_size: Размерность эмбеддингов
|
|||
|
|
num_heads: Количество голов внимания
|
|||
|
|
head_size: Размерность голов внимания
|
|||
|
|
num_layers: Количество слоёв декодера
|
|||
|
|
dropout: Вероятность dropout (default=0.1)
|
|||
|
|
device: Устройство (default='cpu')
|
|||
|
|
"""
|
|||
|
|
def __init__(self,
|
|||
|
|
vocab_size: int,
|
|||
|
|
max_seq_len: int,
|
|||
|
|
emb_size: int,
|
|||
|
|
num_heads: int,
|
|||
|
|
head_size: int,
|
|||
|
|
num_layers: int,
|
|||
|
|
dropout: float = 0.1,
|
|||
|
|
device: str = 'cpu'
|
|||
|
|
):
|
|||
|
|
super().__init__()
|
|||
|
|
self._vocab_size = vocab_size
|
|||
|
|
self._max_seq_len = max_seq_len
|
|||
|
|
self._emb_size = emb_size
|
|||
|
|
self._num_heads = num_heads
|
|||
|
|
self._head_size = head_size
|
|||
|
|
self._num_layers = num_layers
|
|||
|
|
self._dropout = dropout
|
|||
|
|
self._device = device
|
|||
|
|
|
|||
|
|
# Инициализация слоев
|
|||
|
|
self._token_embeddings = TokenEmbeddings(
|
|||
|
|
vocab_size=vocab_size,
|
|||
|
|
emb_size=emb_size
|
|||
|
|
)
|
|||
|
|
self._position_embeddings = PositionalEmbeddings(
|
|||
|
|
max_seq_len=max_seq_len,
|
|||
|
|
emb_size=emb_size
|
|||
|
|
)
|
|||
|
|
self._dropout = nn.Dropout(dropout)
|
|||
|
|
self._decoders = nn.ModuleList([Decoder(
|
|||
|
|
num_heads=num_heads,
|
|||
|
|
emb_size=emb_size,
|
|||
|
|
head_size=head_size,
|
|||
|
|
max_seq_len=max_seq_len,
|
|||
|
|
dropout=dropout
|
|||
|
|
) for _ in range(num_layers)])
|
|||
|
|
self._linear = nn.Linear(emb_size, vocab_size)
|
|||
|
|
|
|||
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|||
|
|
"""Прямой проход через GPT
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
x: Входной тензор [batch_size, seq_len]
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
Тензор логитов [batch_size, seq_len, vocab_size]
|
|||
|
|
"""
|
|||
|
|
# Проверка длины последовательности
|
|||
|
|
if x.size(1) > self._max_seq_len:
|
|||
|
|
raise ValueError(f"Длина последовательности {x.size(1)} превышает максимальную {self.max_seq_len}")
|
|||
|
|
|
|||
|
|
# Эмбеддинги токенов и позиций
|
|||
|
|
tok_out = self._token_embeddings(x) # [batch, seq_len, emb_size]
|
|||
|
|
pos_out = self._position_embeddings(x.size(1)) # [seq_len, emb_size]
|
|||
|
|
|
|||
|
|
# Комбинирование
|
|||
|
|
out = self._dropout(tok_out + pos_out.unsqueeze(0)) # [batch, seq_len, emb_size]
|
|||
|
|
|
|||
|
|
# Стек декодеров
|
|||
|
|
for decoder in self._decoders:
|
|||
|
|
out = decoder(out)
|
|||
|
|
|
|||
|
|
return self._linear(out) # [batch, seq_len, vocab_size]
|
|||
|
|
|
|||
|
|
def generate(self, x: torch.Tensor, max_new_tokens: int) -> torch.Tensor:
|
|||
|
|
"""Авторегрессивная генерация текста
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
x: Входной тензор с индексами токенов [batch_size, seq_len]
|
|||
|
|
max_new_tokens: Максимальное количество новых токенов для генерации
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
Тензор с расширенной последовательностью токенов [batch_size, seq_len + max_new_tokens]
|
|||
|
|
|
|||
|
|
Алгоритм работы:
|
|||
|
|
1. На каждом шаге берется последний фрагмент последовательности (не длиннее max_seq_len)
|
|||
|
|
2. Вычисляются логиты для следующего токена
|
|||
|
|
3. Выбирается токен с максимальной вероятностью (жадный алгоритм)
|
|||
|
|
4. Токен добавляется к последовательности
|
|||
|
|
5. Процесс повторяется пока не сгенерируется max_new_tokens токенов
|
|||
|
|
"""
|
|||
|
|
for _ in range(max_new_tokens):
|
|||
|
|
# 1. Обрезаем вход, если последовательность слишком длинная
|
|||
|
|
x_cond = x[:, -self.max_seq_len:]
|
|||
|
|
|
|||
|
|
# 2. Передаем последовательность в метод forward класса GPT и полуаем логиты.
|
|||
|
|
logits = self.forward(x_cond)
|
|||
|
|
|
|||
|
|
# 3. Берем логиты для последнего токена
|
|||
|
|
last_logits = logits[:, -1, :] # [batch_size, vocab_size]
|
|||
|
|
|
|||
|
|
# 4. Применяем Softmax
|
|||
|
|
probs = F.softmax(last_logits, dim=-1) # [batch_size, vocab_size]
|
|||
|
|
|
|||
|
|
# 5. Выбираем токен с максимальной вероятностью
|
|||
|
|
next_token = torch.argmax(probs, dim=-1, keepdim=True) # [batch_size, 1]
|
|||
|
|
|
|||
|
|
# 6. Добавляем его к последовательности
|
|||
|
|
x = torch.cat([x, next_token], dim=1) # [batch_size, seq_len+1]
|
|||
|
|
return x
|
|||
|
|
|
|||
|
|
def save(self, path):
|
|||
|
|
torch.save({
|
|||
|
|
'model_state_dict': self.state_dict(),
|
|||
|
|
'vocab_size': self._vocab_size,
|
|||
|
|
'max_seq_len': self._max_seq_len,
|
|||
|
|
'emb_size': self._emb_size,
|
|||
|
|
'num_heads': self._num_heads,
|
|||
|
|
'head_size': self._head_size,
|
|||
|
|
'num_layers': self._num_layers
|
|||
|
|
}, path)
|
|||
|
|
|
|||
|
|
@classmethod
|
|||
|
|
def load(cls, path, device):
|
|||
|
|
checkpoint = torch.load(path, map_location=device)
|
|||
|
|
model = cls(
|
|||
|
|
vocab_size=checkpoint['vocab_size'],
|
|||
|
|
max_seq_len=checkpoint['max_seq_len'],
|
|||
|
|
emb_size=checkpoint['emb_size'],
|
|||
|
|
num_heads=checkpoint['num_heads'],
|
|||
|
|
head_size=checkpoint['head_size'],
|
|||
|
|
num_layers=checkpoint['num_layers']
|
|||
|
|
)
|
|||
|
|
model.load_state_dict(checkpoint['model_state_dict'])
|
|||
|
|
model.to(device)
|
|||
|
|
return model
|
|||
|
|
|
|||
|
|
@property
|
|||
|
|
def max_seq_len(self) -> int:
|
|||
|
|
"""Возвращает максимальную длину последовательности"""
|
|||
|
|
return self._max_seq_len
|