mirror of
https://github.com/pese-git/simple-llm.git
synced 2026-01-23 21:14:17 +00:00
docs: актуализировано описание автодовосстановления обучения и обновлена логика эпох в GPT
This commit is contained in:
14
README.md
14
README.md
@@ -79,6 +79,20 @@ python bin/train_gpt_model.py \
|
||||
--num-layers 2
|
||||
```
|
||||
|
||||
### ✔️ Восстановление обучения с чекпоинта
|
||||
|
||||
Если обучение было прервано или вы хотите дообучить модель:
|
||||
- Просто перезапустите команду обучения с теми же параметрами (`bin/train_gpt_model.py ...`).
|
||||
- Скрипт сам определит последний checkpoint (checkpoint_epoch_X.pt) в папке модели.
|
||||
- Обучение продолжится с нужной эпохи, параметры останутся неизменными.
|
||||
- В консоли вы увидите сообщения вроде:
|
||||
```
|
||||
⚡ Восстанавливаем обучение с эпохи 12
|
||||
Восстановление обучения GPT
|
||||
...
|
||||
Начало обучения GPT на 18 эпох (с 12 по 29)
|
||||
```
|
||||
|
||||
### 4. Генерация текста
|
||||
```bash
|
||||
python bin/generate_text.py \
|
||||
|
||||
@@ -73,11 +73,22 @@ def main():
|
||||
)
|
||||
|
||||
# Обучение
|
||||
# Определяем стартовую эпоху
|
||||
start_epoch = 0
|
||||
if os.path.exists(output_dir):
|
||||
checkpoint_files = [f for f in os.listdir(output_dir) if f.startswith('checkpoint_epoch_')]
|
||||
if checkpoint_files:
|
||||
last_epoch = max([int(f.split('_')[2].split('.')[0]) for f in checkpoint_files])
|
||||
start_epoch = last_epoch + 1
|
||||
print(f"⚡ Восстанавливаем обучение с эпохи {start_epoch}")
|
||||
|
||||
model.fit(
|
||||
train_loader=loader,
|
||||
num_epoch=args.epochs,
|
||||
num_epoch=args.epochs - start_epoch,
|
||||
learning_rate=args.lr,
|
||||
checkpoint_dir=output_dir
|
||||
checkpoint_dir=output_dir,
|
||||
resume_training=True,
|
||||
start_epoch=start_epoch
|
||||
)
|
||||
torch.save(model.state_dict(), args.output)
|
||||
|
||||
|
||||
@@ -68,6 +68,21 @@ dataset = GetData(data=all_tokens, seq_len=seq_len, device='cuda')
|
||||
```
|
||||
|
||||
## 5. Обучение модели с помощью fit()
|
||||
|
||||
### Автоматическое восстановление обучения
|
||||
|
||||
- При запуске скрипта обучения в директории модели автоматически ищется последний checkpoint (`checkpoint_epoch_X.pt`) и обучение продолжается с соответствующей эпохи.
|
||||
- Не нужно вручную указывать эпоху старта — всё происходит автоматически.
|
||||
- Логи отображают, с какой эпохи продолжается обучение и сколько эпох осталось до конца (см. вывод в консоли).
|
||||
|
||||
**Пример из логов:**
|
||||
```
|
||||
⚡ Восстанавливаем обучение с эпохи 12
|
||||
Восстановление обучения GPT
|
||||
...
|
||||
Начало обучения GPT на 18 эпох (с 12 по 29)
|
||||
```
|
||||
- Чтобы воспользоваться функцией автодовосстановления, просто повторно запустите команду обучения с теми же CLI-параметрами!
|
||||
```python
|
||||
from torch.utils.data import DataLoader
|
||||
from simple_llm.transformer.gpt import GPT
|
||||
|
||||
@@ -8,6 +8,7 @@ from simple_llm.transformer.decoder import Decoder
|
||||
from simple_llm.transformer.callback import (
|
||||
EarlyStoppingCallback,
|
||||
LRSchedulerCallback,
|
||||
ResumeTrainingCallback,
|
||||
ModelCheckpointCallback
|
||||
)
|
||||
|
||||
@@ -294,7 +295,9 @@ class GPT(nn.Module):
|
||||
num_epoch: int = 1,
|
||||
learning_rate: float = 0.001,
|
||||
callbacks=None,
|
||||
checkpoint_dir=None
|
||||
checkpoint_dir=None,
|
||||
resume_training=False,
|
||||
start_epoch=0
|
||||
):
|
||||
"""Обучает модель GPT с поддержкой callback-системы.
|
||||
|
||||
@@ -360,27 +363,44 @@ class GPT(nn.Module):
|
||||
device = torch.device(self._device)
|
||||
self.to(device)
|
||||
|
||||
# Инициализация callbacks по умолчанию
|
||||
# Инициализация callback-ов
|
||||
if callbacks is None:
|
||||
callbacks = []
|
||||
|
||||
# Добавляем обязательные callback-и
|
||||
default_callbacks = [
|
||||
# Добавляем Resume callback если нужно
|
||||
if resume_training and checkpoint_dir:
|
||||
print(f"Восстановление обучения GPT")
|
||||
callbacks.insert(0, ResumeTrainingCallback(checkpoint_dir, resume=True))
|
||||
|
||||
# Стандартные callback-и
|
||||
callbacks.extend([
|
||||
EarlyStoppingCallback(patience=5),
|
||||
LRSchedulerCallback(lr=learning_rate)
|
||||
]
|
||||
callbacks = default_callbacks + callbacks
|
||||
if checkpoint_dir is not None:
|
||||
callbacks = [ModelCheckpointCallback(checkpoint_dir)] + callbacks
|
||||
])
|
||||
|
||||
if checkpoint_dir:
|
||||
callbacks.append(ModelCheckpointCallback(checkpoint_dir))
|
||||
|
||||
# Инициализация оптимизатора ДО вызова callback'ов
|
||||
self.optimizer = torch.optim.Adam(self.parameters(), lr=learning_rate)
|
||||
|
||||
print(f"\nНачало обучения GPT на {num_epoch} эпох")
|
||||
# Определяем стартовую эпоху
|
||||
start_epoch = 0
|
||||
for cb in callbacks:
|
||||
if isinstance(cb, ResumeTrainingCallback) and cb.last_epoch >= 0:
|
||||
start_epoch = cb.last_epoch + 1
|
||||
break
|
||||
|
||||
# Начало обучения (после инициализации оптимизатора)
|
||||
for cb in callbacks:
|
||||
cb.on_train_begin(self)
|
||||
|
||||
print(f"\nНачало обучения GPT на {num_epoch} эпох (с {start_epoch} по {start_epoch + num_epoch - 1})")
|
||||
print(f"Размер батча: {train_loader.batch_size}")
|
||||
print(f"Всего батчей: {len(train_loader)}")
|
||||
print(f"Устройство: {device}\n")
|
||||
|
||||
for epoch in range(num_epoch):
|
||||
for epoch in range(start_epoch, start_epoch + num_epoch):
|
||||
# Вызов callback перед эпохой
|
||||
for cb in callbacks:
|
||||
cb.on_epoch_begin(epoch, self)
|
||||
|
||||
Reference in New Issue
Block a user