Files
simple-llm/simple_llm/transformer/callback/model_checkpoint_callback.py

93 lines
4.0 KiB
Python

from .callback import Callback
import torch
import os
from typing import Optional
class ModelCheckpointCallback(Callback):
"""Сохраняет чекпоинты модели во время обучения.
Пример:
>>> checkpoint = ModelCheckpointCallback('checkpoints/')
>>> model.fit(callbacks=[checkpoint])
Args:
save_dir (str): Директория для сохранения
save_best_only (bool): Если True, сохраняет только при улучшении loss
save_freq (int): Сохранять каждые N эпох (default=1)
monitor (str): Какой loss мониторить ('val' или 'train')
keep_last_n (int): Сколько последних чекпоинтов хранить на диске (по умолчанию 3)
"""
def __init__(self,
save_dir: str,
save_best_only: bool = True,
save_freq: int = 1,
monitor: str = 'val',
keep_last_n: int = 3):
self.save_dir = save_dir
self.save_best_only = save_best_only
self.save_freq = save_freq
self.monitor = monitor
self.keep_last_n = keep_last_n
self.best_loss = float('inf')
# Создаем директорию если её нет
os.makedirs(save_dir, exist_ok=True)
def on_epoch_end(self, global_epoch, model, train_loss, val_loss):
# Решаем какой loss использовать для сравнения
current_loss = val_loss if (self.monitor == 'val' and val_loss is not None) else train_loss
# Сохраняем по расписанию или при улучшении
should_save = (
(global_epoch + 1) % self.save_freq == 0 or # по расписанию
(self.save_best_only and current_loss < self.best_loss) # или если это лучшая модель
)
if should_save:
self.best_loss = current_loss
checkpoint_path = os.path.join(
self.save_dir,
f"checkpoint_epoch_{global_epoch}.pt"
)
# Собираем состояния всех callback'ов
callback_states = {}
if hasattr(model, '_callbacks'):
for cb in model._callbacks:
if hasattr(cb, 'get_state'):
callback_states[cb.__class__.__name__] = cb.get_state()
torch.save({
'epoch': global_epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': model.optimizer.state_dict(),
'train_loss': train_loss,
'val_loss': val_loss,
'best_loss': self.best_loss,
'callback_states': callback_states,
'config': {
'vocab_size': model._vocab_size,
'max_seq_len': model._max_seq_len,
'emb_size': model._emb_size,
'num_heads': model._num_heads,
'head_size': model._head_size,
'num_layers': model._num_layers
}
}, checkpoint_path)
print(f"Модель сохранена в {checkpoint_path} (loss: {current_loss:.4f})")
self._clean_old_checkpoints()
def _clean_old_checkpoints(self):
import glob
files = sorted(
glob.glob(os.path.join(self.save_dir, 'checkpoint_epoch_*.pt')),
key=os.path.getmtime
)
if len(files) > self.keep_last_n:
for file in files[:-self.keep_last_n]:
try:
os.remove(file)
print(f"Удалён старый чекпоинт: {file}")
except Exception as e:
print(f"Ошибка при удалении чекпоинта {file}: {e}")