mirror of
https://github.com/pese-git/simple-llm.git
synced 2026-01-24 05:26:02 +00:00
docs, logic: обновление документации и автодовосстановления обучения модели, актуализация index.md
This commit is contained in:
53
simple_llm/transformer/callback/resume_training_callback.py
Normal file
53
simple_llm/transformer/callback/resume_training_callback.py
Normal file
@@ -0,0 +1,53 @@
|
||||
# /Users/sergey/Projects/ML/simple-llm/simple_llm/transformer/callback/resume_training_callback.py
|
||||
import os
|
||||
import torch
|
||||
from typing import Optional
|
||||
from .callback import Callback
|
||||
|
||||
class ResumeTrainingCallback(Callback):
|
||||
"""Callback для восстановления обучения с последнего чекпоинта"""
|
||||
|
||||
def __init__(self, checkpoint_dir: str, resume: bool = True):
|
||||
"""
|
||||
Args:
|
||||
checkpoint_dir: Путь к директории с чекпоинтами
|
||||
resume: Флаг восстановления обучения (default=True)
|
||||
"""
|
||||
self.checkpoint_dir = checkpoint_dir
|
||||
self.resume = resume
|
||||
self.last_epoch = -1
|
||||
|
||||
def on_train_begin(self, model):
|
||||
if not self.resume:
|
||||
return
|
||||
|
||||
checkpoint_path = self._find_latest_checkpoint()
|
||||
if checkpoint_path:
|
||||
print(f"\n⚡ Восстанавливаем обучение из {checkpoint_path}")
|
||||
checkpoint = torch.load(checkpoint_path, map_location=model._device)
|
||||
|
||||
# Убедимся, что загружаем на правильное устройство
|
||||
model.load_state_dict(checkpoint['model_state_dict'])
|
||||
if 'optimizer_state_dict' in checkpoint:
|
||||
model.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
||||
if 'scheduler_state_dict' in checkpoint and checkpoint['scheduler_state_dict'] is not None:
|
||||
if hasattr(model, 'scheduler'):
|
||||
model.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
||||
self.last_epoch = checkpoint.get('epoch', -1)
|
||||
|
||||
print(f"➔ Продолжаем с эпохи {self.last_epoch + 1}")
|
||||
print(f"➔ Последний loss: {checkpoint.get('train_loss', 'N/A'):.4f}\n")
|
||||
|
||||
def _find_latest_checkpoint(self) -> Optional[str]:
|
||||
if not os.path.exists(self.checkpoint_dir):
|
||||
return None
|
||||
|
||||
checkpoints = [f for f in os.listdir(self.checkpoint_dir)
|
||||
if f.startswith('checkpoint_') and f.endswith('.pt')]
|
||||
|
||||
if not checkpoints:
|
||||
return None
|
||||
|
||||
# Сортируем по времени создания
|
||||
checkpoints.sort(key=lambda x: os.path.getmtime(os.path.join(self.checkpoint_dir, x)))
|
||||
return os.path.join(self.checkpoint_dir, checkpoints[-1])
|
||||
Reference in New Issue
Block a user