Files
simple-llm/simple_llm/transformer/callback/resume_training_callback.py

66 lines
3.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# /Users/sergey/Projects/ML/simple-llm/simple_llm/transformer/callback/resume_training_callback.py
import os
import torch
from typing import Optional
from .callback import Callback
class ResumeTrainingCallback(Callback):
"""Callback для восстановления обучения с последнего чекпоинта"""
def __init__(self, checkpoint_dir: str, resume: bool = True):
"""
Args:
checkpoint_dir: Путь к директории с чекпоинтами
resume: Флаг восстановления обучения (default=True)
"""
self.checkpoint_dir = checkpoint_dir
self.resume = resume
self.last_epoch = -1
def on_train_begin(self, model):
if not self.resume:
return
checkpoint_path = self._find_latest_checkpoint()
if checkpoint_path:
try:
print(f"\n⚡ Восстанавливаем обучение из {checkpoint_path}")
checkpoint = torch.load(checkpoint_path, map_location=model._device)
# Убедимся, что загружаем на правильное устройство
model.load_state_dict(checkpoint['model_state_dict'])
if 'optimizer_state_dict' in checkpoint:
model.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
if 'scheduler_state_dict' in checkpoint and checkpoint['scheduler_state_dict'] is not None:
if hasattr(model, 'scheduler'):
model.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
self.last_epoch = checkpoint.get('epoch', -1)
print(f"➔ Продолжаем с эпохи {self.last_epoch + 1}")
print(f"➔ Последний loss: {checkpoint.get('train_loss', 'N/A'):.4f}\n")
except Exception as e:
print(f"⚠️ Чекпоинт поврежден или не читается: {checkpoint_path}\n{e}")
# Найти максимальный существующий checkpoint по файловой системе
import glob, os
cp_files = glob.glob(os.path.join(self.checkpoint_dir, 'checkpoint_epoch_*.pt'))
if cp_files:
try:
self.last_epoch = max([int(os.path.splitext(os.path.basename(f))[0].split('_')[-1]) for f in cp_files])
except Exception:
self.last_epoch = -1
else:
self.last_epoch = -1
else:
# Если файлов совсем нет
self.last_epoch = -1
def _find_latest_checkpoint(self) -> Optional[str]:
if not os.path.exists(self.checkpoint_dir):
return None
checkpoints = [f for f in os.listdir(self.checkpoint_dir)
if f.startswith('checkpoint_') and f.endswith('.pt')]
if not checkpoints:
return None
# Сортируем по времени создания
checkpoints.sort(key=lambda x: os.path.getmtime(os.path.join(self.checkpoint_dir, x)))
return os.path.join(self.checkpoint_dir, checkpoints[-1])