Files
simple-llm/tests/test_resume_training.py

210 lines
7.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import tempfile
import torch
import pytest
from simple_llm.transformer.gpt import GPT
from simple_llm.transformer.callback import ModelCheckpointCallback, ResumeTrainingCallback
from torch.utils.data import DataLoader, TensorDataset
from simple_llm.transformer.callback import (
LRSchedulerCallback,
)
@pytest.fixture
def sample_data():
# Создаем тестовые данные
inputs = torch.randint(0, 100, (100, 10)) # 100 samples, seq_len=10
targets = torch.randint(0, 100, (100, 10))
return DataLoader(TensorDataset(inputs, targets), batch_size=10)
@pytest.fixture
def sample_model():
return GPT(
vocab_size=100,
max_seq_len=10,
emb_size=32,
num_heads=4,
head_size=8,
num_layers=2
)
def test_model_checkpoint_saving(sample_model, sample_data):
"""Тестирует корректность сохранения чекпоинтов"""
with tempfile.TemporaryDirectory() as tmpdir:
checkpoint_cb = ModelCheckpointCallback(tmpdir, save_best_only=False)
sample_model.fit(sample_data, num_epoch=1, callbacks=[checkpoint_cb])
files = os.listdir(tmpdir)
assert len(files) == 1
assert files[0].startswith('checkpoint_epoch_')
checkpoint = torch.load(os.path.join(tmpdir, files[0]))
assert 'model_state_dict' in checkpoint
assert 'optimizer_state_dict' in checkpoint
assert 'epoch' in checkpoint
assert 'train_loss' in checkpoint
def test_resume_training(sample_model, sample_data):
"""Тестирует восстановление обучения из чекпоинта"""
with tempfile.TemporaryDirectory() as tmpdir:
checkpoint_cb = ModelCheckpointCallback(tmpdir, save_best_only=False)
sample_model.fit(sample_data, num_epoch=1, callbacks=[checkpoint_cb])
# Проверим, что чекпоинт создан
files = os.listdir(tmpdir)
assert any(f.startswith("checkpoint_epoch_") for f in files)
new_model = GPT(
vocab_size=100,
max_seq_len=10,
emb_size=32,
num_heads=4,
head_size=8,
num_layers=2
)
resume_cb = ResumeTrainingCallback(tmpdir)
resume_cb.on_train_begin(new_model)
# После создания чекпоинта и on_train_begin, last_epoch должен быть 0
assert resume_cb.last_epoch == 0
# Дальнейшее обучение с fit/resume
new_model.fit(
sample_data,
num_epoch=2,
callbacks=[resume_cb, checkpoint_cb],
resume_training=True
)
files = [f for f in os.listdir(tmpdir) if f.startswith('checkpoint_epoch_')]
assert len(files) == 3
def test_resume_with_missing_checkpoint(sample_model, sample_data):
"""Тестирует поведение при отсутствии чекпоинтов"""
with tempfile.TemporaryDirectory() as tmpdir:
assert len(os.listdir(tmpdir)) == 0
resume_cb = ResumeTrainingCallback(tmpdir)
sample_model.fit(
sample_data,
num_epoch=1,
callbacks=[resume_cb],
resume_training=True
)
assert resume_cb.last_epoch == -1
def test_resume_with_corrupted_checkpoint(sample_model, sample_data):
"""Тестирует обработку битых чекпоинтов"""
with tempfile.TemporaryDirectory() as tmpdir:
bad_checkpoint = os.path.join(tmpdir, "checkpoint_epoch_0.pt")
with open(bad_checkpoint, 'w') as f:
f.write("corrupted data")
resume_cb = ResumeTrainingCallback(tmpdir)
sample_model.fit(
sample_data,
num_epoch=1,
callbacks=[resume_cb],
resume_training=True
)
assert resume_cb.last_epoch == -1
def test_optimizer_state_restoration(sample_model, sample_data):
"""Тестирует восстановление состояния оптимизатора"""
with tempfile.TemporaryDirectory() as tmpdir:
checkpoint_cb = ModelCheckpointCallback(tmpdir)
sample_model.fit(sample_data, num_epoch=1, callbacks=[checkpoint_cb])
original_optimizer_state = sample_model.optimizer.state_dict()
new_model = GPT(
vocab_size=100,
max_seq_len=10,
emb_size=32,
num_heads=4,
head_size=8,
num_layers=2
)
resume_cb = ResumeTrainingCallback(tmpdir)
new_model.fit(
sample_data,
num_epoch=2,
callbacks=[resume_cb, checkpoint_cb],
resume_training=True
)
assert 'state' in new_model.optimizer.state_dict()
assert 'param_groups' in new_model.optimizer.state_dict()
# Проверяем только параметры, кроме lr (так как он меняется scheduler'ом)
for key in original_optimizer_state['param_groups'][0]:
if key not in ['params', 'lr']:
assert (
original_optimizer_state['param_groups'][0][key] ==
new_model.optimizer.state_dict()['param_groups'][0][key]
)
def test_multiple_resumes(sample_model, sample_data):
"""Тестирует многократное восстановление обучения"""
with tempfile.TemporaryDirectory() as tmpdir:
checkpoint_cb = ModelCheckpointCallback(tmpdir)
sample_model.fit(sample_data, num_epoch=1, callbacks=[checkpoint_cb])
resume_cb = ResumeTrainingCallback(tmpdir)
sample_model.fit(
sample_data,
num_epoch=2,
callbacks=[resume_cb, checkpoint_cb],
resume_training=True
)
resume_cb = ResumeTrainingCallback(tmpdir)
sample_model.fit(
sample_data,
num_epoch=3,
callbacks=[resume_cb, checkpoint_cb],
resume_training=True
)
files = os.listdir(tmpdir)
assert len(files) == 3
def test_scheduler_state_restoration(sample_model, sample_data):
"""Тестирует восстановление состояния LR"""
with tempfile.TemporaryDirectory() as tmpdir:
checkpoint_cb = ModelCheckpointCallback(tmpdir)
lr_scheduler_cb = LRSchedulerCallback(lr=0.001)
sample_model.fit(
sample_data,
num_epoch=1,
callbacks=[checkpoint_cb, lr_scheduler_cb],
learning_rate=0.001
)
# Сохраняем текущий lr с учетом decay
expected_lr = 0.001 * (0.95 ** 1) # decay^epoch
new_model = GPT(
vocab_size=100,
max_seq_len=10,
emb_size=32,
num_heads=4,
head_size=8,
num_layers=2
)
resume_cb = ResumeTrainingCallback(tmpdir)
new_model.fit(
sample_data,
num_epoch=2,
callbacks=[resume_cb, checkpoint_cb, lr_scheduler_cb],
resume_training=True,
learning_rate=0.001
)
# Проверяем что LR восстановлен с учетом decay
assert new_model.optimizer.param_groups[0]['lr'] == pytest.approx(expected_lr)