docs: актуализировано описание автодовосстановления обучения и обновлена логика эпох в GPT

This commit is contained in:
Sergey Penkovsky
2025-07-30 22:06:26 +03:00
parent 789d2f3848
commit 25c067af4a
4 changed files with 73 additions and 13 deletions

View File

@@ -8,6 +8,7 @@ from simple_llm.transformer.decoder import Decoder
from simple_llm.transformer.callback import (
EarlyStoppingCallback,
LRSchedulerCallback,
ResumeTrainingCallback,
ModelCheckpointCallback
)
@@ -294,7 +295,9 @@ class GPT(nn.Module):
num_epoch: int = 1,
learning_rate: float = 0.001,
callbacks=None,
checkpoint_dir=None
checkpoint_dir=None,
resume_training=False,
start_epoch=0
):
"""Обучает модель GPT с поддержкой callback-системы.
@@ -360,27 +363,44 @@ class GPT(nn.Module):
device = torch.device(self._device)
self.to(device)
# Инициализация callbacks по умолчанию
# Инициализация callback-ов
if callbacks is None:
callbacks = []
# Добавляем обязательные callback-и
default_callbacks = [
# Добавляем Resume callback если нужно
if resume_training and checkpoint_dir:
print(f"Восстановление обучения GPT")
callbacks.insert(0, ResumeTrainingCallback(checkpoint_dir, resume=True))
# Стандартные callback-и
callbacks.extend([
EarlyStoppingCallback(patience=5),
LRSchedulerCallback(lr=learning_rate)
]
callbacks = default_callbacks + callbacks
if checkpoint_dir is not None:
callbacks = [ModelCheckpointCallback(checkpoint_dir)] + callbacks
])
if checkpoint_dir:
callbacks.append(ModelCheckpointCallback(checkpoint_dir))
# Инициализация оптимизатора ДО вызова callback'ов
self.optimizer = torch.optim.Adam(self.parameters(), lr=learning_rate)
# Определяем стартовую эпоху
start_epoch = 0
for cb in callbacks:
if isinstance(cb, ResumeTrainingCallback) and cb.last_epoch >= 0:
start_epoch = cb.last_epoch + 1
break
# Начало обучения (после инициализации оптимизатора)
for cb in callbacks:
cb.on_train_begin(self)
print(f"\nНачало обучения GPT на {num_epoch} эпох")
print(f"\nНачало обучения GPT на {num_epoch} эпох (с {start_epoch} по {start_epoch + num_epoch - 1})")
print(f"Размер батча: {train_loader.batch_size}")
print(f"Всего батчей: {len(train_loader)}")
print(f"Устройство: {device}\n")
for epoch in range(num_epoch):
for epoch in range(start_epoch, start_epoch + num_epoch):
# Вызов callback перед эпохой
for cb in callbacks:
cb.on_epoch_begin(epoch, self)