Fix checkpoint resume logic, global epoch numbering, and robust recovery; update tests for checkpointing; all tests passing

This commit is contained in:
Sergey Penkovsky
2025-08-01 10:40:08 +03:00
parent f7364070f0
commit 08f0356b4d
6 changed files with 113 additions and 53 deletions

View File

@@ -30,6 +30,7 @@ source venv/bin/activate # Linux/Mac
# 3. Установите зависимости
pip install torch
pip install dill tqdm # Основные зависимости для работы
pip install pytest
# Установка simple_llm пакета
pip install .
@@ -94,8 +95,10 @@ python bin/train_gpt_model.py \
--epochs 3 \
--emb-size 64 \
--num-heads 2 \
--num-layers 2
--num-layers 2 \
--keep-last-n 3
```
> Аргумент `--keep-last-n` определяет, сколько последних чекпоинтов (snapshot'ов обучения) хранить в папке модели. По умолчанию — 3. Старые файлы удаляются автоматически для экономии места.
### ✔️ Восстановление обучения с чекпоинта
@@ -130,7 +133,7 @@ python bin/generate_text.py \
# Последовательно выполните:
./bin/train_tokenizer.py --corpus data/corpus/sample --output data/tokenizer/bpe.json
./bin/tokenize_corpus.py --corpus data/corpus/sample --tokenizer data/tokenizer/bpe.json
./bin/train_gpt_model.py --tokens data/tokens/corpus_tokens.pkl --tokenizer data/tokenizer/bpe.json
./bin/train_gpt_model.py --tokens data/tokens/corpus_tokens.pkl --tokenizer data/tokenizer/bpe.json --keep-last-n 3
./bin/generate_text.py --model data/model/gpt_model.pth --tokenizer data/tokenizer/bpe.json --prompt "Привет"
```

View File

@@ -42,6 +42,7 @@ def main():
parser.add_argument('--lr', type=float, default=0.0001,
help='Learning rate')
parser.add_argument('--keep-last-n', type=int, default=3, help='Сколько последних чекпоинтов хранить')
args = parser.parse_args()
# Проверяем и создаем директорию для сохранения
@@ -88,7 +89,8 @@ def main():
learning_rate=args.lr,
checkpoint_dir=output_dir,
resume_training=True,
start_epoch=start_epoch
start_epoch=start_epoch,
keep_last_n=args.keep_last_n
)
torch.save(model.state_dict(), args.output)

View File

@@ -15,28 +15,31 @@ class ModelCheckpointCallback(Callback):
save_best_only (bool): Если True, сохраняет только при улучшении loss
save_freq (int): Сохранять каждые N эпох (default=1)
monitor (str): Какой loss мониторить ('val' или 'train')
keep_last_n (int): Сколько последних чекпоинтов хранить на диске (по умолчанию 3)
"""
def __init__(self,
save_dir: str,
save_best_only: bool = True,
save_freq: int = 1,
monitor: str = 'val'):
monitor: str = 'val',
keep_last_n: int = 3):
self.save_dir = save_dir
self.save_best_only = save_best_only
self.save_freq = save_freq
self.monitor = monitor
self.keep_last_n = keep_last_n
self.best_loss = float('inf')
# Создаем директорию если её нет
os.makedirs(save_dir, exist_ok=True)
def on_epoch_end(self, epoch, model, train_loss, val_loss):
def on_epoch_end(self, global_epoch, model, train_loss, val_loss):
# Решаем какой loss использовать для сравнения
current_loss = val_loss if (self.monitor == 'val' and val_loss is not None) else train_loss
# Сохраняем по расписанию или при улучшении
should_save = (
(epoch + 1) % self.save_freq == 0 or # по расписанию
(global_epoch + 1) % self.save_freq == 0 or # по расписанию
(self.save_best_only and current_loss < self.best_loss) # или если это лучшая модель
)
@@ -44,7 +47,7 @@ class ModelCheckpointCallback(Callback):
self.best_loss = current_loss
checkpoint_path = os.path.join(
self.save_dir,
f"checkpoint_epoch_{epoch}.pt"
f"checkpoint_epoch_{global_epoch}.pt"
)
# Собираем состояния всех callback'ов
@@ -55,7 +58,7 @@ class ModelCheckpointCallback(Callback):
callback_states[cb.__class__.__name__] = cb.get_state()
torch.save({
'epoch': epoch,
'epoch': global_epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': model.optimizer.state_dict(),
'train_loss': train_loss,
@@ -73,3 +76,18 @@ class ModelCheckpointCallback(Callback):
}, checkpoint_path)
print(f"Модель сохранена в {checkpoint_path} (loss: {current_loss:.4f})")
self._clean_old_checkpoints()
def _clean_old_checkpoints(self):
import glob
files = sorted(
glob.glob(os.path.join(self.save_dir, 'checkpoint_epoch_*.pt')),
key=os.path.getmtime
)
if len(files) > self.keep_last_n:
for file in files[:-self.keep_last_n]:
try:
os.remove(file)
print(f"Удалён старый чекпоинт: {file}")
except Exception as e:
print(f"Ошибка при удалении чекпоинта {file}: {e}")

View File

@@ -20,23 +20,36 @@ class ResumeTrainingCallback(Callback):
def on_train_begin(self, model):
if not self.resume:
return
checkpoint_path = self._find_latest_checkpoint()
if checkpoint_path:
print(f"\n⚡ Восстанавливаем обучение из {checkpoint_path}")
checkpoint = torch.load(checkpoint_path, map_location=model._device)
# Убедимся, что загружаем на правильное устройство
model.load_state_dict(checkpoint['model_state_dict'])
if 'optimizer_state_dict' in checkpoint:
model.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
if 'scheduler_state_dict' in checkpoint and checkpoint['scheduler_state_dict'] is not None:
if hasattr(model, 'scheduler'):
model.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
self.last_epoch = checkpoint.get('epoch', -1)
print(f"➔ Продолжаем с эпохи {self.last_epoch + 1}")
print(f"➔ Последний loss: {checkpoint.get('train_loss', 'N/A'):.4f}\n")
try:
print(f"\n⚡ Восстанавливаем обучение из {checkpoint_path}")
checkpoint = torch.load(checkpoint_path, map_location=model._device)
# Убедимся, что загружаем на правильное устройство
model.load_state_dict(checkpoint['model_state_dict'])
if 'optimizer_state_dict' in checkpoint:
model.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
if 'scheduler_state_dict' in checkpoint and checkpoint['scheduler_state_dict'] is not None:
if hasattr(model, 'scheduler'):
model.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
self.last_epoch = checkpoint.get('epoch', -1)
print(f"➔ Продолжаем с эпохи {self.last_epoch + 1}")
print(f"➔ Последний loss: {checkpoint.get('train_loss', 'N/A'):.4f}\n")
except Exception as e:
print(f"⚠️ Чекпоинт поврежден или не читается: {checkpoint_path}\n{e}")
# Найти максимальный существующий checkpoint по файловой системе
import glob, os
cp_files = glob.glob(os.path.join(self.checkpoint_dir, 'checkpoint_epoch_*.pt'))
if cp_files:
try:
self.last_epoch = max([int(os.path.splitext(os.path.basename(f))[0].split('_')[-1]) for f in cp_files])
except Exception:
self.last_epoch = -1
else:
self.last_epoch = -1
else:
# Если файлов совсем нет
self.last_epoch = -1
def _find_latest_checkpoint(self) -> Optional[str]:
if not os.path.exists(self.checkpoint_dir):

View File

@@ -297,7 +297,8 @@ class GPT(nn.Module):
callbacks=None,
checkpoint_dir=None,
resume_training=False,
start_epoch=0
start_epoch=0,
keep_last_n: int = 3
):
"""Обучает модель GPT с поддержкой callback-системы.
@@ -379,7 +380,7 @@ class GPT(nn.Module):
])
if checkpoint_dir:
callbacks.append(ModelCheckpointCallback(checkpoint_dir))
callbacks.append(ModelCheckpointCallback(checkpoint_dir, keep_last_n=keep_last_n))
# Инициализация оптимизатора ДО вызова callback'ов
self.optimizer = torch.optim.Adam(self.parameters(), lr=learning_rate)
@@ -391,19 +392,29 @@ class GPT(nn.Module):
start_epoch = cb.last_epoch + 1
break
# Начало обучения (после инициализации оптимизатора)
for cb in callbacks:
cb.on_train_begin(self)
# Проверка через реальные файлы чекпоинтов (если они удалены или keep_last_n, брать max файл)
import glob, os
if checkpoint_dir is not None:
cp_files = glob.glob(os.path.join(checkpoint_dir, 'checkpoint_epoch_*.pt'))
if cp_files:
try:
max_cp = max([int(os.path.splitext(os.path.basename(f))[0].split('_')[-1]) for f in cp_files])
if max_cp + 1 > start_epoch:
start_epoch = max_cp + 1
except Exception:
pass
print(f"\nНачало обучения GPT на {num_epoch} эпох (с {start_epoch} по {start_epoch + num_epoch - 1})")
final_epoch_num = start_epoch + num_epoch
print(f"\nНачало обучения GPT на {num_epoch} эпох (с {start_epoch + 1} по {final_epoch_num})")
print(f"Размер батча: {train_loader.batch_size}")
print(f"Всего батчей: {len(train_loader)}")
print(f"Устройство: {device}\n")
for epoch in range(start_epoch, start_epoch + num_epoch):
for epoch_local in range(num_epoch):
global_epoch = start_epoch + epoch_local
# Вызов callback перед эпохой
for cb in callbacks:
cb.on_epoch_begin(epoch, self)
cb.on_epoch_begin(global_epoch, self)
self.train()
epoch_loss = 0.0
@@ -411,7 +422,7 @@ class GPT(nn.Module):
# Прогресс-бар для батчей
batch_pbar = tqdm(train_loader,
desc=f"Эпоха {epoch+1}/{num_epoch}",
desc=f"Эпоха {global_epoch+1}/{final_epoch_num}",
leave=False,
bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}]')
@@ -447,7 +458,7 @@ class GPT(nn.Module):
self.train_loss = epoch_loss / len(train_loader)
epoch_time = time.time() - start_time
print(f"\nЭпоха {epoch+1}/{num_epoch} завершена за {epoch_time:.2f} сек")
print(f"\nЭпоха {global_epoch+1}/{final_epoch_num} завершена за {epoch_time:.2f} сек")
print(f"Средний Train Loss: {self.train_loss:.4f}")
if valid_loader is not None:
@@ -456,7 +467,7 @@ class GPT(nn.Module):
with torch.no_grad():
# Прогресс-бар для валидации
valid_pbar = tqdm(valid_loader,
desc=f"Валидация {epoch+1}/{num_epoch}",
desc=f"Валидация {global_epoch+1}/{final_epoch_num}",
leave=False)
for inputs, targets in valid_pbar:
@@ -473,11 +484,17 @@ class GPT(nn.Module):
self.validation_loss = valid_loss / len(valid_loader)
print(f"Средний Val Loss: {self.validation_loss:.4f}")
# Вызов callback после эпохи
# Вызов callback после эпохи — теперь передаём global_epoch (гарантируем уникальность чекпоинта)
stop_training = False
for cb in callbacks:
if cb.on_epoch_end(epoch, self, self.train_loss, self.validation_loss):
stop_training = True
# ModelCheckpointCallback — строго global_epoch
if isinstance(cb, ModelCheckpointCallback):
if cb.on_epoch_end(global_epoch, self, self.train_loss, self.validation_loss):
stop_training = True
else:
# Старые/другие коллбэки — epoch_local
if cb.on_epoch_end(epoch_local, self, self.train_loss, self.validation_loss):
stop_training = True
if stop_training:
print("Обучение остановлено callback-ом")

View File

@@ -50,6 +50,10 @@ def test_resume_training(sample_model, sample_data):
checkpoint_cb = ModelCheckpointCallback(tmpdir, save_best_only=False)
sample_model.fit(sample_data, num_epoch=1, callbacks=[checkpoint_cb])
# Проверим, что чекпоинт создан
files = os.listdir(tmpdir)
assert any(f.startswith("checkpoint_epoch_") for f in files)
new_model = GPT(
vocab_size=100,
max_seq_len=10,
@@ -60,16 +64,19 @@ def test_resume_training(sample_model, sample_data):
)
resume_cb = ResumeTrainingCallback(tmpdir)
resume_cb.on_train_begin(new_model)
# После создания чекпоинта и on_train_begin, last_epoch должен быть 0
assert resume_cb.last_epoch == 0
# Дальнейшее обучение с fit/resume
new_model.fit(
sample_data,
num_epoch=2,
callbacks=[resume_cb, checkpoint_cb],
resume_training=True
)
assert resume_cb.last_epoch == 0
files = os.listdir(tmpdir)
assert len(files) == 2
files = [f for f in os.listdir(tmpdir) if f.startswith('checkpoint_epoch_')]
assert len(files) == 3
def test_resume_with_missing_checkpoint(sample_model, sample_data):
"""Тестирует поведение при отсутствии чекпоинтов"""
@@ -95,13 +102,13 @@ def test_resume_with_corrupted_checkpoint(sample_model, sample_data):
resume_cb = ResumeTrainingCallback(tmpdir)
with pytest.raises(Exception):
sample_model.fit(
sample_data,
num_epoch=1,
callbacks=[resume_cb],
resume_training=True
)
sample_model.fit(
sample_data,
num_epoch=1,
callbacks=[resume_cb],
resume_training=True
)
assert resume_cb.last_epoch == -1
def test_optimizer_state_restoration(sample_model, sample_data):
"""Тестирует восстановление состояния оптимизатора"""