mirror of
https://github.com/pese-git/simple-llm.git
synced 2026-01-26 14:51:36 +00:00
Реализация и документирование метода fit() для обучения GPT
Основные изменения: 1. Реализация метода fit(): - Добавлен полный цикл обучения (forward/backward pass) - Поддержка обучения на CPU/GPU - Расчет и сохранение метрик (train_loss, validation_loss) - Интеграция с оптимизатором Adam 2. Документация: - Подробное описание метода в gpt_documentation_ru.md - Примеры использования в README.md - Параметры и требования к данным 3. Тестирование: - Тесты базовой функциональности - Проверка изменения весов - Тесты для разных устройств (CPU/CUDA) - Обработка edge-cases 4. Примеры: - train_gpt_example.py с полным workflow - Генерация синтетических данных - Сохранение/загрузка моделей
This commit is contained in:
33
README.md
33
README.md
@@ -21,9 +21,13 @@ Simple-LLM предоставляет:
|
||||
pip install torch numpy tqdm
|
||||
```
|
||||
|
||||
2. Запустите пример генерации:
|
||||
2. Запустите примеры:
|
||||
```bash
|
||||
# Пример генерации текста
|
||||
python example/example_gpt.py
|
||||
|
||||
# Пример обучения модели
|
||||
python example/train_gpt_example.py
|
||||
```
|
||||
|
||||
## 🧠 Основные компоненты
|
||||
@@ -63,6 +67,33 @@ output = model.generate(
|
||||
)
|
||||
```
|
||||
|
||||
### Обучение модели
|
||||
```python
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
# Данные должны быть в формате (input_ids, targets)
|
||||
# targets - это input_ids, сдвинутые на 1 токен вперед
|
||||
train_loader = DataLoader(...)
|
||||
|
||||
model.fit(
|
||||
train_loader=train_loader, # Обучающие данные (обязательно)
|
||||
valid_loader=None, # Валидационные данные (опционально)
|
||||
num_epoch=10, # Количество эпох
|
||||
learning_rate=0.001 # Скорость обучения
|
||||
)
|
||||
|
||||
# Сохранение модели
|
||||
model.save("model.pt")
|
||||
|
||||
# Загрузка модели
|
||||
loaded_model = GPT.load("model.pt", device="cuda")
|
||||
```
|
||||
|
||||
**Требования к данным:**
|
||||
- Формат: `(input_ids, targets)` где `targets = roll(input_ids, -1)`
|
||||
- `input_ids`: тензор формы `[batch_size, seq_len]`
|
||||
- Поддерживаются как синтетические, так и реальные текстовые данные
|
||||
|
||||
## 📚 Документация
|
||||
|
||||
Полная документация доступна в [doc/](./doc/):
|
||||
|
||||
Reference in New Issue
Block a user