docs: актуализировано описание автодовосстановления обучения и обновлена логика эпох в GPT

This commit is contained in:
Sergey Penkovsky
2025-07-30 22:06:26 +03:00
parent 789d2f3848
commit 25c067af4a
4 changed files with 73 additions and 13 deletions

View File

@@ -79,6 +79,20 @@ python bin/train_gpt_model.py \
--num-layers 2 --num-layers 2
``` ```
### ✔️ Восстановление обучения с чекпоинта
Если обучение было прервано или вы хотите дообучить модель:
- Просто перезапустите команду обучения с теми же параметрами (`bin/train_gpt_model.py ...`).
- Скрипт сам определит последний checkpoint (checkpoint_epoch_X.pt) в папке модели.
- Обучение продолжится с нужной эпохи, параметры останутся неизменными.
- В консоли вы увидите сообщения вроде:
```
⚡ Восстанавливаем обучение с эпохи 12
Восстановление обучения GPT
...
Начало обучения GPT на 18 эпох (с 12 по 29)
```
### 4. Генерация текста ### 4. Генерация текста
```bash ```bash
python bin/generate_text.py \ python bin/generate_text.py \

View File

@@ -73,11 +73,22 @@ def main():
) )
# Обучение # Обучение
# Определяем стартовую эпоху
start_epoch = 0
if os.path.exists(output_dir):
checkpoint_files = [f for f in os.listdir(output_dir) if f.startswith('checkpoint_epoch_')]
if checkpoint_files:
last_epoch = max([int(f.split('_')[2].split('.')[0]) for f in checkpoint_files])
start_epoch = last_epoch + 1
print(f"⚡ Восстанавливаем обучение с эпохи {start_epoch}")
model.fit( model.fit(
train_loader=loader, train_loader=loader,
num_epoch=args.epochs, num_epoch=args.epochs - start_epoch,
learning_rate=args.lr, learning_rate=args.lr,
checkpoint_dir=output_dir checkpoint_dir=output_dir,
resume_training=True,
start_epoch=start_epoch
) )
torch.save(model.state_dict(), args.output) torch.save(model.state_dict(), args.output)

View File

@@ -68,6 +68,21 @@ dataset = GetData(data=all_tokens, seq_len=seq_len, device='cuda')
``` ```
## 5. Обучение модели с помощью fit() ## 5. Обучение модели с помощью fit()
### Автоматическое восстановление обучения
- При запуске скрипта обучения в директории модели автоматически ищется последний checkpoint (`checkpoint_epoch_X.pt`) и обучение продолжается с соответствующей эпохи.
- Не нужно вручную указывать эпоху старта — всё происходит автоматически.
- Логи отображают, с какой эпохи продолжается обучение и сколько эпох осталось до конца (см. вывод в консоли).
**Пример из логов:**
```
⚡ Восстанавливаем обучение с эпохи 12
Восстановление обучения GPT
...
Начало обучения GPT на 18 эпох (с 12 по 29)
```
- Чтобы воспользоваться функцией автодовосстановления, просто повторно запустите команду обучения с теми же CLI-параметрами!
```python ```python
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from simple_llm.transformer.gpt import GPT from simple_llm.transformer.gpt import GPT

View File

@@ -8,6 +8,7 @@ from simple_llm.transformer.decoder import Decoder
from simple_llm.transformer.callback import ( from simple_llm.transformer.callback import (
EarlyStoppingCallback, EarlyStoppingCallback,
LRSchedulerCallback, LRSchedulerCallback,
ResumeTrainingCallback,
ModelCheckpointCallback ModelCheckpointCallback
) )
@@ -294,7 +295,9 @@ class GPT(nn.Module):
num_epoch: int = 1, num_epoch: int = 1,
learning_rate: float = 0.001, learning_rate: float = 0.001,
callbacks=None, callbacks=None,
checkpoint_dir=None checkpoint_dir=None,
resume_training=False,
start_epoch=0
): ):
"""Обучает модель GPT с поддержкой callback-системы. """Обучает модель GPT с поддержкой callback-системы.
@@ -360,27 +363,44 @@ class GPT(nn.Module):
device = torch.device(self._device) device = torch.device(self._device)
self.to(device) self.to(device)
# Инициализация callbacks по умолчанию # Инициализация callback-ов
if callbacks is None: if callbacks is None:
callbacks = [] callbacks = []
# Добавляем обязательные callback-и # Добавляем Resume callback если нужно
default_callbacks = [ if resume_training and checkpoint_dir:
print(f"Восстановление обучения GPT")
callbacks.insert(0, ResumeTrainingCallback(checkpoint_dir, resume=True))
# Стандартные callback-и
callbacks.extend([
EarlyStoppingCallback(patience=5), EarlyStoppingCallback(patience=5),
LRSchedulerCallback(lr=learning_rate) LRSchedulerCallback(lr=learning_rate)
] ])
callbacks = default_callbacks + callbacks
if checkpoint_dir is not None: if checkpoint_dir:
callbacks = [ModelCheckpointCallback(checkpoint_dir)] + callbacks callbacks.append(ModelCheckpointCallback(checkpoint_dir))
# Инициализация оптимизатора ДО вызова callback'ов
self.optimizer = torch.optim.Adam(self.parameters(), lr=learning_rate) self.optimizer = torch.optim.Adam(self.parameters(), lr=learning_rate)
# Определяем стартовую эпоху
start_epoch = 0
for cb in callbacks:
if isinstance(cb, ResumeTrainingCallback) and cb.last_epoch >= 0:
start_epoch = cb.last_epoch + 1
break
# Начало обучения (после инициализации оптимизатора)
for cb in callbacks:
cb.on_train_begin(self)
print(f"\nНачало обучения GPT на {num_epoch} эпох") print(f"\nНачало обучения GPT на {num_epoch} эпох (с {start_epoch} по {start_epoch + num_epoch - 1})")
print(f"Размер батча: {train_loader.batch_size}") print(f"Размер батча: {train_loader.batch_size}")
print(f"Всего батчей: {len(train_loader)}") print(f"Всего батчей: {len(train_loader)}")
print(f"Устройство: {device}\n") print(f"Устройство: {device}\n")
for epoch in range(num_epoch): for epoch in range(start_epoch, start_epoch + num_epoch):
# Вызов callback перед эпохой # Вызов callback перед эпохой
for cb in callbacks: for cb in callbacks:
cb.on_epoch_begin(epoch, self) cb.on_epoch_begin(epoch, self)