mirror of
https://github.com/pese-git/simple-llm.git
synced 2026-01-24 05:26:02 +00:00
Fix checkpoint resume logic, global epoch numbering, and robust recovery; update tests for checkpointing; all tests passing
This commit is contained in:
@@ -297,7 +297,8 @@ class GPT(nn.Module):
|
||||
callbacks=None,
|
||||
checkpoint_dir=None,
|
||||
resume_training=False,
|
||||
start_epoch=0
|
||||
start_epoch=0,
|
||||
keep_last_n: int = 3
|
||||
):
|
||||
"""Обучает модель GPT с поддержкой callback-системы.
|
||||
|
||||
@@ -379,7 +380,7 @@ class GPT(nn.Module):
|
||||
])
|
||||
|
||||
if checkpoint_dir:
|
||||
callbacks.append(ModelCheckpointCallback(checkpoint_dir))
|
||||
callbacks.append(ModelCheckpointCallback(checkpoint_dir, keep_last_n=keep_last_n))
|
||||
|
||||
# Инициализация оптимизатора ДО вызова callback'ов
|
||||
self.optimizer = torch.optim.Adam(self.parameters(), lr=learning_rate)
|
||||
@@ -391,19 +392,29 @@ class GPT(nn.Module):
|
||||
start_epoch = cb.last_epoch + 1
|
||||
break
|
||||
|
||||
# Начало обучения (после инициализации оптимизатора)
|
||||
for cb in callbacks:
|
||||
cb.on_train_begin(self)
|
||||
|
||||
print(f"\nНачало обучения GPT на {num_epoch} эпох (с {start_epoch} по {start_epoch + num_epoch - 1})")
|
||||
# Проверка через реальные файлы чекпоинтов (если они удалены или keep_last_n, брать max файл)
|
||||
import glob, os
|
||||
if checkpoint_dir is not None:
|
||||
cp_files = glob.glob(os.path.join(checkpoint_dir, 'checkpoint_epoch_*.pt'))
|
||||
if cp_files:
|
||||
try:
|
||||
max_cp = max([int(os.path.splitext(os.path.basename(f))[0].split('_')[-1]) for f in cp_files])
|
||||
if max_cp + 1 > start_epoch:
|
||||
start_epoch = max_cp + 1
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
final_epoch_num = start_epoch + num_epoch
|
||||
print(f"\nНачало обучения GPT на {num_epoch} эпох (с {start_epoch + 1} по {final_epoch_num})")
|
||||
print(f"Размер батча: {train_loader.batch_size}")
|
||||
print(f"Всего батчей: {len(train_loader)}")
|
||||
print(f"Устройство: {device}\n")
|
||||
|
||||
for epoch in range(start_epoch, start_epoch + num_epoch):
|
||||
for epoch_local in range(num_epoch):
|
||||
global_epoch = start_epoch + epoch_local
|
||||
# Вызов callback перед эпохой
|
||||
for cb in callbacks:
|
||||
cb.on_epoch_begin(epoch, self)
|
||||
cb.on_epoch_begin(global_epoch, self)
|
||||
|
||||
self.train()
|
||||
epoch_loss = 0.0
|
||||
@@ -411,7 +422,7 @@ class GPT(nn.Module):
|
||||
|
||||
# Прогресс-бар для батчей
|
||||
batch_pbar = tqdm(train_loader,
|
||||
desc=f"Эпоха {epoch+1}/{num_epoch}",
|
||||
desc=f"Эпоха {global_epoch+1}/{final_epoch_num}",
|
||||
leave=False,
|
||||
bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}]')
|
||||
|
||||
@@ -447,7 +458,7 @@ class GPT(nn.Module):
|
||||
self.train_loss = epoch_loss / len(train_loader)
|
||||
epoch_time = time.time() - start_time
|
||||
|
||||
print(f"\nЭпоха {epoch+1}/{num_epoch} завершена за {epoch_time:.2f} сек")
|
||||
print(f"\nЭпоха {global_epoch+1}/{final_epoch_num} завершена за {epoch_time:.2f} сек")
|
||||
print(f"Средний Train Loss: {self.train_loss:.4f}")
|
||||
|
||||
if valid_loader is not None:
|
||||
@@ -456,7 +467,7 @@ class GPT(nn.Module):
|
||||
with torch.no_grad():
|
||||
# Прогресс-бар для валидации
|
||||
valid_pbar = tqdm(valid_loader,
|
||||
desc=f"Валидация {epoch+1}/{num_epoch}",
|
||||
desc=f"Валидация {global_epoch+1}/{final_epoch_num}",
|
||||
leave=False)
|
||||
|
||||
for inputs, targets in valid_pbar:
|
||||
@@ -473,11 +484,17 @@ class GPT(nn.Module):
|
||||
self.validation_loss = valid_loss / len(valid_loader)
|
||||
print(f"Средний Val Loss: {self.validation_loss:.4f}")
|
||||
|
||||
# Вызов callback после эпохи
|
||||
# Вызов callback после эпохи — теперь передаём global_epoch (гарантируем уникальность чекпоинта)
|
||||
stop_training = False
|
||||
for cb in callbacks:
|
||||
if cb.on_epoch_end(epoch, self, self.train_loss, self.validation_loss):
|
||||
stop_training = True
|
||||
# ModelCheckpointCallback — строго global_epoch
|
||||
if isinstance(cb, ModelCheckpointCallback):
|
||||
if cb.on_epoch_end(global_epoch, self, self.train_loss, self.validation_loss):
|
||||
stop_training = True
|
||||
else:
|
||||
# Старые/другие коллбэки — epoch_local
|
||||
if cb.on_epoch_end(epoch_local, self, self.train_loss, self.validation_loss):
|
||||
stop_training = True
|
||||
|
||||
if stop_training:
|
||||
print("Обучение остановлено callback-ом")
|
||||
|
||||
Reference in New Issue
Block a user