mirror of
https://github.com/pese-git/simple-llm.git
synced 2026-01-23 21:14:17 +00:00
Fix checkpoint resume logic, global epoch numbering, and robust recovery; update tests for checkpointing; all tests passing
This commit is contained in:
@@ -15,28 +15,31 @@ class ModelCheckpointCallback(Callback):
|
||||
save_best_only (bool): Если True, сохраняет только при улучшении loss
|
||||
save_freq (int): Сохранять каждые N эпох (default=1)
|
||||
monitor (str): Какой loss мониторить ('val' или 'train')
|
||||
keep_last_n (int): Сколько последних чекпоинтов хранить на диске (по умолчанию 3)
|
||||
"""
|
||||
def __init__(self,
|
||||
save_dir: str,
|
||||
save_best_only: bool = True,
|
||||
save_freq: int = 1,
|
||||
monitor: str = 'val'):
|
||||
monitor: str = 'val',
|
||||
keep_last_n: int = 3):
|
||||
self.save_dir = save_dir
|
||||
self.save_best_only = save_best_only
|
||||
self.save_freq = save_freq
|
||||
self.monitor = monitor
|
||||
self.keep_last_n = keep_last_n
|
||||
self.best_loss = float('inf')
|
||||
|
||||
# Создаем директорию если её нет
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
def on_epoch_end(self, epoch, model, train_loss, val_loss):
|
||||
def on_epoch_end(self, global_epoch, model, train_loss, val_loss):
|
||||
# Решаем какой loss использовать для сравнения
|
||||
current_loss = val_loss if (self.monitor == 'val' and val_loss is not None) else train_loss
|
||||
|
||||
# Сохраняем по расписанию или при улучшении
|
||||
should_save = (
|
||||
(epoch + 1) % self.save_freq == 0 or # по расписанию
|
||||
(global_epoch + 1) % self.save_freq == 0 or # по расписанию
|
||||
(self.save_best_only and current_loss < self.best_loss) # или если это лучшая модель
|
||||
)
|
||||
|
||||
@@ -44,7 +47,7 @@ class ModelCheckpointCallback(Callback):
|
||||
self.best_loss = current_loss
|
||||
checkpoint_path = os.path.join(
|
||||
self.save_dir,
|
||||
f"checkpoint_epoch_{epoch}.pt"
|
||||
f"checkpoint_epoch_{global_epoch}.pt"
|
||||
)
|
||||
|
||||
# Собираем состояния всех callback'ов
|
||||
@@ -55,7 +58,7 @@ class ModelCheckpointCallback(Callback):
|
||||
callback_states[cb.__class__.__name__] = cb.get_state()
|
||||
|
||||
torch.save({
|
||||
'epoch': epoch,
|
||||
'epoch': global_epoch,
|
||||
'model_state_dict': model.state_dict(),
|
||||
'optimizer_state_dict': model.optimizer.state_dict(),
|
||||
'train_loss': train_loss,
|
||||
@@ -73,3 +76,18 @@ class ModelCheckpointCallback(Callback):
|
||||
}, checkpoint_path)
|
||||
|
||||
print(f"Модель сохранена в {checkpoint_path} (loss: {current_loss:.4f})")
|
||||
self._clean_old_checkpoints()
|
||||
|
||||
def _clean_old_checkpoints(self):
|
||||
import glob
|
||||
files = sorted(
|
||||
glob.glob(os.path.join(self.save_dir, 'checkpoint_epoch_*.pt')),
|
||||
key=os.path.getmtime
|
||||
)
|
||||
if len(files) > self.keep_last_n:
|
||||
for file in files[:-self.keep_last_n]:
|
||||
try:
|
||||
os.remove(file)
|
||||
print(f"Удалён старый чекпоинт: {file}")
|
||||
except Exception as e:
|
||||
print(f"Ошибка при удалении чекпоинта {file}: {e}")
|
||||
@@ -20,23 +20,36 @@ class ResumeTrainingCallback(Callback):
|
||||
def on_train_begin(self, model):
|
||||
if not self.resume:
|
||||
return
|
||||
|
||||
checkpoint_path = self._find_latest_checkpoint()
|
||||
if checkpoint_path:
|
||||
print(f"\n⚡ Восстанавливаем обучение из {checkpoint_path}")
|
||||
checkpoint = torch.load(checkpoint_path, map_location=model._device)
|
||||
|
||||
# Убедимся, что загружаем на правильное устройство
|
||||
model.load_state_dict(checkpoint['model_state_dict'])
|
||||
if 'optimizer_state_dict' in checkpoint:
|
||||
model.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
||||
if 'scheduler_state_dict' in checkpoint and checkpoint['scheduler_state_dict'] is not None:
|
||||
if hasattr(model, 'scheduler'):
|
||||
model.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
||||
self.last_epoch = checkpoint.get('epoch', -1)
|
||||
|
||||
print(f"➔ Продолжаем с эпохи {self.last_epoch + 1}")
|
||||
print(f"➔ Последний loss: {checkpoint.get('train_loss', 'N/A'):.4f}\n")
|
||||
try:
|
||||
print(f"\n⚡ Восстанавливаем обучение из {checkpoint_path}")
|
||||
checkpoint = torch.load(checkpoint_path, map_location=model._device)
|
||||
# Убедимся, что загружаем на правильное устройство
|
||||
model.load_state_dict(checkpoint['model_state_dict'])
|
||||
if 'optimizer_state_dict' in checkpoint:
|
||||
model.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
||||
if 'scheduler_state_dict' in checkpoint and checkpoint['scheduler_state_dict'] is not None:
|
||||
if hasattr(model, 'scheduler'):
|
||||
model.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
||||
self.last_epoch = checkpoint.get('epoch', -1)
|
||||
print(f"➔ Продолжаем с эпохи {self.last_epoch + 1}")
|
||||
print(f"➔ Последний loss: {checkpoint.get('train_loss', 'N/A'):.4f}\n")
|
||||
except Exception as e:
|
||||
print(f"⚠️ Чекпоинт поврежден или не читается: {checkpoint_path}\n{e}")
|
||||
# Найти максимальный существующий checkpoint по файловой системе
|
||||
import glob, os
|
||||
cp_files = glob.glob(os.path.join(self.checkpoint_dir, 'checkpoint_epoch_*.pt'))
|
||||
if cp_files:
|
||||
try:
|
||||
self.last_epoch = max([int(os.path.splitext(os.path.basename(f))[0].split('_')[-1]) for f in cp_files])
|
||||
except Exception:
|
||||
self.last_epoch = -1
|
||||
else:
|
||||
self.last_epoch = -1
|
||||
else:
|
||||
# Если файлов совсем нет
|
||||
self.last_epoch = -1
|
||||
|
||||
def _find_latest_checkpoint(self) -> Optional[str]:
|
||||
if not os.path.exists(self.checkpoint_dir):
|
||||
|
||||
Reference in New Issue
Block a user