mirror of
https://github.com/pese-git/simple-llm.git
synced 2026-01-23 21:14:17 +00:00
203 lines
7.0 KiB
Python
203 lines
7.0 KiB
Python
|
|
import os
|
|||
|
|
import tempfile
|
|||
|
|
import torch
|
|||
|
|
import pytest
|
|||
|
|
from simple_llm.transformer.gpt import GPT
|
|||
|
|
from simple_llm.transformer.callback import ModelCheckpointCallback, ResumeTrainingCallback
|
|||
|
|
from torch.utils.data import DataLoader, TensorDataset
|
|||
|
|
|
|||
|
|
from simple_llm.transformer.callback import (
|
|||
|
|
LRSchedulerCallback,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
@pytest.fixture
|
|||
|
|
def sample_data():
|
|||
|
|
# Создаем тестовые данные
|
|||
|
|
inputs = torch.randint(0, 100, (100, 10)) # 100 samples, seq_len=10
|
|||
|
|
targets = torch.randint(0, 100, (100, 10))
|
|||
|
|
return DataLoader(TensorDataset(inputs, targets), batch_size=10)
|
|||
|
|
|
|||
|
|
@pytest.fixture
|
|||
|
|
def sample_model():
|
|||
|
|
return GPT(
|
|||
|
|
vocab_size=100,
|
|||
|
|
max_seq_len=10,
|
|||
|
|
emb_size=32,
|
|||
|
|
num_heads=4,
|
|||
|
|
head_size=8,
|
|||
|
|
num_layers=2
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
def test_model_checkpoint_saving(sample_model, sample_data):
|
|||
|
|
"""Тестирует корректность сохранения чекпоинтов"""
|
|||
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|||
|
|
checkpoint_cb = ModelCheckpointCallback(tmpdir, save_best_only=False)
|
|||
|
|
sample_model.fit(sample_data, num_epoch=1, callbacks=[checkpoint_cb])
|
|||
|
|
|
|||
|
|
files = os.listdir(tmpdir)
|
|||
|
|
assert len(files) == 1
|
|||
|
|
assert files[0].startswith('checkpoint_epoch_')
|
|||
|
|
|
|||
|
|
checkpoint = torch.load(os.path.join(tmpdir, files[0]))
|
|||
|
|
assert 'model_state_dict' in checkpoint
|
|||
|
|
assert 'optimizer_state_dict' in checkpoint
|
|||
|
|
assert 'epoch' in checkpoint
|
|||
|
|
assert 'train_loss' in checkpoint
|
|||
|
|
|
|||
|
|
def test_resume_training(sample_model, sample_data):
|
|||
|
|
"""Тестирует восстановление обучения из чекпоинта"""
|
|||
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|||
|
|
checkpoint_cb = ModelCheckpointCallback(tmpdir, save_best_only=False)
|
|||
|
|
sample_model.fit(sample_data, num_epoch=1, callbacks=[checkpoint_cb])
|
|||
|
|
|
|||
|
|
new_model = GPT(
|
|||
|
|
vocab_size=100,
|
|||
|
|
max_seq_len=10,
|
|||
|
|
emb_size=32,
|
|||
|
|
num_heads=4,
|
|||
|
|
head_size=8,
|
|||
|
|
num_layers=2
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
resume_cb = ResumeTrainingCallback(tmpdir)
|
|||
|
|
new_model.fit(
|
|||
|
|
sample_data,
|
|||
|
|
num_epoch=2,
|
|||
|
|
callbacks=[resume_cb, checkpoint_cb],
|
|||
|
|
resume_training=True
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
assert resume_cb.last_epoch == 0
|
|||
|
|
files = os.listdir(tmpdir)
|
|||
|
|
assert len(files) == 2
|
|||
|
|
|
|||
|
|
def test_resume_with_missing_checkpoint(sample_model, sample_data):
|
|||
|
|
"""Тестирует поведение при отсутствии чекпоинтов"""
|
|||
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|||
|
|
assert len(os.listdir(tmpdir)) == 0
|
|||
|
|
|
|||
|
|
resume_cb = ResumeTrainingCallback(tmpdir)
|
|||
|
|
sample_model.fit(
|
|||
|
|
sample_data,
|
|||
|
|
num_epoch=1,
|
|||
|
|
callbacks=[resume_cb],
|
|||
|
|
resume_training=True
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
assert resume_cb.last_epoch == -1
|
|||
|
|
|
|||
|
|
def test_resume_with_corrupted_checkpoint(sample_model, sample_data):
|
|||
|
|
"""Тестирует обработку битых чекпоинтов"""
|
|||
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|||
|
|
bad_checkpoint = os.path.join(tmpdir, "checkpoint_epoch_0.pt")
|
|||
|
|
with open(bad_checkpoint, 'w') as f:
|
|||
|
|
f.write("corrupted data")
|
|||
|
|
|
|||
|
|
resume_cb = ResumeTrainingCallback(tmpdir)
|
|||
|
|
|
|||
|
|
with pytest.raises(Exception):
|
|||
|
|
sample_model.fit(
|
|||
|
|
sample_data,
|
|||
|
|
num_epoch=1,
|
|||
|
|
callbacks=[resume_cb],
|
|||
|
|
resume_training=True
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
def test_optimizer_state_restoration(sample_model, sample_data):
|
|||
|
|
"""Тестирует восстановление состояния оптимизатора"""
|
|||
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|||
|
|
checkpoint_cb = ModelCheckpointCallback(tmpdir)
|
|||
|
|
sample_model.fit(sample_data, num_epoch=1, callbacks=[checkpoint_cb])
|
|||
|
|
|
|||
|
|
original_optimizer_state = sample_model.optimizer.state_dict()
|
|||
|
|
|
|||
|
|
new_model = GPT(
|
|||
|
|
vocab_size=100,
|
|||
|
|
max_seq_len=10,
|
|||
|
|
emb_size=32,
|
|||
|
|
num_heads=4,
|
|||
|
|
head_size=8,
|
|||
|
|
num_layers=2
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
resume_cb = ResumeTrainingCallback(tmpdir)
|
|||
|
|
new_model.fit(
|
|||
|
|
sample_data,
|
|||
|
|
num_epoch=2,
|
|||
|
|
callbacks=[resume_cb, checkpoint_cb],
|
|||
|
|
resume_training=True
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
assert 'state' in new_model.optimizer.state_dict()
|
|||
|
|
assert 'param_groups' in new_model.optimizer.state_dict()
|
|||
|
|
|
|||
|
|
# Проверяем только параметры, кроме lr (так как он меняется scheduler'ом)
|
|||
|
|
for key in original_optimizer_state['param_groups'][0]:
|
|||
|
|
if key not in ['params', 'lr']:
|
|||
|
|
assert (
|
|||
|
|
original_optimizer_state['param_groups'][0][key] ==
|
|||
|
|
new_model.optimizer.state_dict()['param_groups'][0][key]
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
def test_multiple_resumes(sample_model, sample_data):
|
|||
|
|
"""Тестирует многократное восстановление обучения"""
|
|||
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|||
|
|
checkpoint_cb = ModelCheckpointCallback(tmpdir)
|
|||
|
|
sample_model.fit(sample_data, num_epoch=1, callbacks=[checkpoint_cb])
|
|||
|
|
|
|||
|
|
resume_cb = ResumeTrainingCallback(tmpdir)
|
|||
|
|
sample_model.fit(
|
|||
|
|
sample_data,
|
|||
|
|
num_epoch=2,
|
|||
|
|
callbacks=[resume_cb, checkpoint_cb],
|
|||
|
|
resume_training=True
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
resume_cb = ResumeTrainingCallback(tmpdir)
|
|||
|
|
sample_model.fit(
|
|||
|
|
sample_data,
|
|||
|
|
num_epoch=3,
|
|||
|
|
callbacks=[resume_cb, checkpoint_cb],
|
|||
|
|
resume_training=True
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
files = os.listdir(tmpdir)
|
|||
|
|
assert len(files) == 3
|
|||
|
|
|
|||
|
|
def test_scheduler_state_restoration(sample_model, sample_data):
|
|||
|
|
"""Тестирует восстановление состояния LR"""
|
|||
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|||
|
|
checkpoint_cb = ModelCheckpointCallback(tmpdir)
|
|||
|
|
lr_scheduler_cb = LRSchedulerCallback(lr=0.001)
|
|||
|
|
|
|||
|
|
sample_model.fit(
|
|||
|
|
sample_data,
|
|||
|
|
num_epoch=1,
|
|||
|
|
callbacks=[checkpoint_cb, lr_scheduler_cb],
|
|||
|
|
learning_rate=0.001
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# Сохраняем текущий lr с учетом decay
|
|||
|
|
expected_lr = 0.001 * (0.95 ** 1) # decay^epoch
|
|||
|
|
|
|||
|
|
new_model = GPT(
|
|||
|
|
vocab_size=100,
|
|||
|
|
max_seq_len=10,
|
|||
|
|
emb_size=32,
|
|||
|
|
num_heads=4,
|
|||
|
|
head_size=8,
|
|||
|
|
num_layers=2
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
resume_cb = ResumeTrainingCallback(tmpdir)
|
|||
|
|
new_model.fit(
|
|||
|
|
sample_data,
|
|||
|
|
num_epoch=2,
|
|||
|
|
callbacks=[resume_cb, checkpoint_cb, lr_scheduler_cb],
|
|||
|
|
resume_training=True,
|
|||
|
|
learning_rate=0.001
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# Проверяем что LR восстановлен с учетом decay
|
|||
|
|
assert new_model.optimizer.param_groups[0]['lr'] == pytest.approx(expected_lr)
|