Files
simple-llm/example/train_gpt_example.py
Sergey Penkovsky 8b0dd9c504 Реализация и документирование метода fit() для обучения GPT
Основные изменения:
1. Реализация метода fit():
- Добавлен полный цикл обучения (forward/backward pass)
- Поддержка обучения на CPU/GPU
- Расчет и сохранение метрик (train_loss, validation_loss)
- Интеграция с оптимизатором Adam

2. Документация:
- Подробное описание метода в gpt_documentation_ru.md
- Примеры использования в README.md
- Параметры и требования к данным

3. Тестирование:
- Тесты базовой функциональности
- Проверка изменения весов
- Тесты для разных устройств (CPU/CUDA)
- Обработка edge-cases

4. Примеры:
- train_gpt_example.py с полным workflow
- Генерация синтетических данных
- Сохранение/загрузка моделей
2025-07-23 12:38:39 +03:00

79 lines
2.6 KiB
Python

"""
Пример обучения модели GPT на синтетических данных
"""
import torch
from torch.utils.data import Dataset, DataLoader
from simple_llm.transformer.gpt import GPT
class SyntheticDataset(Dataset):
"""Синтетический датасет для демонстрации обучения GPT"""
def __init__(self, vocab_size=100, seq_len=20, num_samples=1000):
self.data = torch.randint(0, vocab_size, (num_samples, seq_len))
self.targets = torch.roll(self.data, -1, dims=1) # Сдвигаем на 1 для предсказания следующего токена
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx], self.targets[idx]
def train_gpt():
# Параметры модели
VOCAB_SIZE = 100
SEQ_LEN = 20
EMB_SIZE = 128
N_HEADS = 4
HEAD_SIZE = 32
N_LAYERS = 3
# Инициализация модели
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = GPT(
vocab_size=VOCAB_SIZE,
max_seq_len=SEQ_LEN,
emb_size=EMB_SIZE,
num_heads=N_HEADS,
head_size=HEAD_SIZE,
num_layers=N_LAYERS,
device=device
)
# Создание датасета и загрузчика
dataset = SyntheticDataset(vocab_size=VOCAB_SIZE, seq_len=SEQ_LEN)
train_loader = DataLoader(dataset, batch_size=32, shuffle=True)
print(f"Начало обучения на {device}...")
print(f"Размер словаря: {VOCAB_SIZE}")
print(f"Длина последовательности: {SEQ_LEN}")
print(f"Размер батча: 32")
print(f"Количество эпох: 5\n")
# Обучение модели
model.fit(
train_loader=train_loader,
valid_loader=None, # Можно добавить валидационный набор
num_epoch=5,
learning_rate=0.001
)
# Сохранение модели
model_path = "trained_gpt_model.pt"
model.save(model_path)
print(f"\nМодель сохранена в {model_path}")
# Пример генерации текста
print("\nПример генерации:")
input_seq = torch.randint(0, VOCAB_SIZE, (1, 10)).to(device)
generated = model.generate(
x=input_seq,
max_new_tokens=10,
do_sample=True,
temperature=0.7
)
print(f"Вход: {input_seq.tolist()[0]}")
print(f"Сгенерировано: {generated.tolist()[0]}")
if __name__ == "__main__":
train_gpt()