# /Users/sergey/Projects/ML/simple-llm/simple_llm/transformer/callback/resume_training_callback.py import os import torch from typing import Optional from .callback import Callback class ResumeTrainingCallback(Callback): """Callback для восстановления обучения с последнего чекпоинта""" def __init__(self, checkpoint_dir: str, resume: bool = True): """ Args: checkpoint_dir: Путь к директории с чекпоинтами resume: Флаг восстановления обучения (default=True) """ self.checkpoint_dir = checkpoint_dir self.resume = resume self.last_epoch = -1 def on_train_begin(self, model): if not self.resume: return checkpoint_path = self._find_latest_checkpoint() if checkpoint_path: try: print(f"\n⚡ Восстанавливаем обучение из {checkpoint_path}") checkpoint = torch.load(checkpoint_path, map_location=model._device) # Убедимся, что загружаем на правильное устройство model.load_state_dict(checkpoint['model_state_dict']) if 'optimizer_state_dict' in checkpoint: model.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) if 'scheduler_state_dict' in checkpoint and checkpoint['scheduler_state_dict'] is not None: if hasattr(model, 'scheduler'): model.scheduler.load_state_dict(checkpoint['scheduler_state_dict']) self.last_epoch = checkpoint.get('epoch', -1) print(f"➔ Продолжаем с эпохи {self.last_epoch + 1}") print(f"➔ Последний loss: {checkpoint.get('train_loss', 'N/A'):.4f}\n") except Exception as e: print(f"⚠️ Чекпоинт поврежден или не читается: {checkpoint_path}\n{e}") # Найти максимальный существующий checkpoint по файловой системе import glob, os cp_files = glob.glob(os.path.join(self.checkpoint_dir, 'checkpoint_epoch_*.pt')) if cp_files: try: self.last_epoch = max([int(os.path.splitext(os.path.basename(f))[0].split('_')[-1]) for f in cp_files]) except Exception: self.last_epoch = -1 else: self.last_epoch = -1 else: # Если файлов совсем нет self.last_epoch = -1 def _find_latest_checkpoint(self) -> Optional[str]: if not os.path.exists(self.checkpoint_dir): return None checkpoints = [f for f in os.listdir(self.checkpoint_dir) if f.startswith('checkpoint_') and f.endswith('.pt')] if not checkpoints: return None # Сортируем по времени создания checkpoints.sort(key=lambda x: os.path.getmtime(os.path.join(self.checkpoint_dir, x))) return os.path.join(self.checkpoint_dir, checkpoints[-1])