Files
simple-llm/simple_llm/transformer/callback/resume_training_callback.py

66 lines
3.3 KiB
Python
Raw Normal View History

# /Users/sergey/Projects/ML/simple-llm/simple_llm/transformer/callback/resume_training_callback.py
import os
import torch
from typing import Optional
from .callback import Callback
class ResumeTrainingCallback(Callback):
"""Callback для восстановления обучения с последнего чекпоинта"""
def __init__(self, checkpoint_dir: str, resume: bool = True):
"""
Args:
checkpoint_dir: Путь к директории с чекпоинтами
resume: Флаг восстановления обучения (default=True)
"""
self.checkpoint_dir = checkpoint_dir
self.resume = resume
self.last_epoch = -1
def on_train_begin(self, model):
if not self.resume:
return
checkpoint_path = self._find_latest_checkpoint()
if checkpoint_path:
try:
print(f"\n⚡ Восстанавливаем обучение из {checkpoint_path}")
checkpoint = torch.load(checkpoint_path, map_location=model._device)
# Убедимся, что загружаем на правильное устройство
model.load_state_dict(checkpoint['model_state_dict'])
if 'optimizer_state_dict' in checkpoint:
model.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
if 'scheduler_state_dict' in checkpoint and checkpoint['scheduler_state_dict'] is not None:
if hasattr(model, 'scheduler'):
model.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
self.last_epoch = checkpoint.get('epoch', -1)
print(f"➔ Продолжаем с эпохи {self.last_epoch + 1}")
print(f"➔ Последний loss: {checkpoint.get('train_loss', 'N/A'):.4f}\n")
except Exception as e:
print(f"⚠️ Чекпоинт поврежден или не читается: {checkpoint_path}\n{e}")
# Найти максимальный существующий checkpoint по файловой системе
import glob, os
cp_files = glob.glob(os.path.join(self.checkpoint_dir, 'checkpoint_epoch_*.pt'))
if cp_files:
try:
self.last_epoch = max([int(os.path.splitext(os.path.basename(f))[0].split('_')[-1]) for f in cp_files])
except Exception:
self.last_epoch = -1
else:
self.last_epoch = -1
else:
# Если файлов совсем нет
self.last_epoch = -1
def _find_latest_checkpoint(self) -> Optional[str]:
if not os.path.exists(self.checkpoint_dir):
return None
checkpoints = [f for f in os.listdir(self.checkpoint_dir)
if f.startswith('checkpoint_') and f.endswith('.pt')]
if not checkpoints:
return None
# Сортируем по времени создания
checkpoints.sort(key=lambda x: os.path.getmtime(os.path.join(self.checkpoint_dir, x)))
return os.path.join(self.checkpoint_dir, checkpoints[-1])