mirror of
https://github.com/pese-git/simple-llm.git
synced 2026-01-23 21:14:17 +00:00
210 lines
7.4 KiB
Python
210 lines
7.4 KiB
Python
import os
|
||
import tempfile
|
||
import torch
|
||
import pytest
|
||
from simple_llm.transformer.gpt import GPT
|
||
from simple_llm.transformer.callback import ModelCheckpointCallback, ResumeTrainingCallback
|
||
from torch.utils.data import DataLoader, TensorDataset
|
||
|
||
from simple_llm.transformer.callback import (
|
||
LRSchedulerCallback,
|
||
)
|
||
|
||
@pytest.fixture
|
||
def sample_data():
|
||
# Создаем тестовые данные
|
||
inputs = torch.randint(0, 100, (100, 10)) # 100 samples, seq_len=10
|
||
targets = torch.randint(0, 100, (100, 10))
|
||
return DataLoader(TensorDataset(inputs, targets), batch_size=10)
|
||
|
||
@pytest.fixture
|
||
def sample_model():
|
||
return GPT(
|
||
vocab_size=100,
|
||
max_seq_len=10,
|
||
emb_size=32,
|
||
num_heads=4,
|
||
head_size=8,
|
||
num_layers=2
|
||
)
|
||
|
||
def test_model_checkpoint_saving(sample_model, sample_data):
|
||
"""Тестирует корректность сохранения чекпоинтов"""
|
||
with tempfile.TemporaryDirectory() as tmpdir:
|
||
checkpoint_cb = ModelCheckpointCallback(tmpdir, save_best_only=False)
|
||
sample_model.fit(sample_data, num_epoch=1, callbacks=[checkpoint_cb])
|
||
|
||
files = os.listdir(tmpdir)
|
||
assert len(files) == 1
|
||
assert files[0].startswith('checkpoint_epoch_')
|
||
|
||
checkpoint = torch.load(os.path.join(tmpdir, files[0]))
|
||
assert 'model_state_dict' in checkpoint
|
||
assert 'optimizer_state_dict' in checkpoint
|
||
assert 'epoch' in checkpoint
|
||
assert 'train_loss' in checkpoint
|
||
|
||
def test_resume_training(sample_model, sample_data):
|
||
"""Тестирует восстановление обучения из чекпоинта"""
|
||
with tempfile.TemporaryDirectory() as tmpdir:
|
||
checkpoint_cb = ModelCheckpointCallback(tmpdir, save_best_only=False)
|
||
sample_model.fit(sample_data, num_epoch=1, callbacks=[checkpoint_cb])
|
||
|
||
# Проверим, что чекпоинт создан
|
||
files = os.listdir(tmpdir)
|
||
assert any(f.startswith("checkpoint_epoch_") for f in files)
|
||
|
||
new_model = GPT(
|
||
vocab_size=100,
|
||
max_seq_len=10,
|
||
emb_size=32,
|
||
num_heads=4,
|
||
head_size=8,
|
||
num_layers=2
|
||
)
|
||
|
||
resume_cb = ResumeTrainingCallback(tmpdir)
|
||
resume_cb.on_train_begin(new_model)
|
||
# После создания чекпоинта и on_train_begin, last_epoch должен быть 0
|
||
assert resume_cb.last_epoch == 0
|
||
|
||
# Дальнейшее обучение с fit/resume
|
||
new_model.fit(
|
||
sample_data,
|
||
num_epoch=2,
|
||
callbacks=[resume_cb, checkpoint_cb],
|
||
resume_training=True
|
||
)
|
||
files = [f for f in os.listdir(tmpdir) if f.startswith('checkpoint_epoch_')]
|
||
assert len(files) == 3
|
||
|
||
def test_resume_with_missing_checkpoint(sample_model, sample_data):
|
||
"""Тестирует поведение при отсутствии чекпоинтов"""
|
||
with tempfile.TemporaryDirectory() as tmpdir:
|
||
assert len(os.listdir(tmpdir)) == 0
|
||
|
||
resume_cb = ResumeTrainingCallback(tmpdir)
|
||
sample_model.fit(
|
||
sample_data,
|
||
num_epoch=1,
|
||
callbacks=[resume_cb],
|
||
resume_training=True
|
||
)
|
||
|
||
assert resume_cb.last_epoch == -1
|
||
|
||
def test_resume_with_corrupted_checkpoint(sample_model, sample_data):
|
||
"""Тестирует обработку битых чекпоинтов"""
|
||
with tempfile.TemporaryDirectory() as tmpdir:
|
||
bad_checkpoint = os.path.join(tmpdir, "checkpoint_epoch_0.pt")
|
||
with open(bad_checkpoint, 'w') as f:
|
||
f.write("corrupted data")
|
||
|
||
resume_cb = ResumeTrainingCallback(tmpdir)
|
||
|
||
sample_model.fit(
|
||
sample_data,
|
||
num_epoch=1,
|
||
callbacks=[resume_cb],
|
||
resume_training=True
|
||
)
|
||
assert resume_cb.last_epoch == -1
|
||
|
||
def test_optimizer_state_restoration(sample_model, sample_data):
|
||
"""Тестирует восстановление состояния оптимизатора"""
|
||
with tempfile.TemporaryDirectory() as tmpdir:
|
||
checkpoint_cb = ModelCheckpointCallback(tmpdir)
|
||
sample_model.fit(sample_data, num_epoch=1, callbacks=[checkpoint_cb])
|
||
|
||
original_optimizer_state = sample_model.optimizer.state_dict()
|
||
|
||
new_model = GPT(
|
||
vocab_size=100,
|
||
max_seq_len=10,
|
||
emb_size=32,
|
||
num_heads=4,
|
||
head_size=8,
|
||
num_layers=2
|
||
)
|
||
|
||
resume_cb = ResumeTrainingCallback(tmpdir)
|
||
new_model.fit(
|
||
sample_data,
|
||
num_epoch=2,
|
||
callbacks=[resume_cb, checkpoint_cb],
|
||
resume_training=True
|
||
)
|
||
|
||
assert 'state' in new_model.optimizer.state_dict()
|
||
assert 'param_groups' in new_model.optimizer.state_dict()
|
||
|
||
# Проверяем только параметры, кроме lr (так как он меняется scheduler'ом)
|
||
for key in original_optimizer_state['param_groups'][0]:
|
||
if key not in ['params', 'lr']:
|
||
assert (
|
||
original_optimizer_state['param_groups'][0][key] ==
|
||
new_model.optimizer.state_dict()['param_groups'][0][key]
|
||
)
|
||
|
||
def test_multiple_resumes(sample_model, sample_data):
|
||
"""Тестирует многократное восстановление обучения"""
|
||
with tempfile.TemporaryDirectory() as tmpdir:
|
||
checkpoint_cb = ModelCheckpointCallback(tmpdir)
|
||
sample_model.fit(sample_data, num_epoch=1, callbacks=[checkpoint_cb])
|
||
|
||
resume_cb = ResumeTrainingCallback(tmpdir)
|
||
sample_model.fit(
|
||
sample_data,
|
||
num_epoch=2,
|
||
callbacks=[resume_cb, checkpoint_cb],
|
||
resume_training=True
|
||
)
|
||
|
||
resume_cb = ResumeTrainingCallback(tmpdir)
|
||
sample_model.fit(
|
||
sample_data,
|
||
num_epoch=3,
|
||
callbacks=[resume_cb, checkpoint_cb],
|
||
resume_training=True
|
||
)
|
||
|
||
files = os.listdir(tmpdir)
|
||
assert len(files) == 3
|
||
|
||
def test_scheduler_state_restoration(sample_model, sample_data):
|
||
"""Тестирует восстановление состояния LR"""
|
||
with tempfile.TemporaryDirectory() as tmpdir:
|
||
checkpoint_cb = ModelCheckpointCallback(tmpdir)
|
||
lr_scheduler_cb = LRSchedulerCallback(lr=0.001)
|
||
|
||
sample_model.fit(
|
||
sample_data,
|
||
num_epoch=1,
|
||
callbacks=[checkpoint_cb, lr_scheduler_cb],
|
||
learning_rate=0.001
|
||
)
|
||
|
||
# Сохраняем текущий lr с учетом decay
|
||
expected_lr = 0.001 * (0.95 ** 1) # decay^epoch
|
||
|
||
new_model = GPT(
|
||
vocab_size=100,
|
||
max_seq_len=10,
|
||
emb_size=32,
|
||
num_heads=4,
|
||
head_size=8,
|
||
num_layers=2
|
||
)
|
||
|
||
resume_cb = ResumeTrainingCallback(tmpdir)
|
||
new_model.fit(
|
||
sample_data,
|
||
num_epoch=2,
|
||
callbacks=[resume_cb, checkpoint_cb, lr_scheduler_cb],
|
||
resume_training=True,
|
||
learning_rate=0.001
|
||
)
|
||
|
||
# Проверяем что LR восстановлен с учетом decay
|
||
assert new_model.optimizer.param_groups[0]['lr'] == pytest.approx(expected_lr)
|