mirror of
https://github.com/pese-git/simple-llm.git
synced 2026-01-23 21:14:17 +00:00
66 lines
3.3 KiB
Python
66 lines
3.3 KiB
Python
# /Users/sergey/Projects/ML/simple-llm/simple_llm/transformer/callback/resume_training_callback.py
|
||
import os
|
||
import torch
|
||
from typing import Optional
|
||
from .callback import Callback
|
||
|
||
class ResumeTrainingCallback(Callback):
|
||
"""Callback для восстановления обучения с последнего чекпоинта"""
|
||
|
||
def __init__(self, checkpoint_dir: str, resume: bool = True):
|
||
"""
|
||
Args:
|
||
checkpoint_dir: Путь к директории с чекпоинтами
|
||
resume: Флаг восстановления обучения (default=True)
|
||
"""
|
||
self.checkpoint_dir = checkpoint_dir
|
||
self.resume = resume
|
||
self.last_epoch = -1
|
||
|
||
def on_train_begin(self, model):
|
||
if not self.resume:
|
||
return
|
||
checkpoint_path = self._find_latest_checkpoint()
|
||
if checkpoint_path:
|
||
try:
|
||
print(f"\n⚡ Восстанавливаем обучение из {checkpoint_path}")
|
||
checkpoint = torch.load(checkpoint_path, map_location=model._device)
|
||
# Убедимся, что загружаем на правильное устройство
|
||
model.load_state_dict(checkpoint['model_state_dict'])
|
||
if 'optimizer_state_dict' in checkpoint:
|
||
model.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
||
if 'scheduler_state_dict' in checkpoint and checkpoint['scheduler_state_dict'] is not None:
|
||
if hasattr(model, 'scheduler'):
|
||
model.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
||
self.last_epoch = checkpoint.get('epoch', -1)
|
||
print(f"➔ Продолжаем с эпохи {self.last_epoch + 1}")
|
||
print(f"➔ Последний loss: {checkpoint.get('train_loss', 'N/A'):.4f}\n")
|
||
except Exception as e:
|
||
print(f"⚠️ Чекпоинт поврежден или не читается: {checkpoint_path}\n{e}")
|
||
# Найти максимальный существующий checkpoint по файловой системе
|
||
import glob, os
|
||
cp_files = glob.glob(os.path.join(self.checkpoint_dir, 'checkpoint_epoch_*.pt'))
|
||
if cp_files:
|
||
try:
|
||
self.last_epoch = max([int(os.path.splitext(os.path.basename(f))[0].split('_')[-1]) for f in cp_files])
|
||
except Exception:
|
||
self.last_epoch = -1
|
||
else:
|
||
self.last_epoch = -1
|
||
else:
|
||
# Если файлов совсем нет
|
||
self.last_epoch = -1
|
||
|
||
def _find_latest_checkpoint(self) -> Optional[str]:
|
||
if not os.path.exists(self.checkpoint_dir):
|
||
return None
|
||
|
||
checkpoints = [f for f in os.listdir(self.checkpoint_dir)
|
||
if f.startswith('checkpoint_') and f.endswith('.pt')]
|
||
|
||
if not checkpoints:
|
||
return None
|
||
|
||
# Сортируем по времени создания
|
||
checkpoints.sort(key=lambda x: os.path.getmtime(os.path.join(self.checkpoint_dir, x)))
|
||
return os.path.join(self.checkpoint_dir, checkpoints[-1]) |