Files
simple-llm/tests/test_resume_training.py

203 lines
7.0 KiB
Python
Raw Normal View History

import os
import tempfile
import torch
import pytest
from simple_llm.transformer.gpt import GPT
from simple_llm.transformer.callback import ModelCheckpointCallback, ResumeTrainingCallback
from torch.utils.data import DataLoader, TensorDataset
from simple_llm.transformer.callback import (
LRSchedulerCallback,
)
@pytest.fixture
def sample_data():
# Создаем тестовые данные
inputs = torch.randint(0, 100, (100, 10)) # 100 samples, seq_len=10
targets = torch.randint(0, 100, (100, 10))
return DataLoader(TensorDataset(inputs, targets), batch_size=10)
@pytest.fixture
def sample_model():
return GPT(
vocab_size=100,
max_seq_len=10,
emb_size=32,
num_heads=4,
head_size=8,
num_layers=2
)
def test_model_checkpoint_saving(sample_model, sample_data):
"""Тестирует корректность сохранения чекпоинтов"""
with tempfile.TemporaryDirectory() as tmpdir:
checkpoint_cb = ModelCheckpointCallback(tmpdir, save_best_only=False)
sample_model.fit(sample_data, num_epoch=1, callbacks=[checkpoint_cb])
files = os.listdir(tmpdir)
assert len(files) == 1
assert files[0].startswith('checkpoint_epoch_')
checkpoint = torch.load(os.path.join(tmpdir, files[0]))
assert 'model_state_dict' in checkpoint
assert 'optimizer_state_dict' in checkpoint
assert 'epoch' in checkpoint
assert 'train_loss' in checkpoint
def test_resume_training(sample_model, sample_data):
"""Тестирует восстановление обучения из чекпоинта"""
with tempfile.TemporaryDirectory() as tmpdir:
checkpoint_cb = ModelCheckpointCallback(tmpdir, save_best_only=False)
sample_model.fit(sample_data, num_epoch=1, callbacks=[checkpoint_cb])
new_model = GPT(
vocab_size=100,
max_seq_len=10,
emb_size=32,
num_heads=4,
head_size=8,
num_layers=2
)
resume_cb = ResumeTrainingCallback(tmpdir)
new_model.fit(
sample_data,
num_epoch=2,
callbacks=[resume_cb, checkpoint_cb],
resume_training=True
)
assert resume_cb.last_epoch == 0
files = os.listdir(tmpdir)
assert len(files) == 2
def test_resume_with_missing_checkpoint(sample_model, sample_data):
"""Тестирует поведение при отсутствии чекпоинтов"""
with tempfile.TemporaryDirectory() as tmpdir:
assert len(os.listdir(tmpdir)) == 0
resume_cb = ResumeTrainingCallback(tmpdir)
sample_model.fit(
sample_data,
num_epoch=1,
callbacks=[resume_cb],
resume_training=True
)
assert resume_cb.last_epoch == -1
def test_resume_with_corrupted_checkpoint(sample_model, sample_data):
"""Тестирует обработку битых чекпоинтов"""
with tempfile.TemporaryDirectory() as tmpdir:
bad_checkpoint = os.path.join(tmpdir, "checkpoint_epoch_0.pt")
with open(bad_checkpoint, 'w') as f:
f.write("corrupted data")
resume_cb = ResumeTrainingCallback(tmpdir)
with pytest.raises(Exception):
sample_model.fit(
sample_data,
num_epoch=1,
callbacks=[resume_cb],
resume_training=True
)
def test_optimizer_state_restoration(sample_model, sample_data):
"""Тестирует восстановление состояния оптимизатора"""
with tempfile.TemporaryDirectory() as tmpdir:
checkpoint_cb = ModelCheckpointCallback(tmpdir)
sample_model.fit(sample_data, num_epoch=1, callbacks=[checkpoint_cb])
original_optimizer_state = sample_model.optimizer.state_dict()
new_model = GPT(
vocab_size=100,
max_seq_len=10,
emb_size=32,
num_heads=4,
head_size=8,
num_layers=2
)
resume_cb = ResumeTrainingCallback(tmpdir)
new_model.fit(
sample_data,
num_epoch=2,
callbacks=[resume_cb, checkpoint_cb],
resume_training=True
)
assert 'state' in new_model.optimizer.state_dict()
assert 'param_groups' in new_model.optimizer.state_dict()
# Проверяем только параметры, кроме lr (так как он меняется scheduler'ом)
for key in original_optimizer_state['param_groups'][0]:
if key not in ['params', 'lr']:
assert (
original_optimizer_state['param_groups'][0][key] ==
new_model.optimizer.state_dict()['param_groups'][0][key]
)
def test_multiple_resumes(sample_model, sample_data):
"""Тестирует многократное восстановление обучения"""
with tempfile.TemporaryDirectory() as tmpdir:
checkpoint_cb = ModelCheckpointCallback(tmpdir)
sample_model.fit(sample_data, num_epoch=1, callbacks=[checkpoint_cb])
resume_cb = ResumeTrainingCallback(tmpdir)
sample_model.fit(
sample_data,
num_epoch=2,
callbacks=[resume_cb, checkpoint_cb],
resume_training=True
)
resume_cb = ResumeTrainingCallback(tmpdir)
sample_model.fit(
sample_data,
num_epoch=3,
callbacks=[resume_cb, checkpoint_cb],
resume_training=True
)
files = os.listdir(tmpdir)
assert len(files) == 3
def test_scheduler_state_restoration(sample_model, sample_data):
"""Тестирует восстановление состояния LR"""
with tempfile.TemporaryDirectory() as tmpdir:
checkpoint_cb = ModelCheckpointCallback(tmpdir)
lr_scheduler_cb = LRSchedulerCallback(lr=0.001)
sample_model.fit(
sample_data,
num_epoch=1,
callbacks=[checkpoint_cb, lr_scheduler_cb],
learning_rate=0.001
)
# Сохраняем текущий lr с учетом decay
expected_lr = 0.001 * (0.95 ** 1) # decay^epoch
new_model = GPT(
vocab_size=100,
max_seq_len=10,
emb_size=32,
num_heads=4,
head_size=8,
num_layers=2
)
resume_cb = ResumeTrainingCallback(tmpdir)
new_model.fit(
sample_data,
num_epoch=2,
callbacks=[resume_cb, checkpoint_cb, lr_scheduler_cb],
resume_training=True,
learning_rate=0.001
)
# Проверяем что LR восстановлен с учетом decay
assert new_model.optimizer.param_groups[0]['lr'] == pytest.approx(expected_lr)