mirror of
https://github.com/pese-git/simple-llm.git
synced 2026-01-23 21:14:17 +00:00
Fix checkpoint resume logic, global epoch numbering, and robust recovery; update tests for checkpointing; all tests passing
This commit is contained in:
@@ -20,23 +20,36 @@ class ResumeTrainingCallback(Callback):
|
||||
def on_train_begin(self, model):
|
||||
if not self.resume:
|
||||
return
|
||||
|
||||
checkpoint_path = self._find_latest_checkpoint()
|
||||
if checkpoint_path:
|
||||
print(f"\n⚡ Восстанавливаем обучение из {checkpoint_path}")
|
||||
checkpoint = torch.load(checkpoint_path, map_location=model._device)
|
||||
|
||||
# Убедимся, что загружаем на правильное устройство
|
||||
model.load_state_dict(checkpoint['model_state_dict'])
|
||||
if 'optimizer_state_dict' in checkpoint:
|
||||
model.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
||||
if 'scheduler_state_dict' in checkpoint and checkpoint['scheduler_state_dict'] is not None:
|
||||
if hasattr(model, 'scheduler'):
|
||||
model.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
||||
self.last_epoch = checkpoint.get('epoch', -1)
|
||||
|
||||
print(f"➔ Продолжаем с эпохи {self.last_epoch + 1}")
|
||||
print(f"➔ Последний loss: {checkpoint.get('train_loss', 'N/A'):.4f}\n")
|
||||
try:
|
||||
print(f"\n⚡ Восстанавливаем обучение из {checkpoint_path}")
|
||||
checkpoint = torch.load(checkpoint_path, map_location=model._device)
|
||||
# Убедимся, что загружаем на правильное устройство
|
||||
model.load_state_dict(checkpoint['model_state_dict'])
|
||||
if 'optimizer_state_dict' in checkpoint:
|
||||
model.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
||||
if 'scheduler_state_dict' in checkpoint and checkpoint['scheduler_state_dict'] is not None:
|
||||
if hasattr(model, 'scheduler'):
|
||||
model.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
||||
self.last_epoch = checkpoint.get('epoch', -1)
|
||||
print(f"➔ Продолжаем с эпохи {self.last_epoch + 1}")
|
||||
print(f"➔ Последний loss: {checkpoint.get('train_loss', 'N/A'):.4f}\n")
|
||||
except Exception as e:
|
||||
print(f"⚠️ Чекпоинт поврежден или не читается: {checkpoint_path}\n{e}")
|
||||
# Найти максимальный существующий checkpoint по файловой системе
|
||||
import glob, os
|
||||
cp_files = glob.glob(os.path.join(self.checkpoint_dir, 'checkpoint_epoch_*.pt'))
|
||||
if cp_files:
|
||||
try:
|
||||
self.last_epoch = max([int(os.path.splitext(os.path.basename(f))[0].split('_')[-1]) for f in cp_files])
|
||||
except Exception:
|
||||
self.last_epoch = -1
|
||||
else:
|
||||
self.last_epoch = -1
|
||||
else:
|
||||
# Если файлов совсем нет
|
||||
self.last_epoch = -1
|
||||
|
||||
def _find_latest_checkpoint(self) -> Optional[str]:
|
||||
if not os.path.exists(self.checkpoint_dir):
|
||||
|
||||
Reference in New Issue
Block a user