mirror of
https://github.com/pese-git/simple-llm.git
synced 2026-01-23 21:14:17 +00:00
Реализация и документирование метода fit() для обучения GPT
Основные изменения: 1. Реализация метода fit(): - Добавлен полный цикл обучения (forward/backward pass) - Поддержка обучения на CPU/GPU - Расчет и сохранение метрик (train_loss, validation_loss) - Интеграция с оптимизатором Adam 2. Документация: - Подробное описание метода в gpt_documentation_ru.md - Примеры использования в README.md - Параметры и требования к данным 3. Тестирование: - Тесты базовой функциональности - Проверка изменения весов - Тесты для разных устройств (CPU/CUDA) - Обработка edge-cases 4. Примеры: - train_gpt_example.py с полным workflow - Генерация синтетических данных - Сохранение/загрузка моделей
This commit is contained in:
78
example/train_gpt_example.py
Normal file
78
example/train_gpt_example.py
Normal file
@@ -0,0 +1,78 @@
|
||||
"""
|
||||
Пример обучения модели GPT на синтетических данных
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from simple_llm.transformer.gpt import GPT
|
||||
|
||||
class SyntheticDataset(Dataset):
|
||||
"""Синтетический датасет для демонстрации обучения GPT"""
|
||||
def __init__(self, vocab_size=100, seq_len=20, num_samples=1000):
|
||||
self.data = torch.randint(0, vocab_size, (num_samples, seq_len))
|
||||
self.targets = torch.roll(self.data, -1, dims=1) # Сдвигаем на 1 для предсказания следующего токена
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self.data[idx], self.targets[idx]
|
||||
|
||||
def train_gpt():
|
||||
# Параметры модели
|
||||
VOCAB_SIZE = 100
|
||||
SEQ_LEN = 20
|
||||
EMB_SIZE = 128
|
||||
N_HEADS = 4
|
||||
HEAD_SIZE = 32
|
||||
N_LAYERS = 3
|
||||
|
||||
# Инициализация модели
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
model = GPT(
|
||||
vocab_size=VOCAB_SIZE,
|
||||
max_seq_len=SEQ_LEN,
|
||||
emb_size=EMB_SIZE,
|
||||
num_heads=N_HEADS,
|
||||
head_size=HEAD_SIZE,
|
||||
num_layers=N_LAYERS,
|
||||
device=device
|
||||
)
|
||||
|
||||
# Создание датасета и загрузчика
|
||||
dataset = SyntheticDataset(vocab_size=VOCAB_SIZE, seq_len=SEQ_LEN)
|
||||
train_loader = DataLoader(dataset, batch_size=32, shuffle=True)
|
||||
|
||||
print(f"Начало обучения на {device}...")
|
||||
print(f"Размер словаря: {VOCAB_SIZE}")
|
||||
print(f"Длина последовательности: {SEQ_LEN}")
|
||||
print(f"Размер батча: 32")
|
||||
print(f"Количество эпох: 5\n")
|
||||
|
||||
# Обучение модели
|
||||
model.fit(
|
||||
train_loader=train_loader,
|
||||
valid_loader=None, # Можно добавить валидационный набор
|
||||
num_epoch=5,
|
||||
learning_rate=0.001
|
||||
)
|
||||
|
||||
# Сохранение модели
|
||||
model_path = "trained_gpt_model.pt"
|
||||
model.save(model_path)
|
||||
print(f"\nМодель сохранена в {model_path}")
|
||||
|
||||
# Пример генерации текста
|
||||
print("\nПример генерации:")
|
||||
input_seq = torch.randint(0, VOCAB_SIZE, (1, 10)).to(device)
|
||||
generated = model.generate(
|
||||
x=input_seq,
|
||||
max_new_tokens=10,
|
||||
do_sample=True,
|
||||
temperature=0.7
|
||||
)
|
||||
print(f"Вход: {input_seq.tolist()[0]}")
|
||||
print(f"Сгенерировано: {generated.tolist()[0]}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
train_gpt()
|
||||
Reference in New Issue
Block a user