Fix checkpoint resume logic, global epoch numbering, and robust recovery; update tests for checkpointing; all tests passing

This commit is contained in:
Sergey Penkovsky
2025-08-01 10:40:08 +03:00
parent f7364070f0
commit 08f0356b4d
6 changed files with 113 additions and 53 deletions

View File

@@ -15,28 +15,31 @@ class ModelCheckpointCallback(Callback):
save_best_only (bool): Если True, сохраняет только при улучшении loss
save_freq (int): Сохранять каждые N эпох (default=1)
monitor (str): Какой loss мониторить ('val' или 'train')
keep_last_n (int): Сколько последних чекпоинтов хранить на диске (по умолчанию 3)
"""
def __init__(self,
save_dir: str,
save_best_only: bool = True,
save_freq: int = 1,
monitor: str = 'val'):
monitor: str = 'val',
keep_last_n: int = 3):
self.save_dir = save_dir
self.save_best_only = save_best_only
self.save_freq = save_freq
self.monitor = monitor
self.keep_last_n = keep_last_n
self.best_loss = float('inf')
# Создаем директорию если её нет
os.makedirs(save_dir, exist_ok=True)
def on_epoch_end(self, epoch, model, train_loss, val_loss):
def on_epoch_end(self, global_epoch, model, train_loss, val_loss):
# Решаем какой loss использовать для сравнения
current_loss = val_loss if (self.monitor == 'val' and val_loss is not None) else train_loss
# Сохраняем по расписанию или при улучшении
should_save = (
(epoch + 1) % self.save_freq == 0 or # по расписанию
(global_epoch + 1) % self.save_freq == 0 or # по расписанию
(self.save_best_only and current_loss < self.best_loss) # или если это лучшая модель
)
@@ -44,7 +47,7 @@ class ModelCheckpointCallback(Callback):
self.best_loss = current_loss
checkpoint_path = os.path.join(
self.save_dir,
f"checkpoint_epoch_{epoch}.pt"
f"checkpoint_epoch_{global_epoch}.pt"
)
# Собираем состояния всех callback'ов
@@ -55,7 +58,7 @@ class ModelCheckpointCallback(Callback):
callback_states[cb.__class__.__name__] = cb.get_state()
torch.save({
'epoch': epoch,
'epoch': global_epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': model.optimizer.state_dict(),
'train_loss': train_loss,
@@ -73,3 +76,18 @@ class ModelCheckpointCallback(Callback):
}, checkpoint_path)
print(f"Модель сохранена в {checkpoint_path} (loss: {current_loss:.4f})")
self._clean_old_checkpoints()
def _clean_old_checkpoints(self):
import glob
files = sorted(
glob.glob(os.path.join(self.save_dir, 'checkpoint_epoch_*.pt')),
key=os.path.getmtime
)
if len(files) > self.keep_last_n:
for file in files[:-self.keep_last_n]:
try:
os.remove(file)
print(f"Удалён старый чекпоинт: {file}")
except Exception as e:
print(f"Ошибка при удалении чекпоинта {file}: {e}")

View File

@@ -20,23 +20,36 @@ class ResumeTrainingCallback(Callback):
def on_train_begin(self, model):
if not self.resume:
return
checkpoint_path = self._find_latest_checkpoint()
if checkpoint_path:
print(f"\n⚡ Восстанавливаем обучение из {checkpoint_path}")
checkpoint = torch.load(checkpoint_path, map_location=model._device)
# Убедимся, что загружаем на правильное устройство
model.load_state_dict(checkpoint['model_state_dict'])
if 'optimizer_state_dict' in checkpoint:
model.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
if 'scheduler_state_dict' in checkpoint and checkpoint['scheduler_state_dict'] is not None:
if hasattr(model, 'scheduler'):
model.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
self.last_epoch = checkpoint.get('epoch', -1)
print(f"➔ Продолжаем с эпохи {self.last_epoch + 1}")
print(f"➔ Последний loss: {checkpoint.get('train_loss', 'N/A'):.4f}\n")
try:
print(f"\n⚡ Восстанавливаем обучение из {checkpoint_path}")
checkpoint = torch.load(checkpoint_path, map_location=model._device)
# Убедимся, что загружаем на правильное устройство
model.load_state_dict(checkpoint['model_state_dict'])
if 'optimizer_state_dict' in checkpoint:
model.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
if 'scheduler_state_dict' in checkpoint and checkpoint['scheduler_state_dict'] is not None:
if hasattr(model, 'scheduler'):
model.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
self.last_epoch = checkpoint.get('epoch', -1)
print(f"➔ Продолжаем с эпохи {self.last_epoch + 1}")
print(f"➔ Последний loss: {checkpoint.get('train_loss', 'N/A'):.4f}\n")
except Exception as e:
print(f"⚠️ Чекпоинт поврежден или не читается: {checkpoint_path}\n{e}")
# Найти максимальный существующий checkpoint по файловой системе
import glob, os
cp_files = glob.glob(os.path.join(self.checkpoint_dir, 'checkpoint_epoch_*.pt'))
if cp_files:
try:
self.last_epoch = max([int(os.path.splitext(os.path.basename(f))[0].split('_')[-1]) for f in cp_files])
except Exception:
self.last_epoch = -1
else:
self.last_epoch = -1
else:
# Если файлов совсем нет
self.last_epoch = -1
def _find_latest_checkpoint(self) -> Optional[str]:
if not os.path.exists(self.checkpoint_dir):