mirror of
https://github.com/pese-git/simple-llm.git
synced 2026-01-24 05:26:02 +00:00
Реализация и документирование метода fit() для обучения GPT
Основные изменения: 1. Реализация метода fit(): - Добавлен полный цикл обучения (forward/backward pass) - Поддержка обучения на CPU/GPU - Расчет и сохранение метрик (train_loss, validation_loss) - Интеграция с оптимизатором Adam 2. Документация: - Подробное описание метода в gpt_documentation_ru.md - Примеры использования в README.md - Параметры и требования к данным 3. Тестирование: - Тесты базовой функциональности - Проверка изменения весов - Тесты для разных устройств (CPU/CUDA) - Обработка edge-cases 4. Примеры: - train_gpt_example.py с полным workflow - Генерация синтетических данных - Сохранение/загрузка моделей
This commit is contained in:
@@ -134,7 +134,44 @@ model = GPT(
|
||||
| Повторы в генерации | Настройка repetition_penalty |
|
||||
| Медленная генерация | Кэширование ключей/значений |
|
||||
|
||||
## 6. Дополнительные материалы
|
||||
## 6. Обучение модели
|
||||
|
||||
Метод `fit()` позволяет обучать модель GPT на ваших данных.
|
||||
|
||||
### 6.1 Параметры обучения
|
||||
```python
|
||||
model.fit(
|
||||
train_loader, # DataLoader с обучающими данными
|
||||
valid_loader=None, # Опциональный DataLoader для валидации
|
||||
num_epoch=10, # Количество эпох
|
||||
learning_rate=0.001 # Скорость обучения
|
||||
)
|
||||
```
|
||||
|
||||
### 6.2 Пример обучения
|
||||
```python
|
||||
from torch.utils.data import DataLoader, TensorDataset
|
||||
|
||||
# Создание синтетических данных
|
||||
inputs = torch.randint(0, 10000, (1000, 32)) # 1000 примеров по 32 токена
|
||||
targets = torch.roll(inputs, -1, dims=1) # Сдвиг на 1 токен вперед
|
||||
|
||||
dataset = TensorDataset(inputs, targets)
|
||||
train_loader = DataLoader(dataset, batch_size=32)
|
||||
|
||||
model = GPT(...) # Инициализация модели
|
||||
model.fit(train_loader, num_epoch=5)
|
||||
```
|
||||
|
||||
### 6.3 Особенности обучения
|
||||
- Автоматическое определение устройства (CPU/GPU)
|
||||
- Сохранение истории потерь в атрибутах:
|
||||
- `model.train_loss` - потери на обучении
|
||||
- `model.validation_loss` - потери на валидации (если есть)
|
||||
- Поддержка ранней остановки (через callback)
|
||||
|
||||
## 7. Дополнительные материалы
|
||||
- [Примеры использования](../example/example_gpt.py)
|
||||
- [Пример обучения](../example/train_gpt_example.py)
|
||||
- [Тесты](../tests/test_gpt.py)
|
||||
- [Оптимизация производительности](performance_guide.md)
|
||||
|
||||
Reference in New Issue
Block a user