mirror of
https://github.com/pese-git/simple-llm.git
synced 2026-01-24 13:48:01 +00:00
docs: актуализировано описание автодовосстановления обучения и обновлена логика эпох в GPT
This commit is contained in:
14
README.md
14
README.md
@@ -79,6 +79,20 @@ python bin/train_gpt_model.py \
|
|||||||
--num-layers 2
|
--num-layers 2
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### ✔️ Восстановление обучения с чекпоинта
|
||||||
|
|
||||||
|
Если обучение было прервано или вы хотите дообучить модель:
|
||||||
|
- Просто перезапустите команду обучения с теми же параметрами (`bin/train_gpt_model.py ...`).
|
||||||
|
- Скрипт сам определит последний checkpoint (checkpoint_epoch_X.pt) в папке модели.
|
||||||
|
- Обучение продолжится с нужной эпохи, параметры останутся неизменными.
|
||||||
|
- В консоли вы увидите сообщения вроде:
|
||||||
|
```
|
||||||
|
⚡ Восстанавливаем обучение с эпохи 12
|
||||||
|
Восстановление обучения GPT
|
||||||
|
...
|
||||||
|
Начало обучения GPT на 18 эпох (с 12 по 29)
|
||||||
|
```
|
||||||
|
|
||||||
### 4. Генерация текста
|
### 4. Генерация текста
|
||||||
```bash
|
```bash
|
||||||
python bin/generate_text.py \
|
python bin/generate_text.py \
|
||||||
|
|||||||
@@ -73,11 +73,22 @@ def main():
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Обучение
|
# Обучение
|
||||||
|
# Определяем стартовую эпоху
|
||||||
|
start_epoch = 0
|
||||||
|
if os.path.exists(output_dir):
|
||||||
|
checkpoint_files = [f for f in os.listdir(output_dir) if f.startswith('checkpoint_epoch_')]
|
||||||
|
if checkpoint_files:
|
||||||
|
last_epoch = max([int(f.split('_')[2].split('.')[0]) for f in checkpoint_files])
|
||||||
|
start_epoch = last_epoch + 1
|
||||||
|
print(f"⚡ Восстанавливаем обучение с эпохи {start_epoch}")
|
||||||
|
|
||||||
model.fit(
|
model.fit(
|
||||||
train_loader=loader,
|
train_loader=loader,
|
||||||
num_epoch=args.epochs,
|
num_epoch=args.epochs - start_epoch,
|
||||||
learning_rate=args.lr,
|
learning_rate=args.lr,
|
||||||
checkpoint_dir=output_dir
|
checkpoint_dir=output_dir,
|
||||||
|
resume_training=True,
|
||||||
|
start_epoch=start_epoch
|
||||||
)
|
)
|
||||||
torch.save(model.state_dict(), args.output)
|
torch.save(model.state_dict(), args.output)
|
||||||
|
|
||||||
|
|||||||
@@ -68,6 +68,21 @@ dataset = GetData(data=all_tokens, seq_len=seq_len, device='cuda')
|
|||||||
```
|
```
|
||||||
|
|
||||||
## 5. Обучение модели с помощью fit()
|
## 5. Обучение модели с помощью fit()
|
||||||
|
|
||||||
|
### Автоматическое восстановление обучения
|
||||||
|
|
||||||
|
- При запуске скрипта обучения в директории модели автоматически ищется последний checkpoint (`checkpoint_epoch_X.pt`) и обучение продолжается с соответствующей эпохи.
|
||||||
|
- Не нужно вручную указывать эпоху старта — всё происходит автоматически.
|
||||||
|
- Логи отображают, с какой эпохи продолжается обучение и сколько эпох осталось до конца (см. вывод в консоли).
|
||||||
|
|
||||||
|
**Пример из логов:**
|
||||||
|
```
|
||||||
|
⚡ Восстанавливаем обучение с эпохи 12
|
||||||
|
Восстановление обучения GPT
|
||||||
|
...
|
||||||
|
Начало обучения GPT на 18 эпох (с 12 по 29)
|
||||||
|
```
|
||||||
|
- Чтобы воспользоваться функцией автодовосстановления, просто повторно запустите команду обучения с теми же CLI-параметрами!
|
||||||
```python
|
```python
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from simple_llm.transformer.gpt import GPT
|
from simple_llm.transformer.gpt import GPT
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ from simple_llm.transformer.decoder import Decoder
|
|||||||
from simple_llm.transformer.callback import (
|
from simple_llm.transformer.callback import (
|
||||||
EarlyStoppingCallback,
|
EarlyStoppingCallback,
|
||||||
LRSchedulerCallback,
|
LRSchedulerCallback,
|
||||||
|
ResumeTrainingCallback,
|
||||||
ModelCheckpointCallback
|
ModelCheckpointCallback
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -294,7 +295,9 @@ class GPT(nn.Module):
|
|||||||
num_epoch: int = 1,
|
num_epoch: int = 1,
|
||||||
learning_rate: float = 0.001,
|
learning_rate: float = 0.001,
|
||||||
callbacks=None,
|
callbacks=None,
|
||||||
checkpoint_dir=None
|
checkpoint_dir=None,
|
||||||
|
resume_training=False,
|
||||||
|
start_epoch=0
|
||||||
):
|
):
|
||||||
"""Обучает модель GPT с поддержкой callback-системы.
|
"""Обучает модель GPT с поддержкой callback-системы.
|
||||||
|
|
||||||
@@ -360,27 +363,44 @@ class GPT(nn.Module):
|
|||||||
device = torch.device(self._device)
|
device = torch.device(self._device)
|
||||||
self.to(device)
|
self.to(device)
|
||||||
|
|
||||||
# Инициализация callbacks по умолчанию
|
# Инициализация callback-ов
|
||||||
if callbacks is None:
|
if callbacks is None:
|
||||||
callbacks = []
|
callbacks = []
|
||||||
|
|
||||||
# Добавляем обязательные callback-и
|
# Добавляем Resume callback если нужно
|
||||||
default_callbacks = [
|
if resume_training and checkpoint_dir:
|
||||||
|
print(f"Восстановление обучения GPT")
|
||||||
|
callbacks.insert(0, ResumeTrainingCallback(checkpoint_dir, resume=True))
|
||||||
|
|
||||||
|
# Стандартные callback-и
|
||||||
|
callbacks.extend([
|
||||||
EarlyStoppingCallback(patience=5),
|
EarlyStoppingCallback(patience=5),
|
||||||
LRSchedulerCallback(lr=learning_rate)
|
LRSchedulerCallback(lr=learning_rate)
|
||||||
]
|
])
|
||||||
callbacks = default_callbacks + callbacks
|
|
||||||
if checkpoint_dir is not None:
|
|
||||||
callbacks = [ModelCheckpointCallback(checkpoint_dir)] + callbacks
|
|
||||||
|
|
||||||
|
if checkpoint_dir:
|
||||||
|
callbacks.append(ModelCheckpointCallback(checkpoint_dir))
|
||||||
|
|
||||||
|
# Инициализация оптимизатора ДО вызова callback'ов
|
||||||
self.optimizer = torch.optim.Adam(self.parameters(), lr=learning_rate)
|
self.optimizer = torch.optim.Adam(self.parameters(), lr=learning_rate)
|
||||||
|
|
||||||
print(f"\nНачало обучения GPT на {num_epoch} эпох")
|
# Определяем стартовую эпоху
|
||||||
|
start_epoch = 0
|
||||||
|
for cb in callbacks:
|
||||||
|
if isinstance(cb, ResumeTrainingCallback) and cb.last_epoch >= 0:
|
||||||
|
start_epoch = cb.last_epoch + 1
|
||||||
|
break
|
||||||
|
|
||||||
|
# Начало обучения (после инициализации оптимизатора)
|
||||||
|
for cb in callbacks:
|
||||||
|
cb.on_train_begin(self)
|
||||||
|
|
||||||
|
print(f"\nНачало обучения GPT на {num_epoch} эпох (с {start_epoch} по {start_epoch + num_epoch - 1})")
|
||||||
print(f"Размер батча: {train_loader.batch_size}")
|
print(f"Размер батча: {train_loader.batch_size}")
|
||||||
print(f"Всего батчей: {len(train_loader)}")
|
print(f"Всего батчей: {len(train_loader)}")
|
||||||
print(f"Устройство: {device}\n")
|
print(f"Устройство: {device}\n")
|
||||||
|
|
||||||
for epoch in range(num_epoch):
|
for epoch in range(start_epoch, start_epoch + num_epoch):
|
||||||
# Вызов callback перед эпохой
|
# Вызов callback перед эпохой
|
||||||
for cb in callbacks:
|
for cb in callbacks:
|
||||||
cb.on_epoch_begin(epoch, self)
|
cb.on_epoch_begin(epoch, self)
|
||||||
|
|||||||
Reference in New Issue
Block a user