docs, logic: обновление документации и автодовосстановления обучения модели, актуализация index.md

This commit is contained in:
Sergey Penkovsky
2025-07-30 22:22:20 +03:00
parent 25c067af4a
commit 73e7a164f9
7 changed files with 654 additions and 14 deletions

View File

@@ -7,14 +7,17 @@ Callback-система для управления обучением GPT.
- LRSchedulerCallback - регулировка learning rate
"""
# /Users/sergey/Projects/ML/simple-llm/simple_llm/transformer/callback/__init__.py
from .callback import Callback
from .early_stopping_callback import EarlyStoppingCallback
from .lrs_scheduler_callback import LRSchedulerCallback
from .model_checkpoint_callback import ModelCheckpointCallback
from .resume_training_callback import ResumeTrainingCallback
__all__ = [
'Callback',
'EarlyStoppingCallback',
'LRSchedulerCallback',
'ModelCheckpointCallback'
'LRSchedulerCallback',
'ModelCheckpointCallback',
'ResumeTrainingCallback'
]

View File

@@ -12,6 +12,9 @@ class Callback:
- on_batch_end - после обработки батча
- on_epoch_end - в конце эпохи
"""
def on_train_begin(self, model):
pass
def on_epoch_begin(self, epoch, model):
"""Вызывается перед началом эпохи.

View File

@@ -14,8 +14,22 @@ class LRSchedulerCallback(Callback):
def __init__(self, lr, decay=0.95):
self.base_lr = lr
self.decay = decay
self.last_epoch = -1 # Добавляем отслеживание эпохи
def get_state(self):
return {
'base_lr': self.base_lr,
'decay': self.decay,
'last_epoch': self.last_epoch
}
def set_state(self, state):
self.base_lr = state['base_lr']
self.decay = state['decay']
self.last_epoch = state['last_epoch']
def on_epoch_begin(self, epoch, model):
self.last_epoch = epoch # Сохраняем текущую эпоху
new_lr = self.base_lr * (self.decay ** epoch)
for param_group in model.optimizer.param_groups:
param_group['lr'] = new_lr

View File

@@ -1,6 +1,7 @@
from .callback import Callback
import torch
import os
from typing import Optional
class ModelCheckpointCallback(Callback):
"""Сохраняет чекпоинты модели во время обучения.
@@ -11,24 +12,64 @@ class ModelCheckpointCallback(Callback):
Args:
save_dir (str): Директория для сохранения
save_best_only (bool): Сохранять только лучшие модели
save_best_only (bool): Если True, сохраняет только при улучшении loss
save_freq (int): Сохранять каждые N эпох (default=1)
monitor (str): Какой loss мониторить ('val' или 'train')
"""
def __init__(self, save_dir, save_best_only=True):
def __init__(self,
save_dir: str,
save_best_only: bool = True,
save_freq: int = 1,
monitor: str = 'val'):
self.save_dir = save_dir
self.save_best_only = save_best_only
self.save_freq = save_freq
self.monitor = monitor
self.best_loss = float('inf')
def on_epoch_end(self, epoch, model, train_loss, val_loss):
if not os.path.exists(self.save_dir):
os.makedirs(self.save_dir)
current_loss = val_loss if val_loss else train_loss
# Создаем директорию если её нет
os.makedirs(save_dir, exist_ok=True)
if not self.save_best_only or current_loss < self.best_loss:
def on_epoch_end(self, epoch, model, train_loss, val_loss):
# Решаем какой loss использовать для сравнения
current_loss = val_loss if (self.monitor == 'val' and val_loss is not None) else train_loss
# Сохраняем по расписанию или при улучшении
should_save = (
(epoch + 1) % self.save_freq == 0 or # по расписанию
(self.save_best_only and current_loss < self.best_loss) # или если это лучшая модель
)
if should_save:
self.best_loss = current_loss
path = os.path.join(self.save_dir, f"checkpoint_epoch_{epoch}.pt")
checkpoint_path = os.path.join(
self.save_dir,
f"checkpoint_epoch_{epoch}.pt"
)
# Собираем состояния всех callback'ов
callback_states = {}
if hasattr(model, '_callbacks'):
for cb in model._callbacks:
if hasattr(cb, 'get_state'):
callback_states[cb.__class__.__name__] = cb.get_state()
torch.save({
'epoch': epoch,
'model_state': model.state_dict(),
'loss': current_loss
}, path)
'model_state_dict': model.state_dict(),
'optimizer_state_dict': model.optimizer.state_dict(),
'train_loss': train_loss,
'val_loss': val_loss,
'best_loss': self.best_loss,
'callback_states': callback_states,
'config': {
'vocab_size': model._vocab_size,
'max_seq_len': model._max_seq_len,
'emb_size': model._emb_size,
'num_heads': model._num_heads,
'head_size': model._head_size,
'num_layers': model._num_layers
}
}, checkpoint_path)
print(f"Модель сохранена в {checkpoint_path} (loss: {current_loss:.4f})")

View File

@@ -0,0 +1,53 @@
# /Users/sergey/Projects/ML/simple-llm/simple_llm/transformer/callback/resume_training_callback.py
import os
import torch
from typing import Optional
from .callback import Callback
class ResumeTrainingCallback(Callback):
"""Callback для восстановления обучения с последнего чекпоинта"""
def __init__(self, checkpoint_dir: str, resume: bool = True):
"""
Args:
checkpoint_dir: Путь к директории с чекпоинтами
resume: Флаг восстановления обучения (default=True)
"""
self.checkpoint_dir = checkpoint_dir
self.resume = resume
self.last_epoch = -1
def on_train_begin(self, model):
if not self.resume:
return
checkpoint_path = self._find_latest_checkpoint()
if checkpoint_path:
print(f"\n⚡ Восстанавливаем обучение из {checkpoint_path}")
checkpoint = torch.load(checkpoint_path, map_location=model._device)
# Убедимся, что загружаем на правильное устройство
model.load_state_dict(checkpoint['model_state_dict'])
if 'optimizer_state_dict' in checkpoint:
model.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
if 'scheduler_state_dict' in checkpoint and checkpoint['scheduler_state_dict'] is not None:
if hasattr(model, 'scheduler'):
model.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
self.last_epoch = checkpoint.get('epoch', -1)
print(f"➔ Продолжаем с эпохи {self.last_epoch + 1}")
print(f"➔ Последний loss: {checkpoint.get('train_loss', 'N/A'):.4f}\n")
def _find_latest_checkpoint(self) -> Optional[str]:
if not os.path.exists(self.checkpoint_dir):
return None
checkpoints = [f for f in os.listdir(self.checkpoint_dir)
if f.startswith('checkpoint_') and f.endswith('.pt')]
if not checkpoints:
return None
# Сортируем по времени создания
checkpoints.sort(key=lambda x: os.path.getmtime(os.path.join(self.checkpoint_dir, x)))
return os.path.join(self.checkpoint_dir, checkpoints[-1])