mirror of
https://github.com/pese-git/simple-llm.git
synced 2026-01-24 05:26:02 +00:00
Fix checkpoint resume logic, global epoch numbering, and robust recovery; update tests for checkpointing; all tests passing
This commit is contained in:
@@ -49,7 +49,11 @@ def test_resume_training(sample_model, sample_data):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
checkpoint_cb = ModelCheckpointCallback(tmpdir, save_best_only=False)
|
||||
sample_model.fit(sample_data, num_epoch=1, callbacks=[checkpoint_cb])
|
||||
|
||||
|
||||
# Проверим, что чекпоинт создан
|
||||
files = os.listdir(tmpdir)
|
||||
assert any(f.startswith("checkpoint_epoch_") for f in files)
|
||||
|
||||
new_model = GPT(
|
||||
vocab_size=100,
|
||||
max_seq_len=10,
|
||||
@@ -58,18 +62,21 @@ def test_resume_training(sample_model, sample_data):
|
||||
head_size=8,
|
||||
num_layers=2
|
||||
)
|
||||
|
||||
|
||||
resume_cb = ResumeTrainingCallback(tmpdir)
|
||||
resume_cb.on_train_begin(new_model)
|
||||
# После создания чекпоинта и on_train_begin, last_epoch должен быть 0
|
||||
assert resume_cb.last_epoch == 0
|
||||
|
||||
# Дальнейшее обучение с fit/resume
|
||||
new_model.fit(
|
||||
sample_data,
|
||||
num_epoch=2,
|
||||
sample_data,
|
||||
num_epoch=2,
|
||||
callbacks=[resume_cb, checkpoint_cb],
|
||||
resume_training=True
|
||||
)
|
||||
|
||||
assert resume_cb.last_epoch == 0
|
||||
files = os.listdir(tmpdir)
|
||||
assert len(files) == 2
|
||||
files = [f for f in os.listdir(tmpdir) if f.startswith('checkpoint_epoch_')]
|
||||
assert len(files) == 3
|
||||
|
||||
def test_resume_with_missing_checkpoint(sample_model, sample_data):
|
||||
"""Тестирует поведение при отсутствии чекпоинтов"""
|
||||
@@ -95,13 +102,13 @@ def test_resume_with_corrupted_checkpoint(sample_model, sample_data):
|
||||
|
||||
resume_cb = ResumeTrainingCallback(tmpdir)
|
||||
|
||||
with pytest.raises(Exception):
|
||||
sample_model.fit(
|
||||
sample_data,
|
||||
num_epoch=1,
|
||||
callbacks=[resume_cb],
|
||||
resume_training=True
|
||||
)
|
||||
sample_model.fit(
|
||||
sample_data,
|
||||
num_epoch=1,
|
||||
callbacks=[resume_cb],
|
||||
resume_training=True
|
||||
)
|
||||
assert resume_cb.last_epoch == -1
|
||||
|
||||
def test_optimizer_state_restoration(sample_model, sample_data):
|
||||
"""Тестирует восстановление состояния оптимизатора"""
|
||||
|
||||
Reference in New Issue
Block a user