Fix checkpoint resume logic, global epoch numbering, and robust recovery; update tests for checkpointing; all tests passing

This commit is contained in:
Sergey Penkovsky
2025-08-01 10:40:08 +03:00
parent f7364070f0
commit 08f0356b4d
6 changed files with 113 additions and 53 deletions

View File

@@ -297,7 +297,8 @@ class GPT(nn.Module):
callbacks=None,
checkpoint_dir=None,
resume_training=False,
start_epoch=0
start_epoch=0,
keep_last_n: int = 3
):
"""Обучает модель GPT с поддержкой callback-системы.
@@ -379,7 +380,7 @@ class GPT(nn.Module):
])
if checkpoint_dir:
callbacks.append(ModelCheckpointCallback(checkpoint_dir))
callbacks.append(ModelCheckpointCallback(checkpoint_dir, keep_last_n=keep_last_n))
# Инициализация оптимизатора ДО вызова callback'ов
self.optimizer = torch.optim.Adam(self.parameters(), lr=learning_rate)
@@ -391,19 +392,29 @@ class GPT(nn.Module):
start_epoch = cb.last_epoch + 1
break
# Начало обучения (после инициализации оптимизатора)
for cb in callbacks:
cb.on_train_begin(self)
print(f"\nНачало обучения GPT на {num_epoch} эпох (с {start_epoch} по {start_epoch + num_epoch - 1})")
# Проверка через реальные файлы чекпоинтов (если они удалены или keep_last_n, брать max файл)
import glob, os
if checkpoint_dir is not None:
cp_files = glob.glob(os.path.join(checkpoint_dir, 'checkpoint_epoch_*.pt'))
if cp_files:
try:
max_cp = max([int(os.path.splitext(os.path.basename(f))[0].split('_')[-1]) for f in cp_files])
if max_cp + 1 > start_epoch:
start_epoch = max_cp + 1
except Exception:
pass
final_epoch_num = start_epoch + num_epoch
print(f"\nНачало обучения GPT на {num_epoch} эпох (с {start_epoch + 1} по {final_epoch_num})")
print(f"Размер батча: {train_loader.batch_size}")
print(f"Всего батчей: {len(train_loader)}")
print(f"Устройство: {device}\n")
for epoch in range(start_epoch, start_epoch + num_epoch):
for epoch_local in range(num_epoch):
global_epoch = start_epoch + epoch_local
# Вызов callback перед эпохой
for cb in callbacks:
cb.on_epoch_begin(epoch, self)
cb.on_epoch_begin(global_epoch, self)
self.train()
epoch_loss = 0.0
@@ -411,7 +422,7 @@ class GPT(nn.Module):
# Прогресс-бар для батчей
batch_pbar = tqdm(train_loader,
desc=f"Эпоха {epoch+1}/{num_epoch}",
desc=f"Эпоха {global_epoch+1}/{final_epoch_num}",
leave=False,
bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}]')
@@ -447,7 +458,7 @@ class GPT(nn.Module):
self.train_loss = epoch_loss / len(train_loader)
epoch_time = time.time() - start_time
print(f"\nЭпоха {epoch+1}/{num_epoch} завершена за {epoch_time:.2f} сек")
print(f"\nЭпоха {global_epoch+1}/{final_epoch_num} завершена за {epoch_time:.2f} сек")
print(f"Средний Train Loss: {self.train_loss:.4f}")
if valid_loader is not None:
@@ -456,7 +467,7 @@ class GPT(nn.Module):
with torch.no_grad():
# Прогресс-бар для валидации
valid_pbar = tqdm(valid_loader,
desc=f"Валидация {epoch+1}/{num_epoch}",
desc=f"Валидация {global_epoch+1}/{final_epoch_num}",
leave=False)
for inputs, targets in valid_pbar:
@@ -473,11 +484,17 @@ class GPT(nn.Module):
self.validation_loss = valid_loss / len(valid_loader)
print(f"Средний Val Loss: {self.validation_loss:.4f}")
# Вызов callback после эпохи
# Вызов callback после эпохи — теперь передаём global_epoch (гарантируем уникальность чекпоинта)
stop_training = False
for cb in callbacks:
if cb.on_epoch_end(epoch, self, self.train_loss, self.validation_loss):
stop_training = True
# ModelCheckpointCallback — строго global_epoch
if isinstance(cb, ModelCheckpointCallback):
if cb.on_epoch_end(global_epoch, self, self.train_loss, self.validation_loss):
stop_training = True
else:
# Старые/другие коллбэки — epoch_local
if cb.on_epoch_end(epoch_local, self, self.train_loss, self.validation_loss):
stop_training = True
if stop_training:
print("Обучение остановлено callback-ом")