mirror of
https://github.com/pese-git/simple-llm.git
synced 2026-01-23 21:14:17 +00:00
Реализация и документирование метода fit() для обучения GPT
Основные изменения: 1. Реализация метода fit(): - Добавлен полный цикл обучения (forward/backward pass) - Поддержка обучения на CPU/GPU - Расчет и сохранение метрик (train_loss, validation_loss) - Интеграция с оптимизатором Adam 2. Документация: - Подробное описание метода в gpt_documentation_ru.md - Примеры использования в README.md - Параметры и требования к данным 3. Тестирование: - Тесты базовой функциональности - Проверка изменения весов - Тесты для разных устройств (CPU/CUDA) - Обработка edge-cases 4. Примеры: - train_gpt_example.py с полным workflow - Генерация синтетических данных - Сохранение/загрузка моделей
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
from torch import nn
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.utils.data import DataLoader
|
||||
from simple_llm.embedding.token_embeddings import TokenEmbeddings
|
||||
from simple_llm.embedding.positional_embeddings import PositionalEmbeddings
|
||||
from simple_llm.transformer.decoder import Decoder
|
||||
@@ -37,6 +38,9 @@ class GPT(nn.Module):
|
||||
self._num_layers = num_layers
|
||||
self._dropout = dropout
|
||||
self._device = device
|
||||
|
||||
self.train_loss = None
|
||||
self.validation_loss = None
|
||||
|
||||
# Инициализация слоев
|
||||
self._token_embeddings = TokenEmbeddings(
|
||||
@@ -206,7 +210,7 @@ class GPT(nn.Module):
|
||||
masked_logits = logits_scaled.clone()
|
||||
vocab_size = logits_scaled.size(-1)
|
||||
|
||||
# создаём маску: 1, если токен НЕ в topk_indices
|
||||
# создаём маску: True, если токен НЕ в topk_indices
|
||||
mask = torch.ones_like(logits_scaled, dtype=torch.bool)
|
||||
mask.scatter_(1, topk_indices, False) # False там, где top-k индексы
|
||||
masked_logits[mask] = float('-inf')
|
||||
@@ -225,9 +229,9 @@ class GPT(nn.Module):
|
||||
# Гарантируем, что хотя бы первый токен останется
|
||||
sorted_mask[:, 0] = True
|
||||
# 5. Преобразуем маску обратно в оригинальный порядок:
|
||||
# Создаём полную маску из 0
|
||||
# Создаём полную маску из False
|
||||
mask = torch.zeros_like(probs, dtype=torch.bool)
|
||||
# Устанавливаем 1 в местах нужных токенов
|
||||
# Устанавливаем True в местах нужных токенов
|
||||
mask.scatter_(dim=1, index=sorted_indices, src=sorted_mask)
|
||||
# 6. Зануляем логиты токенов вне топ-p:
|
||||
logits_scaled[~mask] = float('-inf')
|
||||
@@ -276,4 +280,108 @@ class GPT(nn.Module):
|
||||
@property
|
||||
def max_seq_len(self) -> int:
|
||||
"""Возвращает максимальную длину последовательности"""
|
||||
return self._max_seq_len
|
||||
return self._max_seq_len
|
||||
|
||||
|
||||
def fit(self,
|
||||
train_loader: DataLoader,
|
||||
valid_loader: DataLoader = None,
|
||||
num_epoch: int = 1,
|
||||
learning_rate: float = 0.001
|
||||
):
|
||||
"""Обучает модель GPT на предоставленных данных.
|
||||
|
||||
Процесс обучения включает:
|
||||
- Прямой проход для получения предсказаний
|
||||
- Вычисление потерь с помощью кросс-энтропии
|
||||
- Обратное распространение ошибки
|
||||
- Обновление весов через оптимизатор Adam
|
||||
|
||||
Args:
|
||||
train_loader (DataLoader): Загрузчик обучающих данных, возвращающий пары (inputs, targets).
|
||||
inputs - тензор индексов токенов формы [batch_size, seq_len]
|
||||
targets - тензор индексов следующих токенов формы [batch_size, seq_len]
|
||||
valid_loader (DataLoader, optional): Загрузчик валидационных данных. Если None,
|
||||
валидация не выполняется. По умолчанию None.
|
||||
num_epoch (int, optional): Количество эпох обучения. По умолчанию 1.
|
||||
learning_rate (float, optional): Скорость обучения для оптимизатора. По умолчанию 0.001.
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
Raises:
|
||||
ValueError: Если train_loader равен None
|
||||
ValueError: Если num_epoch ≤ 0
|
||||
ValueError: Если learning_rate ≤ 0
|
||||
|
||||
Side Effects:
|
||||
- Обновляет параметры модели (self.parameters())
|
||||
- Устанавливает атрибуты:
|
||||
self.train_loss: Средние потери на обучении за последнюю эпоху
|
||||
self.validation_loss: Средние потери на валидации за последнюю эпоху (если valid_loader передан)
|
||||
|
||||
Examples:
|
||||
>>> from torch.utils.data import DataLoader, TensorDataset
|
||||
>>> # Создаем тестовые данные
|
||||
>>> inputs = torch.randint(0, 100, (100, 10))
|
||||
>>> targets = torch.randint(0, 100, (100, 10))
|
||||
>>> dataset = TensorDataset(inputs, targets)
|
||||
>>> loader = DataLoader(dataset, batch_size=32)
|
||||
>>> # Инициализируем модель
|
||||
>>> model = GPT(vocab_size=100, max_seq_len=20, emb_size=64,
|
||||
... num_heads=4, head_size=16, num_layers=2)
|
||||
>>> # Обучаем модель
|
||||
>>> model.fit(loader, num_epoch=5, learning_rate=0.001)
|
||||
"""
|
||||
if train_loader is None:
|
||||
raise ValueError("train_loader не может быть None")
|
||||
if num_epoch <= 0:
|
||||
raise ValueError("num_epoch должен быть > 0")
|
||||
if learning_rate <= 0:
|
||||
raise ValueError("learning_rate должен быть > 0")
|
||||
|
||||
device = torch.device(self._device)
|
||||
self.to(device)
|
||||
|
||||
optimizer = torch.optim.Adam(self.parameters(), lr=learning_rate)
|
||||
|
||||
for epoch in range(num_epoch):
|
||||
self.train()
|
||||
epoch_loss = 0.0
|
||||
|
||||
#for inputs, targets in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epoch}"):
|
||||
for inputs, targets in train_loader:
|
||||
inputs = inputs.to(device)
|
||||
targets = targets.to(device)
|
||||
|
||||
logits = self(inputs)
|
||||
logits = logits.view(-1, logits.size(-1))
|
||||
targets = targets.view(-1)
|
||||
|
||||
loss = F.cross_entropy(logits, targets)
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
epoch_loss += loss.item()
|
||||
|
||||
self.train_loss = epoch_loss / len(train_loader)
|
||||
#print(f"[{epoch+1}/{num_epoch}] Train Loss: {self.train_loss:.4f}", end='')
|
||||
|
||||
if valid_loader is not None:
|
||||
self.eval()
|
||||
valid_loss = 0.0
|
||||
with torch.no_grad():
|
||||
for inputs, targets in valid_loader:
|
||||
inputs = inputs.to(device)
|
||||
targets = targets.to(device)
|
||||
|
||||
logits = self(inputs)
|
||||
logits = logits.view(-1, logits.size(-1))
|
||||
targets = targets.view(-1)
|
||||
|
||||
loss = F.cross_entropy(logits, targets)
|
||||
valid_loss += loss.item()
|
||||
|
||||
self.validation_loss = valid_loss / len(valid_loader)
|
||||
#print(f" | Val Loss: {self.validation_loss:.4f}")
|
||||
Reference in New Issue
Block a user