mirror of
https://github.com/pese-git/simple-llm.git
synced 2026-01-24 05:26:02 +00:00
docs: актуализировано описание автодовосстановления обучения и обновлена логика эпох в GPT
This commit is contained in:
@@ -8,6 +8,7 @@ from simple_llm.transformer.decoder import Decoder
|
||||
from simple_llm.transformer.callback import (
|
||||
EarlyStoppingCallback,
|
||||
LRSchedulerCallback,
|
||||
ResumeTrainingCallback,
|
||||
ModelCheckpointCallback
|
||||
)
|
||||
|
||||
@@ -294,7 +295,9 @@ class GPT(nn.Module):
|
||||
num_epoch: int = 1,
|
||||
learning_rate: float = 0.001,
|
||||
callbacks=None,
|
||||
checkpoint_dir=None
|
||||
checkpoint_dir=None,
|
||||
resume_training=False,
|
||||
start_epoch=0
|
||||
):
|
||||
"""Обучает модель GPT с поддержкой callback-системы.
|
||||
|
||||
@@ -360,27 +363,44 @@ class GPT(nn.Module):
|
||||
device = torch.device(self._device)
|
||||
self.to(device)
|
||||
|
||||
# Инициализация callbacks по умолчанию
|
||||
# Инициализация callback-ов
|
||||
if callbacks is None:
|
||||
callbacks = []
|
||||
|
||||
# Добавляем обязательные callback-и
|
||||
default_callbacks = [
|
||||
# Добавляем Resume callback если нужно
|
||||
if resume_training and checkpoint_dir:
|
||||
print(f"Восстановление обучения GPT")
|
||||
callbacks.insert(0, ResumeTrainingCallback(checkpoint_dir, resume=True))
|
||||
|
||||
# Стандартные callback-и
|
||||
callbacks.extend([
|
||||
EarlyStoppingCallback(patience=5),
|
||||
LRSchedulerCallback(lr=learning_rate)
|
||||
]
|
||||
callbacks = default_callbacks + callbacks
|
||||
if checkpoint_dir is not None:
|
||||
callbacks = [ModelCheckpointCallback(checkpoint_dir)] + callbacks
|
||||
|
||||
])
|
||||
|
||||
if checkpoint_dir:
|
||||
callbacks.append(ModelCheckpointCallback(checkpoint_dir))
|
||||
|
||||
# Инициализация оптимизатора ДО вызова callback'ов
|
||||
self.optimizer = torch.optim.Adam(self.parameters(), lr=learning_rate)
|
||||
|
||||
# Определяем стартовую эпоху
|
||||
start_epoch = 0
|
||||
for cb in callbacks:
|
||||
if isinstance(cb, ResumeTrainingCallback) and cb.last_epoch >= 0:
|
||||
start_epoch = cb.last_epoch + 1
|
||||
break
|
||||
|
||||
# Начало обучения (после инициализации оптимизатора)
|
||||
for cb in callbacks:
|
||||
cb.on_train_begin(self)
|
||||
|
||||
print(f"\nНачало обучения GPT на {num_epoch} эпох")
|
||||
print(f"\nНачало обучения GPT на {num_epoch} эпох (с {start_epoch} по {start_epoch + num_epoch - 1})")
|
||||
print(f"Размер батча: {train_loader.batch_size}")
|
||||
print(f"Всего батчей: {len(train_loader)}")
|
||||
print(f"Устройство: {device}\n")
|
||||
|
||||
for epoch in range(num_epoch):
|
||||
for epoch in range(start_epoch, start_epoch + num_epoch):
|
||||
# Вызов callback перед эпохой
|
||||
for cb in callbacks:
|
||||
cb.on_epoch_begin(epoch, self)
|
||||
|
||||
Reference in New Issue
Block a user