Реализация и документирование метода fit() для обучения GPT

Основные изменения:
1. Реализация метода fit():
- Добавлен полный цикл обучения (forward/backward pass)
- Поддержка обучения на CPU/GPU
- Расчет и сохранение метрик (train_loss, validation_loss)
- Интеграция с оптимизатором Adam

2. Документация:
- Подробное описание метода в gpt_documentation_ru.md
- Примеры использования в README.md
- Параметры и требования к данным

3. Тестирование:
- Тесты базовой функциональности
- Проверка изменения весов
- Тесты для разных устройств (CPU/CUDA)
- Обработка edge-cases

4. Примеры:
- train_gpt_example.py с полным workflow
- Генерация синтетических данных
- Сохранение/загрузка моделей
This commit is contained in:
Sergey Penkovsky
2025-07-23 12:26:49 +03:00
parent c56a3e80c9
commit 8b0dd9c504
6 changed files with 391 additions and 6 deletions

View File

@@ -1,6 +1,7 @@
from torch import nn
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from simple_llm.embedding.token_embeddings import TokenEmbeddings
from simple_llm.embedding.positional_embeddings import PositionalEmbeddings
from simple_llm.transformer.decoder import Decoder
@@ -37,6 +38,9 @@ class GPT(nn.Module):
self._num_layers = num_layers
self._dropout = dropout
self._device = device
self.train_loss = None
self.validation_loss = None
# Инициализация слоев
self._token_embeddings = TokenEmbeddings(
@@ -206,7 +210,7 @@ class GPT(nn.Module):
masked_logits = logits_scaled.clone()
vocab_size = logits_scaled.size(-1)
# создаём маску: 1, если токен НЕ в topk_indices
# создаём маску: True, если токен НЕ в topk_indices
mask = torch.ones_like(logits_scaled, dtype=torch.bool)
mask.scatter_(1, topk_indices, False) # False там, где top-k индексы
masked_logits[mask] = float('-inf')
@@ -225,9 +229,9 @@ class GPT(nn.Module):
# Гарантируем, что хотя бы первый токен останется
sorted_mask[:, 0] = True
# 5. Преобразуем маску обратно в оригинальный порядок:
# Создаём полную маску из 0
# Создаём полную маску из False
mask = torch.zeros_like(probs, dtype=torch.bool)
# Устанавливаем 1 в местах нужных токенов
# Устанавливаем True в местах нужных токенов
mask.scatter_(dim=1, index=sorted_indices, src=sorted_mask)
# 6. Зануляем логиты токенов вне топ-p:
logits_scaled[~mask] = float('-inf')
@@ -276,4 +280,108 @@ class GPT(nn.Module):
@property
def max_seq_len(self) -> int:
"""Возвращает максимальную длину последовательности"""
return self._max_seq_len
return self._max_seq_len
def fit(self,
train_loader: DataLoader,
valid_loader: DataLoader = None,
num_epoch: int = 1,
learning_rate: float = 0.001
):
"""Обучает модель GPT на предоставленных данных.
Процесс обучения включает:
- Прямой проход для получения предсказаний
- Вычисление потерь с помощью кросс-энтропии
- Обратное распространение ошибки
- Обновление весов через оптимизатор Adam
Args:
train_loader (DataLoader): Загрузчик обучающих данных, возвращающий пары (inputs, targets).
inputs - тензор индексов токенов формы [batch_size, seq_len]
targets - тензор индексов следующих токенов формы [batch_size, seq_len]
valid_loader (DataLoader, optional): Загрузчик валидационных данных. Если None,
валидация не выполняется. По умолчанию None.
num_epoch (int, optional): Количество эпох обучения. По умолчанию 1.
learning_rate (float, optional): Скорость обучения для оптимизатора. По умолчанию 0.001.
Returns:
None
Raises:
ValueError: Если train_loader равен None
ValueError: Если num_epoch ≤ 0
ValueError: Если learning_rate ≤ 0
Side Effects:
- Обновляет параметры модели (self.parameters())
- Устанавливает атрибуты:
self.train_loss: Средние потери на обучении за последнюю эпоху
self.validation_loss: Средние потери на валидации за последнюю эпоху (если valid_loader передан)
Examples:
>>> from torch.utils.data import DataLoader, TensorDataset
>>> # Создаем тестовые данные
>>> inputs = torch.randint(0, 100, (100, 10))
>>> targets = torch.randint(0, 100, (100, 10))
>>> dataset = TensorDataset(inputs, targets)
>>> loader = DataLoader(dataset, batch_size=32)
>>> # Инициализируем модель
>>> model = GPT(vocab_size=100, max_seq_len=20, emb_size=64,
... num_heads=4, head_size=16, num_layers=2)
>>> # Обучаем модель
>>> model.fit(loader, num_epoch=5, learning_rate=0.001)
"""
if train_loader is None:
raise ValueError("train_loader не может быть None")
if num_epoch <= 0:
raise ValueError("num_epoch должен быть > 0")
if learning_rate <= 0:
raise ValueError("learning_rate должен быть > 0")
device = torch.device(self._device)
self.to(device)
optimizer = torch.optim.Adam(self.parameters(), lr=learning_rate)
for epoch in range(num_epoch):
self.train()
epoch_loss = 0.0
#for inputs, targets in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epoch}"):
for inputs, targets in train_loader:
inputs = inputs.to(device)
targets = targets.to(device)
logits = self(inputs)
logits = logits.view(-1, logits.size(-1))
targets = targets.view(-1)
loss = F.cross_entropy(logits, targets)
optimizer.zero_grad()
loss.backward()
optimizer.step()
epoch_loss += loss.item()
self.train_loss = epoch_loss / len(train_loader)
#print(f"[{epoch+1}/{num_epoch}] Train Loss: {self.train_loss:.4f}", end='')
if valid_loader is not None:
self.eval()
valid_loss = 0.0
with torch.no_grad():
for inputs, targets in valid_loader:
inputs = inputs.to(device)
targets = targets.to(device)
logits = self(inputs)
logits = logits.view(-1, logits.size(-1))
targets = targets.view(-1)
loss = F.cross_entropy(logits, targets)
valid_loss += loss.item()
self.validation_loss = valid_loss / len(valid_loader)
#print(f" | Val Loss: {self.validation_loss:.4f}")