Fix checkpoint resume logic, global epoch numbering, and robust recovery; update tests for checkpointing; all tests passing

This commit is contained in:
Sergey Penkovsky
2025-08-01 10:40:08 +03:00
parent f7364070f0
commit 08f0356b4d
6 changed files with 113 additions and 53 deletions

View File

@@ -49,7 +49,11 @@ def test_resume_training(sample_model, sample_data):
with tempfile.TemporaryDirectory() as tmpdir:
checkpoint_cb = ModelCheckpointCallback(tmpdir, save_best_only=False)
sample_model.fit(sample_data, num_epoch=1, callbacks=[checkpoint_cb])
# Проверим, что чекпоинт создан
files = os.listdir(tmpdir)
assert any(f.startswith("checkpoint_epoch_") for f in files)
new_model = GPT(
vocab_size=100,
max_seq_len=10,
@@ -58,18 +62,21 @@ def test_resume_training(sample_model, sample_data):
head_size=8,
num_layers=2
)
resume_cb = ResumeTrainingCallback(tmpdir)
resume_cb.on_train_begin(new_model)
# После создания чекпоинта и on_train_begin, last_epoch должен быть 0
assert resume_cb.last_epoch == 0
# Дальнейшее обучение с fit/resume
new_model.fit(
sample_data,
num_epoch=2,
sample_data,
num_epoch=2,
callbacks=[resume_cb, checkpoint_cb],
resume_training=True
)
assert resume_cb.last_epoch == 0
files = os.listdir(tmpdir)
assert len(files) == 2
files = [f for f in os.listdir(tmpdir) if f.startswith('checkpoint_epoch_')]
assert len(files) == 3
def test_resume_with_missing_checkpoint(sample_model, sample_data):
"""Тестирует поведение при отсутствии чекпоинтов"""
@@ -95,13 +102,13 @@ def test_resume_with_corrupted_checkpoint(sample_model, sample_data):
resume_cb = ResumeTrainingCallback(tmpdir)
with pytest.raises(Exception):
sample_model.fit(
sample_data,
num_epoch=1,
callbacks=[resume_cb],
resume_training=True
)
sample_model.fit(
sample_data,
num_epoch=1,
callbacks=[resume_cb],
resume_training=True
)
assert resume_cb.last_epoch == -1
def test_optimizer_state_restoration(sample_model, sample_data):
"""Тестирует восстановление состояния оптимизатора"""