Files
simple-llm/tests/test_gpt_training.py
Sergey Penkovsky 8b0dd9c504 Реализация и документирование метода fit() для обучения GPT
Основные изменения:
1. Реализация метода fit():
- Добавлен полный цикл обучения (forward/backward pass)
- Поддержка обучения на CPU/GPU
- Расчет и сохранение метрик (train_loss, validation_loss)
- Интеграция с оптимизатором Adam

2. Документация:
- Подробное описание метода в gpt_documentation_ru.md
- Примеры использования в README.md
- Параметры и требования к данным

3. Тестирование:
- Тесты базовой функциональности
- Проверка изменения весов
- Тесты для разных устройств (CPU/CUDA)
- Обработка edge-cases

4. Примеры:
- train_gpt_example.py с полным workflow
- Генерация синтетических данных
- Сохранение/загрузка моделей
2025-07-23 12:38:39 +03:00

131 lines
4.4 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import pytest
import torch
from torch.utils.data import Dataset, DataLoader
from simple_llm.transformer.gpt import GPT
class DummyDataset(Dataset):
"""Тестовый датасет для проверки обучения"""
def __init__(self, vocab_size=100, seq_len=10, num_samples=100):
self.data = torch.randint(0, vocab_size, (num_samples, seq_len))
self.targets = torch.roll(self.data, -1, dims=1)
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx], self.targets[idx]
@pytest.fixture
def gpt_model():
"""Фикстура для создания тестовой модели GPT"""
return GPT(
vocab_size=100,
max_seq_len=10,
emb_size=64,
num_heads=4,
head_size=16,
num_layers=2
)
@pytest.fixture
def data_loaders():
"""Фикстура для создания тестовых DataLoader"""
dataset = DummyDataset()
train_loader = DataLoader(dataset, batch_size=10)
valid_loader = DataLoader(dataset, batch_size=10)
return train_loader, valid_loader
def test_fit_basic(gpt_model, data_loaders):
"""Тестирование базовой работы метода fit"""
train_loader, valid_loader = data_loaders
# Запускаем обучение
gpt_model.fit(
train_loader=train_loader,
valid_loader=valid_loader,
num_epoch=2,
learning_rate=0.001
)
# Проверяем что атрибуты loss были установлены
assert hasattr(gpt_model, 'train_loss'), "train_loss не был сохранен"
assert hasattr(gpt_model, 'validation_loss'), "validation_loss не был сохранен"
assert isinstance(gpt_model.train_loss, float), "train_loss должен быть числом"
assert isinstance(gpt_model.validation_loss, float), "validation_loss должен быть числом"
@pytest.mark.parametrize("device", ["cpu", "cuda"])
def test_fit_device(gpt_model, data_loaders, device):
"""Тестирование работы fit на разных устройствах"""
if device == "cuda" and not torch.cuda.is_available():
pytest.skip("CUDA не доступен для тестирования")
train_loader, valid_loader = data_loaders
# Переносим модель на нужное устройство
gpt_model._device = device
gpt_model.to(device)
gpt_model.fit(
train_loader=train_loader,
valid_loader=valid_loader,
num_epoch=1,
learning_rate=0.001
)
# Проверяем что модель действительно на нужном устройстве
assert next(gpt_model.parameters()).device.type == device
def test_fit_loss_decrease(gpt_model, data_loaders):
"""Тестирование что веса изменяются в процессе обучения"""
train_loader, valid_loader = data_loaders
# Сохраняем начальные веса
initial_state = {k: v.clone() for k, v in gpt_model.state_dict().items()}
gpt_model.fit(
train_loader=train_loader,
valid_loader=valid_loader,
num_epoch=3,
learning_rate=0.01
)
# Проверяем что веса изменились
changed = False
for k in initial_state:
if not torch.allclose(initial_state[k], gpt_model.state_dict()[k]):
changed = True
break
assert changed, "Веса модели не изменились после обучения"
def test_fit_without_validation(gpt_model, data_loaders):
"""Тестирование работы без валидационного набора"""
train_loader, _ = data_loaders
gpt_model.fit(
train_loader=train_loader,
valid_loader=None,
num_epoch=1,
learning_rate=0.001
)
assert hasattr(gpt_model, 'train_loss')
assert gpt_model.validation_loss is None
def test_fit_with_invalid_train_data(gpt_model):
"""Тестирование обработки невалидных train данных"""
with pytest.raises(ValueError, match="train_loader не может быть None"):
gpt_model.fit(
train_loader=None,
valid_loader=None,
num_epoch=1,
learning_rate=0.001
)