mirror of
https://github.com/pese-git/simple-llm.git
synced 2026-01-23 21:14:17 +00:00
Основные изменения: 1. Реализация метода fit(): - Добавлен полный цикл обучения (forward/backward pass) - Поддержка обучения на CPU/GPU - Расчет и сохранение метрик (train_loss, validation_loss) - Интеграция с оптимизатором Adam 2. Документация: - Подробное описание метода в gpt_documentation_ru.md - Примеры использования в README.md - Параметры и требования к данным 3. Тестирование: - Тесты базовой функциональности - Проверка изменения весов - Тесты для разных устройств (CPU/CUDA) - Обработка edge-cases 4. Примеры: - train_gpt_example.py с полным workflow - Генерация синтетических данных - Сохранение/загрузка моделей
3.6 KiB
3.6 KiB
Simple-LLM Framework
Простая и понятная реализация языковой модели GPT-стиля с нуля на PyTorch
🔍 Обзор
Simple-LLM предоставляет:
- Полную реализацию архитектуры GPT
- Эффективный токенизатор BPE
- Модули трансформера (внимание, FFN, эмбеддинги)
- Гибкую систему генерации текста
- Примеры использования и документацию
🚀 Быстрый старт
- Установите зависимости:
pip install torch numpy tqdm
- Запустите примеры:
# Пример генерации текста
python example/example_gpt.py
# Пример обучения модели
python example/train_gpt_example.py
🧠 Основные компоненты
Обработка данных
from simple_llm.data.get_data import GetData
dataset = GetData(
data=[1, 2, 3, 4, 5], # Входная последовательность
seq_len=3, # Длина окна
device="cuda" # Устройство (опционально)
)
Модель GPT
from simple_llm.transformer.gpt import GPT
model = GPT(
vocab_size=10000,
max_seq_len=512,
emb_size=768,
num_heads=12,
num_layers=6
)
Генерация текста
output = model.generate(
input_ids,
max_new_tokens=100,
temperature=0.9,
top_k=50,
top_p=0.9
)
Обучение модели
from torch.utils.data import DataLoader
# Данные должны быть в формате (input_ids, targets)
# targets - это input_ids, сдвинутые на 1 токен вперед
train_loader = DataLoader(...)
model.fit(
train_loader=train_loader, # Обучающие данные (обязательно)
valid_loader=None, # Валидационные данные (опционально)
num_epoch=10, # Количество эпох
learning_rate=0.001 # Скорость обучения
)
# Сохранение модели
model.save("model.pt")
# Загрузка модели
loaded_model = GPT.load("model.pt", device="cuda")
Требования к данным:
- Формат:
(input_ids, targets)гдеtargets = roll(input_ids, -1) input_ids: тензор формы[batch_size, seq_len]- Поддерживаются как синтетические, так и реальные текстовые данные
📚 Документация
Полная документация доступна в doc/:
🛠 Тестирование
pytest tests/
🤝 Как внести вклад
- Форкните репозиторий
- Создайте ветку (
git checkout -b feature/AmazingFeature) - Сделайте коммит (
git commit -m 'Add some AmazingFeature') - Запушьте ветку (
git push origin feature/AmazingFeature) - Откройте Pull Request
📜 Лицензия
Распространяется под лицензией MIT. См. LICENSE