Fix checkpoint resume logic, global epoch numbering, and robust recovery; update tests for checkpointing; all tests passing

This commit is contained in:
Sergey Penkovsky
2025-08-01 10:40:08 +03:00
parent f7364070f0
commit 08f0356b4d
6 changed files with 113 additions and 53 deletions

View File

@@ -20,23 +20,36 @@ class ResumeTrainingCallback(Callback):
def on_train_begin(self, model):
if not self.resume:
return
checkpoint_path = self._find_latest_checkpoint()
if checkpoint_path:
print(f"\n⚡ Восстанавливаем обучение из {checkpoint_path}")
checkpoint = torch.load(checkpoint_path, map_location=model._device)
# Убедимся, что загружаем на правильное устройство
model.load_state_dict(checkpoint['model_state_dict'])
if 'optimizer_state_dict' in checkpoint:
model.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
if 'scheduler_state_dict' in checkpoint and checkpoint['scheduler_state_dict'] is not None:
if hasattr(model, 'scheduler'):
model.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
self.last_epoch = checkpoint.get('epoch', -1)
print(f"➔ Продолжаем с эпохи {self.last_epoch + 1}")
print(f"➔ Последний loss: {checkpoint.get('train_loss', 'N/A'):.4f}\n")
try:
print(f"\n⚡ Восстанавливаем обучение из {checkpoint_path}")
checkpoint = torch.load(checkpoint_path, map_location=model._device)
# Убедимся, что загружаем на правильное устройство
model.load_state_dict(checkpoint['model_state_dict'])
if 'optimizer_state_dict' in checkpoint:
model.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
if 'scheduler_state_dict' in checkpoint and checkpoint['scheduler_state_dict'] is not None:
if hasattr(model, 'scheduler'):
model.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
self.last_epoch = checkpoint.get('epoch', -1)
print(f"➔ Продолжаем с эпохи {self.last_epoch + 1}")
print(f"➔ Последний loss: {checkpoint.get('train_loss', 'N/A'):.4f}\n")
except Exception as e:
print(f"⚠️ Чекпоинт поврежден или не читается: {checkpoint_path}\n{e}")
# Найти максимальный существующий checkpoint по файловой системе
import glob, os
cp_files = glob.glob(os.path.join(self.checkpoint_dir, 'checkpoint_epoch_*.pt'))
if cp_files:
try:
self.last_epoch = max([int(os.path.splitext(os.path.basename(f))[0].split('_')[-1]) for f in cp_files])
except Exception:
self.last_epoch = -1
else:
self.last_epoch = -1
else:
# Если файлов совсем нет
self.last_epoch = -1
def _find_latest_checkpoint(self) -> Optional[str]:
if not os.path.exists(self.checkpoint_dir):