mirror of
https://github.com/pese-git/llm-arch-research.git
synced 2026-01-23 21:10:54 +00:00
130
README.md
130
README.md
@@ -1,15 +1,16 @@
|
|||||||
# LLM Architecture Research
|
# LLM Architecture Research
|
||||||
|
|
||||||
Исследовательский проект для разработки и обучения архитектур больших языковых моделей (LLM).
|
Исследовательский проект по разработке, обучению и сравнительному анализу современных архитектур больших языковых моделей (LLM): **GPT, GPT-2, LLaMA, Mistral**. Прямая поддержка интеграции с HuggingFace (через модуль `hf-proxy`).
|
||||||
|
|
||||||
|
|
||||||
## 🏗️ Архитектура проекта
|
## 🏗️ Архитектура проекта
|
||||||
|
|
||||||
Проект организован как монорепозиторий с использованием **uv** workspace:
|
Проект организован как монорепозиторий с использованием **uv** workspace:
|
||||||
|
|
||||||
- **`llm`** — основная библиотека с реализацией архитектур LLM (GPT, GPT-2)
|
- **`llm`** — основная библиотека с реализацией архитектур LLM (**GPT, GPT-2, LLaMA, Mistral**)
|
||||||
- **`hf-proxy`** — адаптер для интеграции с HuggingFace
|
- **`hf-proxy`** — экспериментальный адаптер для интеграции с HuggingFace (загрузка, токенизация, экспериментальные скрипты). Функционал может изменяться и не гарантирует полной совместимости с будущими версиями HuggingFace Transformers.
|
||||||
- **`experiments`** — скрипты обучения и экспериментов
|
- **`experiments`** — скрипты обучения и генерации (включая HF и собственные модели)
|
||||||
- **`notebooks`** — исследовательские ноутбуки
|
- **`notebooks`** — исследовательские ноутбуки, анализ архитектур
|
||||||
|
|
||||||
## 📁 Структура проекта
|
## 📁 Структура проекта
|
||||||
|
|
||||||
@@ -41,8 +42,11 @@ llm-arch-research/
|
|||||||
│ │ │ ├── gpt.py
|
│ │ │ ├── gpt.py
|
||||||
│ │ │ ├── gpt2.py
|
│ │ │ ├── gpt2.py
|
||||||
│ │ │ └── __init__.py
|
│ │ │ └── __init__.py
|
||||||
│ │ └── llama/ # LLaMA архитектура
|
│ │ ├── llama/ # LLaMA архитектура
|
||||||
│ │ ├── llama.py
|
│ │ │ ├── llama.py
|
||||||
|
│ │ │ └── __init__.py
|
||||||
|
│ │ └── mistral/ # Mistral архитектура
|
||||||
|
│ │ ├── mistral.py
|
||||||
│ │ └── __init__.py
|
│ │ └── __init__.py
|
||||||
│ ├── training/ # утилиты обучения
|
│ ├── training/ # утилиты обучения
|
||||||
│ │ ├── dataset.py
|
│ │ ├── dataset.py
|
||||||
@@ -81,6 +85,18 @@ llm-arch-research/
|
|||||||
|
|
||||||
## 🚀 Быстрый старт
|
## 🚀 Быстрый старт
|
||||||
|
|
||||||
|
**Пример запуска обучения и генерации для любых архитектур:**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python experiments/llm_only/run_llm_experiment.py --model mistral --action generate --config experiments/llm_only/configs/mistral_generate.json
|
||||||
|
```
|
||||||
|
|
||||||
|
**Использование собственных моделей с HuggingFace-интерфейсом:**
|
||||||
|
```python
|
||||||
|
from hf_proxy.hf_adapter import HFAdapter
|
||||||
|
hf_model = HFAdapter("mistralai/Mistral-7B-v0.1")
|
||||||
|
```
|
||||||
|
|
||||||
### Установка зависимостей
|
### Установка зависимостей
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
@@ -91,73 +107,15 @@ uv sync
|
|||||||
uv sync --extra dev
|
uv sync --extra dev
|
||||||
```
|
```
|
||||||
|
|
||||||
## ⚡ Работа с экспериментами (experiments/llm_only)
|
## ⚡ Работа с экспериментами (experiments/llm_only, experiments/hf_integration)
|
||||||
|
|
||||||
В папке `experiments/llm_only` вы найдете универсальный скрипт для обучения и генерации LLM без HuggingFace.
|
- В `experiments/llm_only`: универсальный скрипт для обучения и генерации LLM (включая LLaMA и Mistral) без HuggingFace — всё через собственную реализацию.
|
||||||
Архитектура позволяет управлять выбором модели, типом действия и параметрами через аргументы командной строки и отдельные JSON-конфиги.
|
- В `experiments/hf_integration`: скрипты и примеры для генерации, обучения и тестирования моделей с помощью HuggingFace API (через hf-proxy). Позволяет использовать свои модели и токенизаторы как стандартные HF-объекты.
|
||||||
|
|
||||||
### Основные файлы и директории
|
**Для моделей Mistral/Llama доступны оба сценария: прямая работа или через HuggingFace-прокси.**
|
||||||
|
|
||||||
- `run_llm_experiment.py` — универсальный скрипт-стартер для обучения и генерации.
|
*Конфиги и примеры см. в соответствующих папках.*
|
||||||
- `configs/` — примеры конфигурационных файлов (`*.json`) для разных моделей и сценариев.
|
|
||||||
|
|
||||||
### Использование универсального скрипта
|
|
||||||
|
|
||||||
1. **Настройте конфиг**
|
|
||||||
Для каждой модели и режима работы есть отдельный JSON-файл с параметрами:
|
|
||||||
- `configs/gpt_train.json`, `configs/gpt_generate.json`
|
|
||||||
- `configs/gpt2_train.json`, `configs/gpt2_generate.json`
|
|
||||||
- `configs/llama_train.json`, `configs/llama_generate.json`
|
|
||||||
|
|
||||||
2. **Запустите обучение или генерацию**
|
|
||||||
Стандартная команда:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
python experiments/llm_only/run_llm_experiment.py --model <название_модели> --action <train/generate> --config experiments/llm_only/configs/<имя_конфига>.json
|
|
||||||
```
|
|
||||||
|
|
||||||
Примеры:
|
|
||||||
|
|
||||||
- Обучить GPT:
|
|
||||||
```bash
|
|
||||||
python experiments/llm_only/run_llm_experiment.py --model gpt --action train --config experiments/llm_only/configs/gpt_train.json
|
|
||||||
```
|
|
||||||
|
|
||||||
- Генерировать текст GPT2:
|
|
||||||
```bash
|
|
||||||
python experiments/llm_only/run_llm_experiment.py --model gpt2 --action generate --config experiments/llm_only/configs/gpt2_generate.json
|
|
||||||
```
|
|
||||||
|
|
||||||
- Обучить Llama:
|
|
||||||
```bash
|
|
||||||
python experiments/llm_only/run_llm_experiment.py --model llama --action train --config experiments/llm_only/configs/llama_train.json
|
|
||||||
```
|
|
||||||
|
|
||||||
- Генерировать текст Llama:
|
|
||||||
```bash
|
|
||||||
python experiments/llm_only/run_llm_experiment.py --model llama --action generate --config experiments/llm_only/configs/llama_generate.json
|
|
||||||
```
|
|
||||||
|
|
||||||
3. **Конфигурирование параметров**
|
|
||||||
- Все гиперпараметры (архитектура, обучение, генерация, пути) задаются в json-файле.
|
|
||||||
- Для новых моделей создайте копию существующего конфига, укажите другие веса и параметры, и используйте нужное название модели в команде.
|
|
||||||
|
|
||||||
4. **Структура конфига**
|
|
||||||
Минимальный пример конфига для обучения:
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"bpe_tokenizer": "checkpoints/bpe_tokenizer.json",
|
|
||||||
"model_config": {
|
|
||||||
"vocab_size": null,
|
|
||||||
"embed_dim": 256,
|
|
||||||
"num_heads": 4,
|
|
||||||
"num_layers": 4,
|
|
||||||
"max_position_embeddings": 128,
|
|
||||||
"dropout": 0.1
|
|
||||||
},
|
|
||||||
"model_weights": "checkpoints/gpt-bpe/model.pt"
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
@@ -272,33 +230,23 @@ dependencies = [
|
|||||||
|
|
||||||
## 🎯 Реализованные возможности
|
## 🎯 Реализованные возможности
|
||||||
|
|
||||||
### Архитектуры GPT и GPT-2
|
### Архитектуры
|
||||||
- ✅ Токенные и позиционные эмбеддинги
|
- ✅ GPT, GPT-2: Полностью воспроизводимые реализации, токенные и позиционные эмбеддинги, causal multi-head attention, LayerNorm
|
||||||
- ✅ Многоголовое внимание с causal mask
|
- ✅ LLaMA: Rotary Positional Embeddings (RoPE), RMSNorm, SwiGLU, оптимизированная память
|
||||||
- ✅ Декодерные блоки с residual connections
|
- ✅ Mistral: Sliding Window Attention (оконное внимание), Grouped Query Attention (GQA), совместимость с HF
|
||||||
- ✅ Layer normalization
|
- ✅ Все архитектуры поддерживают обучение и генерацию текста
|
||||||
- ✅ Dropout регуляризация
|
|
||||||
- ✅ Отдельные реализации GPT и GPT-2 (различия в масштабе и деталях архитектуры)
|
|
||||||
|
|
||||||
### Генерация текста
|
### Генерация текста
|
||||||
- ✅ Жадный поиск (greedy decoding)
|
- ✅ Greedy, sampling (Top-k, Top-p), контроль температуры, efficient caching
|
||||||
- ✅ Вероятностное сэмплирование
|
|
||||||
- ✅ Top-k сэмплирование
|
|
||||||
- ✅ Nucleus sampling (top-p)
|
|
||||||
- ✅ Контроль температуры
|
|
||||||
|
|
||||||
### Обучение
|
### Обучение
|
||||||
- ✅ Датасет для языкового моделирования
|
- ✅ Языковое моделирование с кастомными и HF-токенизаторами
|
||||||
- ✅ Базовый тренировочный цикл
|
- ✅ AdamW, кастомные датасеты, сохранение чекпоинтов
|
||||||
- ✅ Оптимизатор AdamW
|
|
||||||
- ✅ Сохранение чекпоинтов
|
|
||||||
|
|
||||||
### Интеграция с HuggingFace (hf-proxy)
|
### Интеграция с HuggingFace (hf-proxy)
|
||||||
- ✅ Адаптер моделей для совместимости с HF интерфейсами
|
- ✅ Экспорт/импорт моделей и токенизаторов в HF совместимый формат
|
||||||
- ✅ Адаптер токенизаторов с поддержкой всех методов HF
|
- ✅ Генерация и обучение через HF Trainer, pipelines и т.д.
|
||||||
- ✅ Сохранение и загрузка в HF формате
|
- ✅ Двусторонняя поддержка: собственные модели становятся HF-совместимыми и наоборот
|
||||||
- ✅ Совместимость с HF Trainer и pipelines
|
|
||||||
- ✅ Генерация через стандартные HF интерфейсы
|
|
||||||
|
|
||||||
## 🔬 Эксперименты с hf-proxy
|
## 🔬 Эксперименты с hf-proxy
|
||||||
|
|
||||||
|
|||||||
19
experiments/llm_only/configs/mistral_generate.json
Normal file
19
experiments/llm_only/configs/mistral_generate.json
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
{
|
||||||
|
"bpe_tokenizer": "checkpoints/bpe_tokenizer.json",
|
||||||
|
"test_prompts": [
|
||||||
|
"Open weights",
|
||||||
|
"The Llama model is",
|
||||||
|
"Efficient transformers"
|
||||||
|
],
|
||||||
|
"model_config_path": "checkpoints/mistral-bpe/config.json",
|
||||||
|
"model_weights": "checkpoints/mistral-bpe/model.pt",
|
||||||
|
"generation": {
|
||||||
|
"max_new_tokens": 40,
|
||||||
|
"temperature": 0.8,
|
||||||
|
"do_sample": true,
|
||||||
|
"top_k": null,
|
||||||
|
"top_p": null
|
||||||
|
},
|
||||||
|
"log_path": "checkpoints/mistral_only_generation_logs.json"
|
||||||
|
}
|
||||||
|
|
||||||
26
experiments/llm_only/configs/mistral_train.json
Normal file
26
experiments/llm_only/configs/mistral_train.json
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
{
|
||||||
|
"bpe_tokenizer": "checkpoints/bpe_tokenizer.json",
|
||||||
|
"bpe_vocab_size": 1000,
|
||||||
|
"bpe_special_tokens": ["<pad>", "<unk>", "<bos>", "<eos>"],
|
||||||
|
"test_prompts": ["Open source AI", "What is Llama?"],
|
||||||
|
"model_config": {
|
||||||
|
"vocab_size": null,
|
||||||
|
"embed_dim": 256,
|
||||||
|
"num_q_heads": 4,
|
||||||
|
"num_kv_heads": 2,
|
||||||
|
"head_size": 64,
|
||||||
|
"num_layers": 4,
|
||||||
|
"max_position_embeddings": 512,
|
||||||
|
"window_size": 16,
|
||||||
|
"dropout": 0.1
|
||||||
|
},
|
||||||
|
"model_weights": "checkpoints/mistral-bpe/model.pt",
|
||||||
|
"model_config_path": "checkpoints/mistral-bpe/config.json",
|
||||||
|
"training": {
|
||||||
|
"learning_rate": 0.0003,
|
||||||
|
"batch_size": 2,
|
||||||
|
"num_epochs": 3,
|
||||||
|
"warmup_steps": 50
|
||||||
|
},
|
||||||
|
"log_path": "checkpoints/mistral_only_training_logs.json"
|
||||||
|
}
|
||||||
@@ -16,7 +16,7 @@ import torch
|
|||||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
|
|
||||||
from llm.tokenizers import BPETokenizer
|
from llm.tokenizers import BPETokenizer
|
||||||
from llm.training.dataset import TextDataset
|
from llm.datasets.text_dataset import TextDataset
|
||||||
from llm.training.trainer import Trainer
|
from llm.training.trainer import Trainer
|
||||||
|
|
||||||
from shared.data import (
|
from shared.data import (
|
||||||
@@ -42,6 +42,9 @@ def load_model_class(model_name):
|
|||||||
elif model_name.lower() == 'llama':
|
elif model_name.lower() == 'llama':
|
||||||
from llm.models.llama import Llama
|
from llm.models.llama import Llama
|
||||||
return Llama
|
return Llama
|
||||||
|
elif model_name.lower() == 'mistral':
|
||||||
|
from llm.models.mistral import Mistral
|
||||||
|
return Mistral
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Модель '{model_name}' не поддерживается.")
|
raise ValueError(f"Модель '{model_name}' не поддерживается.")
|
||||||
|
|
||||||
|
|||||||
@@ -27,14 +27,19 @@ llm/
|
|||||||
│ │ ├── gpt.py # Базовая GPT
|
│ │ ├── gpt.py # Базовая GPT
|
||||||
│ │ ├── gpt2.py # GPT-2 реализация
|
│ │ ├── gpt2.py # GPT-2 реализация
|
||||||
│ │ └── __init__.py
|
│ │ └── __init__.py
|
||||||
│ └── llama/ # LLaMA архитектура
|
│ ├── llama/ # LLaMA архитектура
|
||||||
│ ├── llama.py # LLaMA реализация
|
│ │ ├── llama.py # LLaMA реализация
|
||||||
|
│ │ └── __init__.py
|
||||||
|
│ └── mistral/ # Mistral архитектура
|
||||||
|
│ ├── mistral.py # Mistral реализация
|
||||||
│ └── __init__.py
|
│ └── __init__.py
|
||||||
├── tokenizers/ # Токенизаторы
|
├── tokenizers/ # Токенизаторы
|
||||||
│ ├── base_tokenizer.py # Базовый интерфейс
|
│ ├── base_tokenizer.py # Базовый интерфейс
|
||||||
│ └── bpe_tokenizer.py # BPE токенизатор
|
│ └── bpe_tokenizer.py # BPE токенизатор
|
||||||
|
├── datasets/ # Работа с датасетами
|
||||||
|
│ ├── text_dataset.py # Стандартный датасет
|
||||||
|
│ └── streaming_text_dataset.py # Стриминговый датасет
|
||||||
└── training/ # Утилиты обучения
|
└── training/ # Утилиты обучения
|
||||||
├── dataset.py # Датасеты
|
|
||||||
├── trainer.py # Тренировочный цикл
|
├── trainer.py # Тренировочный цикл
|
||||||
├── optimizer.py # Оптимизаторы
|
├── optimizer.py # Оптимизаторы
|
||||||
└── scheduler.py # Планировщики обучения
|
└── scheduler.py # Планировщики обучения
|
||||||
@@ -176,12 +181,11 @@ generated = model.generate(input_ids, max_length=100)
|
|||||||
- ✅ Базовая архитектура трансформер-декодера
|
- ✅ Базовая архитектура трансформер-декодера
|
||||||
|
|
||||||
### GPT-2 Особенности
|
### GPT-2 Особенности
|
||||||
- ✅ Улучшенная версия оригинальной GPT
|
|
||||||
- ✅ Layer Normalization (перед вниманием и FFN)
|
- ✅ Layer Normalization (перед вниманием и FFN)
|
||||||
- ✅ GELU активация
|
- ✅ GELU активация
|
||||||
- ✅ Learned positional embeddings
|
- ✅ Learned positional embeddings
|
||||||
- ✅ Кэширование для эффективной генерации
|
- ✅ Кэширование KV для быстрой генерации
|
||||||
- ✅ Оптимизированные веса инициализации
|
- ✅ Улучшенная инициализация слоёв
|
||||||
|
|
||||||
### LLaMA Особенности
|
### LLaMA Особенности
|
||||||
- ✅ Rotary Positional Embeddings (RoPE)
|
- ✅ Rotary Positional Embeddings (RoPE)
|
||||||
@@ -190,6 +194,21 @@ generated = model.generate(input_ids, max_length=100)
|
|||||||
- ✅ Оптимизированная структура декодера
|
- ✅ Оптимизированная структура декодера
|
||||||
- ✅ Эффективное кэширование KV-памяти
|
- ✅ Эффективное кэширование KV-памяти
|
||||||
|
|
||||||
|
### Mistral Особенности
|
||||||
|
- ✅ Sliding Window Attention (оконное внимание)
|
||||||
|
- ✅ Grouped Query Attention (GQA)
|
||||||
|
- ✅ RoPE
|
||||||
|
- ✅ RMSNorm
|
||||||
|
- ✅ Разделённая архитектура на блоки с эффективным управлением памятью
|
||||||
|
- ✅ Совместимость с HuggingFace через hf-proxy
|
||||||
|
|
||||||
|
## 🤝 Интеграция с HuggingFace и BPE
|
||||||
|
|
||||||
|
- Встроенная поддержка собственных BPE токенизаторов и экспериментальная поддержка токенизаторов через HuggingFace (см. hf-proxy).
|
||||||
|
- hf-proxy — экспериментальный модуль! Совместимость с будущими версиями Transformers не гарантируется; API может меняться.
|
||||||
|
- Допускается загрузка/конвертация моделей в формат HF для использования экосистемы Transformers.
|
||||||
|
- Для запуска моделей с токенизаторами HF используйте `hf-proxy` и соответствующие эксперименты из `experiments/hf_integration/`.
|
||||||
|
|
||||||
## 🧪 Тестирование
|
## 🧪 Тестирование
|
||||||
|
|
||||||
Запуск всех тестов:
|
Запуск всех тестов:
|
||||||
@@ -198,7 +217,7 @@ cd llm
|
|||||||
python -m pytest tests/ -v
|
python -m pytest tests/ -v
|
||||||
```
|
```
|
||||||
|
|
||||||
**Статус тестов:** ✅ 101 тест пройден
|
**Статус тестов:** ✅ 101+ тест, охвачены все основные компоненты (ядро, ядро-токенизация, архитектуры, обучение)
|
||||||
|
|
||||||
## 📚 Научные концепции
|
## 📚 Научные концепции
|
||||||
|
|
||||||
|
|||||||
@@ -9,33 +9,46 @@ from .rope import RoPE
|
|||||||
|
|
||||||
class CachedDecoder(nn.Module):
|
class CachedDecoder(nn.Module):
|
||||||
"""
|
"""
|
||||||
Универсальный декодерный блок для современных LLM (GPT, LLaMA, др.), поддерживает кэширование key-value для эффективной генерации.
|
CachedDecoder — Transformer-декодер с key/value-кэшированием (реализация накладывающегося masked multi-head attention).
|
||||||
|
|
||||||
Научная идея:
|
Назначение:
|
||||||
Автопагрессивная авторегрессия в трансформерах требует быстрого доступа к ранее вычисленным self-attention ключам/значениям — этот класс позволяет прозрачно кэшировать такие состояния для быстрой инференс-генерации.
|
-----------
|
||||||
|
Позволяет быстро и эффективно реализовывать autoregressive генерацию текста в стиле GPT-2/3/4:
|
||||||
|
- На шаге генерации используются только нужные токены, “прошлые” key/value значения не пересчитываются, а подаются из кэша.
|
||||||
|
- Позволяет значительно ускорять inferece (особенно на длинных последовательностях).
|
||||||
|
- Вдохновлено реализациями в HuggingFace transformers, GPT-2/3 и других LLM.
|
||||||
|
|
||||||
Алгоритм:
|
Архитектурные особенности:
|
||||||
- Input -> LayerNorm -> Многоголовое внимание с кэшем (может быть RoPE)
|
--------------------------
|
||||||
- Суммируем residual
|
- Использует классическую multi-head attention (с causal mask — запрещает видеть “будущее”).
|
||||||
- LayerNorm -> FeedForward (любой, например SwiGLU) -> Residual
|
- Предусматривает передачу и накопление KV-cache для каждого слоя (hidden state attention).
|
||||||
- Возвращается кортеж (output, kvcache)
|
- Поддерживает передачу внимания через стек attention-блоков.
|
||||||
|
- Применяется layernorm и feed-forward block (GELU).
|
||||||
|
|
||||||
Args:
|
Параметры конструктора:
|
||||||
feed_forward_layer (nn.Module): FeedForward или SwiGLU слой
|
-----------------------
|
||||||
num_heads (int): Количество голов внимания
|
num_heads : int — число attention heads
|
||||||
emb_size (int): Размерность эмбеддингов
|
emb_size : int — embedding размерность
|
||||||
head_size (int): Размерность головы внимания
|
head_size : int — размер каждой attention head (обычно emb_size // num_heads)
|
||||||
max_seq_len (int): Максимальная длина
|
feed_forward_layer : nn.Module — feedforward блок (mLP), может быть любым PyTorch-слоем
|
||||||
norm_layer (тип nn.Module): Normalization слой (LayerNorm или RMSNorm)
|
max_seq_len : int — максимально допустимая длина последовательности
|
||||||
dropout (float): Dropout
|
dropout : float — dropout на attention/ffn
|
||||||
rope (RoPE|None): Экземпляр RoPE (для LLaMA)
|
|
||||||
|
Пример использования:
|
||||||
|
---------------------
|
||||||
|
>>> from llm.core.feed_forward import FeedForward
|
||||||
|
>>> ff_block = FeedForward(emb_size=256, dropout=0.1, activation=\"gelu\")
|
||||||
|
>>> decoder = CachedDecoder(num_heads=4, emb_size=256, head_size=64, feed_forward_layer=ff_block, max_seq_len=2048, dropout=0.1)
|
||||||
|
>>> x = torch.randn(2, 100, 256)
|
||||||
|
>>> y, kv_cache = decoder(x, use_cache=True, cache=None)
|
||||||
|
>>> print(y.shape) # torch.Size([2, 100, 256])
|
||||||
|
|
||||||
|
Подробнее:
|
||||||
|
----------
|
||||||
|
- GPT-2: https://cdn.openai.com/better-language-models/language-models.pdf
|
||||||
|
- HuggingFace cache mechanics: https://huggingface.co/docs/transformers/main/en/model_doc/gpt2
|
||||||
|
- Объяснения autoregressive cache: https://jalammar.github.io/illustrated-gpt2/
|
||||||
|
|
||||||
Пример (GPT2 style):
|
|
||||||
>>> decoder = CachedDecoder(
|
|
||||||
... feed_forward_layer=FeedForward(...),
|
|
||||||
... norm_layer=nn.LayerNorm,
|
|
||||||
... num_heads=4, emb_size=256, head_size=64, max_seq_len=128)
|
|
||||||
>>> out, cache = decoder(x, use_cache=True)
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -50,20 +63,22 @@ class CachedDecoder(nn.Module):
|
|||||||
rope: RoPE = None,
|
rope: RoPE = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Инициализация декодера с кэшированием.
|
Конструктор CachedDecoder.
|
||||||
|
|
||||||
Поведение аналогично блоку TransformerDecoderLayer,
|
Аргументы:
|
||||||
но с гибкой возможностью подмены любых подкомпонент (активация, norm, позиции).
|
----------
|
||||||
|
num_heads : int
|
||||||
Args:
|
Сколько attention heads используется в каждом attention слое.
|
||||||
feed_forward_layer: Слой feed-forward (должен быть экземпляром, а не классом)
|
emb_size : int
|
||||||
num_heads: Количество голов внимания
|
Размерность входного вектора x.
|
||||||
emb_size: Размерность эмбеддингов
|
head_size : int
|
||||||
head_size: Размерность каждой головы
|
Размерность каждой attention head; emb_size = num_heads * head_size должно быть True!
|
||||||
max_seq_len: Максимальная длина последовательности
|
feed_forward_layer : nn.Module
|
||||||
norm_layer: Класс нормализации (по умолчанию LayerNorm)
|
Feed-forward слой (например, обычный двухслойный MLP), который применяется после нормы и внимания, и после второй нормы.
|
||||||
dropout: Вероятность dropout
|
max_seq_len : int
|
||||||
rope: Rotary Positional Embeddings (опционально)
|
Максимальная поддерживаемая длина последовательности (выделяет буфер для causal-маски).
|
||||||
|
dropout : float, default=0.1
|
||||||
|
Dropout после внимания и/или feedforward.
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._heads = MultiHeadAttention(
|
self._heads = MultiHeadAttention(
|
||||||
@@ -86,19 +101,30 @@ class CachedDecoder(nn.Module):
|
|||||||
cache: list = None,
|
cache: list = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Прямой проход с поддержкой кэша.
|
Прямой проход через Decoder Block с поддержкой KV-кэша.
|
||||||
|
|
||||||
|
В этом методе применяется:
|
||||||
|
- Causal multi-head attention (masked, не смотрит вперёд)
|
||||||
|
- Быстрая обработка длинных последовательностей за счёт сохранения и передачи KV-кэша
|
||||||
|
- LayerNorm перед каждым блоком
|
||||||
|
- Feed-forward блок и вторая LayerNorm
|
||||||
|
- Dropout
|
||||||
|
|
||||||
|
Аргументы:
|
||||||
|
----------
|
||||||
|
x : torch.Tensor
|
||||||
|
Вход [batch, seq_len, emb_size]
|
||||||
|
use_cache : bool, по умолчанию True
|
||||||
|
Включать ли накопление и возврат KV-кэша для autoregressive inferece.
|
||||||
|
cache : list, опционально
|
||||||
|
Список предыдущего KV-кеша для attention.
|
||||||
|
|
||||||
|
Возвращает:
|
||||||
|
-----------
|
||||||
|
x_ff_out : torch.Tensor
|
||||||
|
Результат после attention, модуля и их рез. связей (shape == x)
|
||||||
|
new_cache : new KV-cache (или None)
|
||||||
|
|
||||||
Args:
|
|
||||||
x (Tensor[float]): [batch, seq_len, emb_size] — скрытые состояния
|
|
||||||
mask (Optional[Tensor]): маска внимания (или causal mask), shape [seq_len, seq_len]
|
|
||||||
use_cache (bool): использовать кэширование KV
|
|
||||||
cache (list): кэш self-attention для быстрого авторегрессива
|
|
||||||
Returns:
|
|
||||||
output (Tensor[float]): выходные состояния [batch, seq_len, emb_size]
|
|
||||||
kv_caches (list): обновленный кэш, если use_cache
|
|
||||||
Пример:
|
|
||||||
>>> out, new_cache = decoder(x, use_cache=True, cache=old_cache)
|
|
||||||
>>> out.shape # [batch, seq_len, emb_size]
|
|
||||||
"""
|
"""
|
||||||
norm1_out = self._norm1(x)
|
norm1_out = self._norm1(x)
|
||||||
# Передаём все cache/use_cache дальше в attention
|
# Передаём все cache/use_cache дальше в attention
|
||||||
|
|||||||
@@ -6,37 +6,45 @@ from .multi_head_attention import MultiHeadAttention
|
|||||||
|
|
||||||
class Decoder(nn.Module):
|
class Decoder(nn.Module):
|
||||||
"""
|
"""
|
||||||
Базовый автогерессивный блок-декодер трансформера (без кэша KV).
|
Decoder — базовый transformer decoder block (pre-LN), классический строительный блок современных языковых моделей.
|
||||||
|
|
||||||
Предназначен для:
|
Назначение:
|
||||||
- Обработки последовательностей с учетом контекста (самовнимание)
|
-----------
|
||||||
- Постепенного генерирования выходной последовательности
|
- Инкапсулирует архитектуру: norm → multi-head self-attention → residual → norm → feed-forward → residual
|
||||||
- Учета масок для предотвращения "заглядывания в будущее"
|
- Подходит как для LLM/GPT, так и для любых autoregressive sequence моделей.
|
||||||
|
- Использует masked self-attention: каждый токен видит только предыдущие (никакого \"заглядывания в будущее\").
|
||||||
|
- Стабильность обеспечивается через residual connections и LayerNorm после каждого sub-layer.
|
||||||
|
|
||||||
Алгоритм работы:
|
Почему это важно?
|
||||||
1. Входной тензор (batch_size, seq_len, emb_size)
|
-----------------
|
||||||
2. Многоголовое внимание с residual connection и LayerNorm
|
- Все современные языковые модели состоят из подобных блоков, соединённых в стек.
|
||||||
3. FeedForward сеть с residual connection и LayerNorm
|
- Алгоритм residual+norm позволяет проще обучать очень глубокие сети.
|
||||||
4. Выходной тензор (batch_size, seq_len, emb_size)
|
- Разделение на attention+FFN дает и локальные, и глобальные взаимодействия между токенами.
|
||||||
|
|
||||||
Основные характеристики:
|
Формула работы (псевдокод):
|
||||||
- Поддержка масок внимания
|
---------------------------
|
||||||
- Residual connections для стабилизации градиентов
|
y1 = norm1(x)
|
||||||
- Layer Normalization после каждого sub-layer
|
attn_out = Attention(y1)
|
||||||
- Конфигурируемые параметры внимания
|
x2 = x + attn_out # residual
|
||||||
|
y2 = norm2(x2)
|
||||||
|
ffn_out = FFN(y2)
|
||||||
|
out = x2 + ffn_out # residual
|
||||||
|
|
||||||
|
Архитектурные особенности:
|
||||||
|
--------------------------
|
||||||
|
- Поддержка внимания с маской (causal mask или произвольная attention mask)
|
||||||
|
- Residual connections для каждого блока (attention, FFN)
|
||||||
|
- Pre-LN (norm перед каждым подблоком)
|
||||||
|
- Зависит от переданных блоков self_attention и feed_forward, а не их реализации
|
||||||
|
|
||||||
|
References:
|
||||||
|
-----------
|
||||||
|
- Vaswani et al., \"Attention is All You Need\" (2017): https://arxiv.org/abs/1706.03762
|
||||||
|
- Illustrated Transformer: https://jalammar.github.io/illustrated-transformer/
|
||||||
|
- Transformer Circuits (дружественное описание): https://transformer-circuits.pub/2021/framework/index.html
|
||||||
|
|
||||||
Научная суть:
|
|
||||||
- Осуществляет посимвольное предсказание: каждый токен видит только предыдущие (masked attention)
|
|
||||||
- Состоит из self-attention + feedforward + residual + нормализация
|
|
||||||
- Residual connection и normalization дают стабильность и градиентный “flow” при обучении
|
|
||||||
- Механизм предложен в Vaswani et al., "Attention is All You Need", 2017
|
|
||||||
Args:
|
|
||||||
num_heads (int): количество attention-голов
|
|
||||||
emb_size (int): размер эмбеддинга
|
|
||||||
head_size (int): размер одной attention-головы
|
|
||||||
max_seq_len (int): максимальная длина последовательности
|
|
||||||
dropout (float): вероятность dropout
|
|
||||||
Пример:
|
Пример:
|
||||||
|
-------
|
||||||
>>> decoder = Decoder(num_heads=8, emb_size=512, head_size=64, max_seq_len=1024)
|
>>> decoder = Decoder(num_heads=8, emb_size=512, head_size=64, max_seq_len=1024)
|
||||||
>>> x = torch.randn(1, 10, 512)
|
>>> x = torch.randn(1, 10, 512)
|
||||||
>>> out = decoder(x)
|
>>> out = decoder(x)
|
||||||
@@ -52,14 +60,27 @@ class Decoder(nn.Module):
|
|||||||
dropout: float = 0.1,
|
dropout: float = 0.1,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Инициализация декодера.
|
Инициализация стандартного decoder-блока для Transformer.
|
||||||
|
|
||||||
Параметры:
|
Аргументы:
|
||||||
num_heads: int - количество голов внимания
|
----------
|
||||||
emb_size: int - размерность эмбеддингов
|
num_heads: int
|
||||||
head_size: int - размерность каждой головы внимания
|
Количество attention голов (как делить emb_size на heads)
|
||||||
max_seq_len: int - максимальная длина последовательности
|
emb_size: int
|
||||||
dropout: float (default=0.1) - вероятность dropout
|
Размерность эмбеддингов (и входа и выхода)
|
||||||
|
head_size: int
|
||||||
|
Размерность одной attention-головы (emb_size = num_heads * head_size)
|
||||||
|
max_seq_len: int
|
||||||
|
Максимальная длина последовательности (важно для mask)
|
||||||
|
dropout: float, default=0.1
|
||||||
|
Dropout после внимания и FFN
|
||||||
|
|
||||||
|
Внутри:
|
||||||
|
-------
|
||||||
|
- Создаёт слой MultiHeadAttention (masked/casual)
|
||||||
|
- Создаёт двухслойный FeedForward (SwiGLU или GELU)
|
||||||
|
- Применяет 2 слоя LayerNorm для стабилизации градиентов
|
||||||
|
- Все блоки реализованы как PyTorch-модули
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._heads = MultiHeadAttention(
|
self._heads = MultiHeadAttention(
|
||||||
@@ -75,20 +96,26 @@ class Decoder(nn.Module):
|
|||||||
|
|
||||||
def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
|
def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Прямой проход через декодер.
|
Один прямой проход через Transformer decoder block.
|
||||||
|
|
||||||
Вход:
|
Аргументы:
|
||||||
x: torch.Tensor - входной тензор [batch_size, seq_len, emb_size]
|
----------
|
||||||
mask: torch.Tensor (optional) - маска внимания [seq_len, seq_len]
|
x : torch.Tensor
|
||||||
|
Входной тензор [batch_size, seq_len, emb_size]
|
||||||
|
mask : torch.Tensor, optional
|
||||||
|
Attention/causal mask (по умолчанию None, тогда будет casual mask по длине seq_len)
|
||||||
|
|
||||||
Возвращает:
|
Возвращает:
|
||||||
torch.Tensor - выходной тензор [batch_size, seq_len, emb_size]
|
-----------
|
||||||
|
out : torch.Tensor
|
||||||
|
Выходной тензор той же формы, что и x
|
||||||
|
|
||||||
Алгоритм forward:
|
Алгоритм:
|
||||||
1. Применяем MultiHeadAttention к входу
|
---------
|
||||||
2. Добавляем residual connection и LayerNorm
|
- Применяем attention к нормализованному входу (layernorm)
|
||||||
3. Применяем FeedForward сеть
|
- Добавляем residual-связь (attention + исходный вход)
|
||||||
4. Добавляем residual connection и LayerNorm
|
- Применяем FFN к нормализованному результату (layernorm)
|
||||||
|
- Добавляем residual-связь (ffn + предыдущий выход)
|
||||||
"""
|
"""
|
||||||
# Self-Attention блок
|
# Self-Attention блок
|
||||||
attention, _ = self._heads(x, mask, use_cache=False, cache=None)
|
attention, _ = self._heads(x, mask, use_cache=False, cache=None)
|
||||||
|
|||||||
@@ -6,53 +6,71 @@ from .gelu import GELU
|
|||||||
|
|
||||||
class FeedForward(nn.Module):
|
class FeedForward(nn.Module):
|
||||||
"""
|
"""
|
||||||
Классический слой прямого распространения (FeedForward, или FFN) для архитектуры Transformer.
|
FeedForward — классический позиционно-независимый блок для Transformer, применяется к каждому токену отдельно.
|
||||||
|
|
||||||
Этот слой состоит из двух линейных преобразований с расширением внутренней размерности
|
Назначение и роль:
|
||||||
в 4 раза и механизмом dropout для регуляризации. Между линейными слоями применяется
|
------------------
|
||||||
активация ReLU.
|
- Реализует двухслойную (или более сложную) нейронную сеть, которая обрабатывает каждый токен ПОРЯДОЧНО независимо (по последней измерении).
|
||||||
|
- Дает модели "нелинейную мощность": любой токен может быть переосмыслен вне глобального контекста.
|
||||||
|
- После слоя внимания (MHA) FFN помогает связать смысл локальных (внутри токена) “скрытых” значений.
|
||||||
|
|
||||||
Научная суть:
|
Архитектурные детали:
|
||||||
- После внимания каждому токену применяется одинаковая двухслойная нейросеть.
|
---------------------
|
||||||
- Дает глубокую нелинейность; позволяет модели не только сопоставлять, но и моделировать сложные связи между токенами.
|
- Обычно используется блок: (Linear → Activation → Dropout → Linear → Dropout)
|
||||||
- Изначально предложен в «Attention is All You Need» (Vaswani et al., 2017).
|
- В современных LLM обычно в 4 раза расширяют скрытый слой (inner_dim = 4 * emb_size).
|
||||||
|
- Активация часто GELU или SiLU (Swish), иногда SwiGLU, ReGLU, GeGLU (см. PaLM, Llama).
|
||||||
|
|
||||||
Формула:
|
Формула (обычная версия):
|
||||||
FFN(x) = Dropout(W2·act(W1·x))
|
-------------------------
|
||||||
где act — ReLU, GELU и др., обычно expansion x4.
|
FFN(x) = Linear2(Dropout(Activation(Linear1(x))))
|
||||||
|
где Linear1: [emb_size → 4*emb_size], Activation: GELU/SiLU, Linear2: [4*emb_size → emb_size]
|
||||||
|
|
||||||
Алгоритм работы:
|
Параметры конструктора:
|
||||||
1. Входной тензор x (размерность: [batch_size, seq_len, emb_size])
|
-----------------------
|
||||||
2. Линейное преобразование: emb_size -> 4*emb_size
|
emb_size: int — размерность входа/выхода токена
|
||||||
3. Активация ReLU
|
inner_dim: int (необязательно) — размер скрытого слоя (по умолчанию 4*emb_size)
|
||||||
4. Линейное преобразование: 4*emb_size -> emb_size
|
activation: str — тип активации ('gelu', 'silu', 'relu', ...), см. варианты ниже
|
||||||
5. Применение dropout
|
dropout: float — dropout после каждой линейной проекции
|
||||||
6. Возврат результата (размерность: [batch_size, seq_len, emb_size])
|
|
||||||
|
|
||||||
Предназначение:
|
Пример использования:
|
||||||
- Добавляет нелинейность в архитектуру трансформера
|
---------------------
|
||||||
- Обеспечивает взаимодействие между различными размерностями эмбеддингов
|
>>> ffn = FeedForward(emb_size=256, dropout=0.1, activation='gelu')
|
||||||
- Работает независимо для каждого токена в последовательности
|
>>> x = torch.randn(2, 32, 256) # [batch, seq_len, emb_size]
|
||||||
|
>>> y = ffn(x)
|
||||||
|
>>> print(y.shape) # torch.Size([2, 32, 256])
|
||||||
|
|
||||||
Args:
|
Пояснения:
|
||||||
emb_size (int): размерность входных эмбеддингов
|
----------
|
||||||
dropout (float): вероятность(dropout)
|
- FeedForward не использует позицию токена — это МLP, применяемый к каждому токену независимо.
|
||||||
activation (str): нелинейная функция (relu, gelu, gelu_exact)
|
- Длина последовательности и размер батча не имеют значения (broadcast/reshape по [-2, -1]).
|
||||||
|
- Используется во всех декодерах/энкодерах трансформеров.
|
||||||
|
|
||||||
|
Подробнее смотри:
|
||||||
|
-----------------
|
||||||
|
- Vaswani et al., "Attention is All You Need": https://arxiv.org/abs/1706.03762
|
||||||
|
- GELU: https://arxiv.org/abs/1606.08415
|
||||||
|
- SwiGLU (PaLM, Llama): https://arxiv.org/abs/2002.05202
|
||||||
|
|
||||||
Пример:
|
|
||||||
>>> ff = FeedForward(emb_size=512, dropout=0.1)
|
|
||||||
>>> x = torch.randn(32, 10, 512)
|
|
||||||
>>> output = ff(x)
|
|
||||||
>>> print(output.shape) # torch.Size([32, 10, 512])
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, emb_size: int, dropout: float = 0.1, activation: str = "relu"):
|
def __init__(self, emb_size: int, dropout: float = 0.1, activation: str = "relu"):
|
||||||
"""
|
"""
|
||||||
Инициализация слоя Feed Forward Network.
|
Инициализация FeedForward блока для трансформера.
|
||||||
|
|
||||||
Args:
|
Аргументы:
|
||||||
emb_size: Размерность входных эмбеддингов
|
----------
|
||||||
dropout: Вероятность dropout для регуляризации (по умолчанию: 0.1)
|
emb_size: int
|
||||||
|
Размерность входного и выходного эмбеддинга модели.
|
||||||
|
dropout: float, по умолчанию 0.1
|
||||||
|
Dropout после линии и/или активации (уменьшает переобучение).
|
||||||
|
activation: str, по умолчанию 'gelu'
|
||||||
|
Какая нелинейность использовать ('gelu', 'silu', 'relu' и т.д.).
|
||||||
|
inner_dim: int, опционально
|
||||||
|
Размер скрытого слоя (по умолчанию 4 * emb_size, как в оригинальном Transformer).
|
||||||
|
|
||||||
|
Внутри:
|
||||||
|
-------
|
||||||
|
- Задает структуру: Linear → Activation → Dropout → Linear → Dropout.
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# Первый линейный слой (расширение размерности)
|
# Первый линейный слой (расширение размерности)
|
||||||
@@ -73,13 +91,23 @@ class FeedForward(nn.Module):
|
|||||||
|
|
||||||
def forward(self, x: torch.Tensor):
|
def forward(self, x: torch.Tensor):
|
||||||
"""
|
"""
|
||||||
Прямой проход через слой Feed Forward Network.
|
Прямой проход через FeedForward блок.
|
||||||
|
|
||||||
Args:
|
Аргументы:
|
||||||
x: Входной тензор размерности [batch_size, seq_len, emb_size]
|
----------
|
||||||
|
x : torch.Tensor
|
||||||
|
Входной тензор формы [..., emb_size] (используется на каждом токене отдельно!)
|
||||||
|
|
||||||
Returns:
|
Возвращает:
|
||||||
Тензор той же размерности, что и входной
|
-----------
|
||||||
|
torch.Tensor — выход такой же формы, как вход (только последняя размерность сохраняется).
|
||||||
|
|
||||||
|
Пример:
|
||||||
|
-------
|
||||||
|
>>> ffn = FeedForward(emb_size=256)
|
||||||
|
>>> x = torch.randn(8, 16, 256)
|
||||||
|
>>> y = ffn(x)
|
||||||
|
>>> y.shape # [8, 16, 256]
|
||||||
"""
|
"""
|
||||||
# Сохраняем dtype входных данных
|
# Сохраняем dtype входных данных
|
||||||
input_dtype = x.dtype
|
input_dtype = x.dtype
|
||||||
|
|||||||
@@ -1,22 +1,45 @@
|
|||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
import math
|
||||||
|
|
||||||
class GELU(nn.Module):
|
class GELU(nn.Module):
|
||||||
"""
|
"""
|
||||||
Гауссовская Эрф-активация (GELU, Gaussian Error Linear Unit).
|
GELU (Gaussian Error Linear Unit) — современная сглаженная функция активации для нейросетей.
|
||||||
|
|
||||||
Научная суть:
|
Мотивация и назначение:
|
||||||
- Одна из самых популярных smooth активаций для трансформеров.
|
-----------------------
|
||||||
- Дает более гибкие аппроксимации, чем ReLU/SiLU, улучшает flow градиентов для больших LLM.
|
- GELU используется во всех современных трансформерах (BERT, GPT, Llama) вместо ReLU, поскольку лучше передает градиенты и даёт более "мягкое" обучение.
|
||||||
- Используется в BERT, GPT, GPT2 и почти всех современных NLP-моделях.
|
- Формирует плавный переход между активированным и неактивированным состоянием, что улучшает устойчивость и общую производительность больших моделей.
|
||||||
Формула:
|
- Дает возможность обучению «решать», насколько сильно и в каких диапазонах нужно передавать сигнал (в отличие от жёсткого ReLU).
|
||||||
GELU(x) = 0.5 * x * (1 + tanh(\sqrt{2/π} * (x + 0.044715 x³)))
|
|
||||||
Подробнее: Hendrycks & Gimpel, "Gaussian Error Linear Units (GELUs)", arXiv:1606.08415
|
Математическая формула:
|
||||||
Пример:
|
-----------------------
|
||||||
|
GELU(x) = 0.5 * x * (1 + tanh( sqrt(2/pi) * (x + 0.044715 * x^3) ))
|
||||||
|
- Статья (Hendrycks & Gimpel, 2016): https://arxiv.org/abs/1606.08415
|
||||||
|
- В PyTorch с версии 1.4+ встроена как torch.nn.functional.gelu и torch.nn.GELU.
|
||||||
|
|
||||||
|
Как это работает:
|
||||||
|
-----------------
|
||||||
|
- Для каждого входного значения x:
|
||||||
|
- x при больших значениях (большие положительные) почти полностью передается дальше.
|
||||||
|
- x при малых (или сильно отрицательных) "заглушается" к нулю.
|
||||||
|
- На промежуточных значениях — плавный переход.
|
||||||
|
- Является аппроксимацией случайного бинома с гауссовским шумом.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
-----
|
||||||
|
Нет learnable параметров — GELU работает одинаково для всех входов.
|
||||||
|
|
||||||
|
Пример использования:
|
||||||
|
---------------------
|
||||||
>>> gelu = GELU()
|
>>> gelu = GELU()
|
||||||
>>> y = gelu(torch.tensor([-1.0, 0.0, 1.0]))
|
>>> x = torch.tensor([-2.0, 0.0, 2.0])
|
||||||
>>> print(y)
|
>>> print(gelu(x)) # тензор из плавно переходящих значений
|
||||||
|
|
||||||
|
References:
|
||||||
|
-----------
|
||||||
|
- Hendrycks & Gimpel: https://arxiv.org/abs/1606.08415
|
||||||
|
- BERT, GPT-2 papers (везде используется GELU)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@@ -24,6 +47,24 @@ class GELU(nn.Module):
|
|||||||
self.sqrt_2_over_pi = torch.sqrt(torch.tensor(2.0) / math.pi)
|
self.sqrt_2_over_pi = torch.sqrt(torch.tensor(2.0) / math.pi)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Прямой проход через GELU-активацию.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
-----
|
||||||
|
x : torch.Tensor
|
||||||
|
Любой входной тензор.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
--------
|
||||||
|
torch.Tensor — тензор той же формы, где к каждому элементу применён GELU.
|
||||||
|
|
||||||
|
Пример:
|
||||||
|
-------
|
||||||
|
>>> gelu = GELU()
|
||||||
|
>>> x = torch.linspace(-3, 3, 7)
|
||||||
|
>>> y = gelu(x)
|
||||||
|
"""
|
||||||
return (
|
return (
|
||||||
0.5
|
0.5
|
||||||
* x
|
* x
|
||||||
|
|||||||
413
llm/src/llm/core/group_query_attention.py
Normal file
413
llm/src/llm/core/group_query_attention.py
Normal file
@@ -0,0 +1,413 @@
|
|||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from llm.core.rope import RoPE
|
||||||
|
|
||||||
|
class GroupedQueryAttention(nn.Module):
|
||||||
|
"""
|
||||||
|
Grouped Query Attention (GQA)
|
||||||
|
=============================
|
||||||
|
|
||||||
|
Что такое Grouped Query Attention?
|
||||||
|
----------------------------------
|
||||||
|
Это разновидность многоголового внимания (multi-head), где для Q (query) голов может быть больше, чем для K/V (key/value) голов:
|
||||||
|
вместо стандартного MHA (num_q_heads == num_kv_heads) — меньшее число K/V разделяет информацию для всех Q.
|
||||||
|
Такой подход экономит память и ускоряет инференс, сохраняя высокое качество внимания (используется например в Mistral, Llama-2, GPT-4 и др.).
|
||||||
|
|
||||||
|
Зачем это нужно?
|
||||||
|
----------------
|
||||||
|
- Сокращает количество вычислений и размер KV-кэша в больших LLM.
|
||||||
|
- Позволяет эффективно масштабировать число attention-глав для моделирования сложных связей, не увеличивая размер всех матриц.
|
||||||
|
|
||||||
|
Как работает?
|
||||||
|
-------------
|
||||||
|
1. Q формируется для каждого query-head (их много)
|
||||||
|
2. K и V вычисляется только для меньшего числа KV-heads (обычно в 2-4 раза меньше, чем Q)
|
||||||
|
3. К/V heads дублируются (repeat) так, чтобы на каждую Q-head был свой набор
|
||||||
|
4. Всё внимание (Q,K,V) — стандартное scaled dot-product, только более эффективно и с компрессией
|
||||||
|
|
||||||
|
Поддержка дополнительных фич:
|
||||||
|
-----------------------------
|
||||||
|
- Rotary Position Encoding (RoPE) для Q и K (для относительной позиции)
|
||||||
|
- Sliding-window attention mask (можно ограничить исторический контекст, как в Mistral)
|
||||||
|
- Кэширование Q/K/V (ускоряет генерацию автоагретивно)
|
||||||
|
|
||||||
|
Аргументы конструктора:
|
||||||
|
-----------------------
|
||||||
|
num_q_heads: int — количество query голов (Q)
|
||||||
|
num_kv_heads: int — количество key/value голов (обычно меньше Q)
|
||||||
|
emb_size: int — embedding размерность
|
||||||
|
head_size: int — размер каждой attention-head
|
||||||
|
max_seq_len: int — максимальная длина последовательности
|
||||||
|
window_size: int — размер sliding window (макс. количество токенов в контексте внимания)
|
||||||
|
rope: RoPE (по желанию) — если задан, то будет применяться RoPE для Q и K
|
||||||
|
dropout: float — dropout после линейной проекции
|
||||||
|
|
||||||
|
Пример использования:
|
||||||
|
---------------------
|
||||||
|
>>> gqa = GroupedQueryAttention(num_q_heads=8, num_kv_heads=2, emb_size=256, head_size=32, max_seq_len=1024, window_size=256)
|
||||||
|
>>> x = torch.randn(2, 128, 256)
|
||||||
|
>>> y, cache = gqa(x)
|
||||||
|
>>> print(y.shape) # torch.Size([2, 128, 256])
|
||||||
|
|
||||||
|
Где прочитать подробнее:
|
||||||
|
------------------------
|
||||||
|
- LlamaV2 (Section 2.3): https://arxiv.org/abs/2307.09288
|
||||||
|
- Mistral: https://arxiv.org/abs/2310.06825
|
||||||
|
- \"Self-attention with linear complexity\" (Vila et al.): https://arxiv.org/abs/2302.05442
|
||||||
|
- Обзор: https://huggingface.co/blog/mistral
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_q_heads: int,
|
||||||
|
num_kv_heads: int,
|
||||||
|
emb_size: int,
|
||||||
|
head_size: int,
|
||||||
|
max_seq_len: int,
|
||||||
|
window_size: int,
|
||||||
|
rope: RoPE = None,
|
||||||
|
dropout: float = 0.1,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Инициализация слоя Grouped Query Attention (GQA).
|
||||||
|
|
||||||
|
Этот конструктор задаёт архитектуру эффективного внимания, где Q-голов может быть больше, чем KV-голов.
|
||||||
|
Это экономит память/вычисления и позволяет реализовать сдвигающееся "окно" внимания (Mistral-style).
|
||||||
|
|
||||||
|
Аргументы:
|
||||||
|
----------
|
||||||
|
num_q_heads : int
|
||||||
|
Количество Query attention heads (чаще всего кратно num_kv_heads, напр. 8/2, 12/4).
|
||||||
|
Чем больше — тем богаче контекстное окно каждой позиции.
|
||||||
|
num_kv_heads : int
|
||||||
|
Количество Key/Value attention heads (обычно 2-4, иногда меньше, чем Query).
|
||||||
|
В современных LLM принято уменьшать их число для оптимизации скорости/кэша.
|
||||||
|
emb_size : int
|
||||||
|
Размерность входного embedding (общий размер вектора на токен).
|
||||||
|
head_size : int
|
||||||
|
Размерность одной головы внимания.
|
||||||
|
Требуется: num_q_heads * head_size == emb_size (иначе ошибка).
|
||||||
|
max_seq_len : int
|
||||||
|
Максимальная поддерживаемая длина входной последовательности; определяет размер триангулярной (causal/sliding window) маски.
|
||||||
|
window_size : int
|
||||||
|
Размер "скользящего окна" истории — сколько токенов учитывается при слепом внимании (как у Mistral).
|
||||||
|
Чем меньше значение, тем локальнее работает внимание (и меньше память/время).
|
||||||
|
rope : RoPE, опционально
|
||||||
|
Если задан — применяется Rotary Positional Encoding к Q и K для относительного позиционного кодирования.
|
||||||
|
dropout : float, по умолчанию 0.1
|
||||||
|
Dropout после линейной проекции attention (обычно 0.1, помогает борьбе с переобучением).
|
||||||
|
|
||||||
|
Что создаётся внутри:
|
||||||
|
---------------------
|
||||||
|
- Линейные слои для получения Q, K, V из embedding.
|
||||||
|
- Буфер для causal/sliding window mask (матрица масок в зависимости от window_size и max_seq_len).
|
||||||
|
- Линейный слой для финального преобразования (объединение всех голов и возврат к emb_size).
|
||||||
|
- Dropout перед возвратом.
|
||||||
|
|
||||||
|
Пример:
|
||||||
|
-------
|
||||||
|
>>> attn = GroupedQueryAttention(
|
||||||
|
... num_q_heads=8, num_kv_heads=2, emb_size=256, head_size=32,
|
||||||
|
... max_seq_len=1024, window_size=256, dropout=0.1)
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self._num_heads = num_q_heads
|
||||||
|
self._num_kv_heads = num_kv_heads
|
||||||
|
self._head_size = head_size
|
||||||
|
self._max_seq_len = max_seq_len
|
||||||
|
self._rope = rope
|
||||||
|
self._window_size = window_size
|
||||||
|
|
||||||
|
self._q = nn.Linear(emb_size, self._num_heads * head_size)
|
||||||
|
self._k = nn.Linear(emb_size, num_kv_heads * head_size)
|
||||||
|
self._v = nn.Linear(emb_size, num_kv_heads * head_size)
|
||||||
|
|
||||||
|
# Создание causal маски
|
||||||
|
mask = self._create_sliding_window_mask(max_seq_len, self._window_size)
|
||||||
|
self.register_buffer(
|
||||||
|
"_tril_mask", mask.bool() if hasattr(torch, "bool") else mask.byte()
|
||||||
|
)
|
||||||
|
|
||||||
|
self._layer = nn.Linear(head_size * self._num_heads, emb_size)
|
||||||
|
self._dropout = nn.Dropout(dropout)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
mask: torch.Tensor = None,
|
||||||
|
use_cache: bool = True,
|
||||||
|
cache: list = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Шаг внимания в режиме Grouped Query Attention —
|
||||||
|
реализует эффективное многооконное внимание с раздельными Q/KV и sliding/casual mask.
|
||||||
|
|
||||||
|
Что происходит в этом методе:
|
||||||
|
-----------------------------
|
||||||
|
- Преобразует входной тензор x (токеновые эмбеддинги) в Q, K, V-матрицы с учётом разного числа голов для Q и KV.
|
||||||
|
- Формирует attention "маску" для sliding window, если нужно ограничить историю.
|
||||||
|
- Применяет RoPE (если задан) к Q и K, вносит позиционную информацию.
|
||||||
|
- При работе с кэшем дополняет ключи и значения предыдущими (ускоряет генерацию).
|
||||||
|
- Повторяет K/V головы для соответствия количеству Q (чтобы на каждую Q-head приходился свой KV).
|
||||||
|
- Считает обычное scaled dot-product внимание, применяет маску (не даёт видеть будущее, как и в autoregressive).
|
||||||
|
- Softmax, смешивание V на основе attention, объединение всех голов.
|
||||||
|
- Dropout и финальное линейное преобразование обратно к emb_size.
|
||||||
|
|
||||||
|
Аргументы:
|
||||||
|
----------
|
||||||
|
x : torch.Tensor
|
||||||
|
Входной тензор размера [batch, seq_len, emb_size]
|
||||||
|
mask : torch.Tensor, по умолчанию None
|
||||||
|
Матричная маска для внимания (можно передать внешнюю или использовать встроенную sliding window mask)
|
||||||
|
use_cache : bool, по умолчанию True
|
||||||
|
Нужно ли использовать/возвращать кэш KV для быстрых автогенераций.
|
||||||
|
cache : list, опционально
|
||||||
|
Ранее сохранённый кэш KV (используется для инференса по одному токену)
|
||||||
|
|
||||||
|
Возвращает:
|
||||||
|
-----------
|
||||||
|
- output: torch.Tensor формы [batch, seq_len, emb_size]
|
||||||
|
- kv_cache: кэш новых KV (если use_cache=True), иначе None
|
||||||
|
|
||||||
|
Важно:
|
||||||
|
-------
|
||||||
|
- Реализует Mistral-style attention: к каждой Q-head в итоге “приписан” собственный (но потенциально дублированный) KV-head.
|
||||||
|
- Sliding window ограничивает область вижимости в attention (ускоряет генерацию на длинных последовательностях).
|
||||||
|
- Использование RoPE опционально — но необходимо для современных архитектур LLM.
|
||||||
|
|
||||||
|
Пример:
|
||||||
|
-------
|
||||||
|
>>> attn = GroupedQueryAttention(num_q_heads=8, num_kv_heads=2, emb_size=256, head_size=32, max_seq_len=1024, window_size=256)
|
||||||
|
>>> x = torch.randn(2, 128, 256)
|
||||||
|
>>> y, kv_cache = attn(x)
|
||||||
|
>>> print(y.shape) # torch.Size([2, 128, 256])
|
||||||
|
"""
|
||||||
|
batch_size, seq_len, emb_size = x.shape
|
||||||
|
|
||||||
|
if seq_len > self._max_seq_len:
|
||||||
|
raise ValueError(
|
||||||
|
f"Длина последовательности {seq_len} превышает максимум {self._max_seq_len}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Пропустите тензор x через матрицы Wq, Wk , Wv, чтобы получить матрицы запроса, ключа и значения.
|
||||||
|
k = self._k(x) # [B, T, hs]
|
||||||
|
q = self._q(x) # [B, T, hs]
|
||||||
|
v = self._v(x) # [B, T, hs]
|
||||||
|
|
||||||
|
# Шаг 2: Изменение формы для multi-head
|
||||||
|
# [batch_size, seq_len, num_heads * head_size]
|
||||||
|
# -> [batch_size, seq_len, num_heads, head_size]
|
||||||
|
# Измените форму запроса (query) на batch_size × num_q_heads × seq_len × head_size.
|
||||||
|
q = q.reshape(batch_size, seq_len, self._num_heads, self._head_size)
|
||||||
|
|
||||||
|
# Измените форму ключа (key) и значения (value) на batch_size × num_kv_heads × seq_len × head_size.
|
||||||
|
k = k.reshape(batch_size, seq_len, self._num_kv_heads, self._head_size)
|
||||||
|
v = v.reshape(batch_size, seq_len, self._num_kv_heads, self._head_size)
|
||||||
|
|
||||||
|
|
||||||
|
# 3. Transpose: [B, T, H, hs] -> [B, H, T, hs]
|
||||||
|
q = q.transpose(1, 2)
|
||||||
|
k = k.transpose(1, 2)
|
||||||
|
v = v.transpose(1, 2)
|
||||||
|
|
||||||
|
start_pos = 0
|
||||||
|
if cache is not None:
|
||||||
|
k_cache, v_cache = cache
|
||||||
|
cache_len = k_cache.shape[2]
|
||||||
|
start_pos = cache_len
|
||||||
|
|
||||||
|
# Пропустите матрицы запроса и ключа через экземпляр rope, чтобы выполнить поворот.
|
||||||
|
if self._rope is not None:
|
||||||
|
# Применяем RoPE к Q и K (НЕ к V!)
|
||||||
|
q = self._rope(q, start_pos=start_pos) # [B, T, hs]
|
||||||
|
k = self._rope(k, start_pos=start_pos) # [B, T, hs]
|
||||||
|
|
||||||
|
# Если cache пришел, то объединяем кэш и одну строку из ключа и значения. Это будут новые key и value для последующих вычислений.
|
||||||
|
# 5. Кэширование (для autoregressive generation)
|
||||||
|
if cache is not None:
|
||||||
|
k_cache, v_cache = cache
|
||||||
|
k = torch.cat([k_cache, k], dim=2) # Concat по seq_len (dim=2)
|
||||||
|
v = torch.cat([v_cache, v], dim=2)
|
||||||
|
|
||||||
|
# Если use_cache == True, то сохраните матрицы ключа и значения для кэша (это нужно сделать до дублирования голов).
|
||||||
|
#if use_cache == True:
|
||||||
|
# # Обрезаем до последних window_size токенов
|
||||||
|
# k_to_cache = k[:, :, -self._window_size:, :]
|
||||||
|
# v_to_cache = v[:, :, -self._window_size:, :]
|
||||||
|
# kv_cache = (k_to_cache, v_to_cache)
|
||||||
|
|
||||||
|
# Продублируйте головы в тензорах ключа (key) и значения (value), чтобы получился тензор размера на batch_size × num_q_heads × seq_len × head_size.
|
||||||
|
#k = self._repeat_kv_heads(k, self._num_heads, self._num_kv_heads)
|
||||||
|
#v = self._repeat_kv_heads(v, self._num_heads, self._num_kv_heads)
|
||||||
|
k_expanded = self._repeat_kv_heads(k, self._num_heads, self._num_kv_heads)
|
||||||
|
v_expanded = self._repeat_kv_heads(v, self._num_heads, self._num_kv_heads)
|
||||||
|
|
||||||
|
# Перемножим матрицы запроса и ключа (транспонированную), чтобы вычислить матрицу внимания.
|
||||||
|
# И разделить все значения в матрице внимания на корень из head_size.
|
||||||
|
scores = q @ k_expanded.transpose(-2, -1) / (self._head_size ** 0.5)
|
||||||
|
|
||||||
|
# 8. Применение маски
|
||||||
|
k_seq_len = k_expanded.size(2) # Длина K после concat с кэшем
|
||||||
|
|
||||||
|
if cache is None:
|
||||||
|
# Случай 1: Без кэша - полная квадратная маска
|
||||||
|
# scores: [B, H, seq_len, seq_len]
|
||||||
|
# Применяем маску [:seq_len, :seq_len]
|
||||||
|
scores = scores.masked_fill(
|
||||||
|
~self._tril_mask[:seq_len, :seq_len],
|
||||||
|
float("-inf")
|
||||||
|
)
|
||||||
|
|
||||||
|
# Применить к матрице внимания (построчно) функцию Softmax.
|
||||||
|
weights = F.softmax(scores, dim=-1)
|
||||||
|
|
||||||
|
# Перемножим матрицу внимания и матрицу значения.
|
||||||
|
x_out = weights @ v_expanded # [B, T, hs]
|
||||||
|
|
||||||
|
# Измените форму тензора на batch_size × seq_len × num_heads*head_size.
|
||||||
|
# Transpose обратно и concatenate heads
|
||||||
|
x_out = x_out.transpose(1, 2) # [B, T_q, H, hs]
|
||||||
|
x_out = x_out.contiguous() # Важно для reshape!
|
||||||
|
concatenated_attention = x_out.reshape(batch_size, seq_len, self._num_heads * self._head_size)
|
||||||
|
|
||||||
|
#concatenated_attention = x_out.reshape(batch_size, seq_len, self._num_heads * self._head_size)
|
||||||
|
|
||||||
|
# Пропустите получившийся тензор через последний линейный слой.
|
||||||
|
# 3. Проецируем в пространство эмбеддингов
|
||||||
|
projected_output = self._layer(concatenated_attention)
|
||||||
|
|
||||||
|
# 4. Применяем dropout для регуляризации
|
||||||
|
output = self._dropout(projected_output)
|
||||||
|
|
||||||
|
if use_cache:
|
||||||
|
# Обрезаем оригинальный K и V (до дублирования)
|
||||||
|
k_to_cache = k[:, :, -self._window_size:, :]
|
||||||
|
v_to_cache = v[:, :, -self._window_size:, :]
|
||||||
|
kv_cache = (k_to_cache, v_to_cache)
|
||||||
|
return output, kv_cache
|
||||||
|
else:
|
||||||
|
return output, None
|
||||||
|
|
||||||
|
def _repeat_kv_heads(
|
||||||
|
self,
|
||||||
|
kv: torch.Tensor,
|
||||||
|
num_q_heads: int,
|
||||||
|
num_kv_heads: int
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Приводит число голов K/V к числу голов Q путём поэлементного повторения (tile) KV-голов.
|
||||||
|
|
||||||
|
Зачем это нужно?
|
||||||
|
----------------
|
||||||
|
В Grouped Query Attention (Mistral, Llama-2, GPT-4 и др.) обычно num_kv_heads < num_q_heads.
|
||||||
|
Чтобы каждая Query-head могла смотреть на свою собственную (пусть и общую) KV, мы "нарезаем" или повторяем KV столько раз, сколько требуется — это экономит память и ускоряет генерацию.
|
||||||
|
|
||||||
|
Алгоритм:
|
||||||
|
---------
|
||||||
|
- kv имеет форму [batch_size, num_kv_heads, seq_len, head_size]
|
||||||
|
- Для каждого KV-head делается n_repeat = num_q_heads // num_kv_heads по head-axis (обычно целое)
|
||||||
|
- На выходе форма [batch_size, num_q_heads, seq_len, head_size], где каждый KV-head дублирован для нужного количества Q-heads.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
-----
|
||||||
|
kv : torch.Tensor
|
||||||
|
Входной тензор KV (обычно после linear layer on эмбеддинги), размер [batch_size, num_kv_heads, seq_len, head_size]
|
||||||
|
num_q_heads : int
|
||||||
|
Сколько должно быть Q-голов (их больше!)
|
||||||
|
num_kv_heads : int
|
||||||
|
Сколько KV-голов было (их меньше!)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
--------
|
||||||
|
torch.Tensor формы [batch_size, num_q_heads, seq_len, head_size], где KV-головы повторены как требуется.
|
||||||
|
|
||||||
|
Пример:
|
||||||
|
-------
|
||||||
|
num_q_heads = 8, num_kv_heads = 2
|
||||||
|
[KV0, KV1] -> [KV0, KV0, KV0, KV0, KV1, KV1, KV1, KV1]
|
||||||
|
# Каждый KV-head дублируется 4 раза, чтобы покрыть все 8 Q-heads.
|
||||||
|
"""
|
||||||
|
batch_size, num_kv_heads, seq_len, head_size = kv.shape
|
||||||
|
|
||||||
|
if num_q_heads == num_kv_heads:
|
||||||
|
# Нет необходимости дублировать
|
||||||
|
return kv
|
||||||
|
|
||||||
|
# Вычисляем сколько раз нужно повторить каждую голову
|
||||||
|
num_repeats = num_q_heads // num_kv_heads
|
||||||
|
|
||||||
|
# repeat_interleave дублирует каждую голову num_repeats раз
|
||||||
|
# [B, num_kv_heads, S, hs] -> [B, num_q_heads, S, hs]
|
||||||
|
# [B, num_kv_heads, S, hs] -> [B, num_kv_heads, 1, S, hs]
|
||||||
|
kv = kv.unsqueeze(2)
|
||||||
|
|
||||||
|
# [B, num_kv_heads, 1, S, hs] -> [B, num_kv_heads, num_repeats, S, hs]
|
||||||
|
kv = kv.repeat(1, 1, num_repeats, 1, 1)
|
||||||
|
|
||||||
|
# [B, num_kv_heads, num_repeats, S, hs] -> [B, num_q_heads, S, hs]
|
||||||
|
kv = kv.reshape(batch_size, num_q_heads, seq_len, head_size)
|
||||||
|
|
||||||
|
|
||||||
|
return kv
|
||||||
|
|
||||||
|
def _create_sliding_window_mask(
|
||||||
|
self,
|
||||||
|
max_seq_len: int,
|
||||||
|
window_size: int,
|
||||||
|
device: torch.device = None
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Создаёт маску для Sliding Window Attention (ограниченного окна внимания).
|
||||||
|
|
||||||
|
Зачем нужна эта маска?
|
||||||
|
----------------------
|
||||||
|
В современных LLM (например, Mistral) self-attention работает не по всей истории, а только в узком "скользящем окне":
|
||||||
|
каждый токен видит только предшествующие (или соседние) токены на расстоянии window_size.
|
||||||
|
Это ускоряет инференс на длинных текстах и экономит память, но сохраняет ключевые зависимости в пределах окна.
|
||||||
|
|
||||||
|
Как работает алгоритм:
|
||||||
|
----------------------
|
||||||
|
- Для каждого токена mask[i, j] == True только если токен j находится СЛЕВА и не дальше, чем window_size позиций (или сам i).
|
||||||
|
- Главное: mask всегда "нижнетреугольная" (causal), плюс полоса шириной window_size вдоль главной диагонали.
|
||||||
|
- Всё за пределами окна — False (attention нельзя).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
-----
|
||||||
|
max_seq_len : int
|
||||||
|
Максимальная длина последовательности (размер будущей attention-матрицы).
|
||||||
|
window_size : int
|
||||||
|
Сколько предыдущих токенов доступно для внимания у каждого шага (вкл. сам себя).
|
||||||
|
device : torch.device, опционально
|
||||||
|
На каком устройстве (cpu/gpu) создавать маску.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
--------
|
||||||
|
torch.Tensor
|
||||||
|
Маска внимания формы [max_seq_len, max_seq_len], где True — допускается внимание (иначе False).
|
||||||
|
|
||||||
|
Пример:
|
||||||
|
-------
|
||||||
|
>>> mask = create_sliding_window_mask(8, 3)
|
||||||
|
>>> print(mask.int())
|
||||||
|
tensor([[1, 0, 0, 0, 0, 0, 0, 0],
|
||||||
|
[1, 1, 0, 0, 0, 0, 0, 0],
|
||||||
|
[1, 1, 1, 0, 0, 0, 0, 0],
|
||||||
|
[0, 1, 1, 1, 0, 0, 0, 0],
|
||||||
|
[0, 0, 1, 1, 1, 0, 0, 0],
|
||||||
|
[0, 0, 0, 1, 1, 1, 0, 0],
|
||||||
|
[0, 0, 0, 0, 1, 1, 1, 0],
|
||||||
|
[0, 0, 0, 0, 0, 1, 1, 1]])
|
||||||
|
"""
|
||||||
|
row_indices = torch.arange(max_seq_len, device=device).unsqueeze(1) # [max_seq_len, 1]
|
||||||
|
col_indices = torch.arange(max_seq_len, device=device).unsqueeze(0) # [1, max_seq_len]
|
||||||
|
|
||||||
|
causal_mask = col_indices <= row_indices
|
||||||
|
|
||||||
|
window_mask = (row_indices - col_indices) <= window_size
|
||||||
|
|
||||||
|
mask = causal_mask & window_mask
|
||||||
|
|
||||||
|
return mask
|
||||||
@@ -1,121 +0,0 @@
|
|||||||
import torch
|
|
||||||
from torch import nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from math import sqrt
|
|
||||||
from .rope import RoPE
|
|
||||||
|
|
||||||
|
|
||||||
class HeadAttention(nn.Module):
|
|
||||||
"""
|
|
||||||
Одноголовый механизм внимания (scaled dot-product attention) — фундаментальный строительный блок всех современных Transformer.
|
|
||||||
|
|
||||||
Научная суть:
|
|
||||||
- Attention учит модель самостоятельно "выбирать" важные связи между словами, независимо от их положения.
|
|
||||||
- Механизм causal mask гарантирует невозможность "заглядывания в будущее" при генерации (авторегрессия).
|
|
||||||
|
|
||||||
Формула:
|
|
||||||
Attention(Q, K, V) = softmax(QK^T / sqrt(d_k)) · V
|
|
||||||
(Q — запросы, K — ключи, V — значения; d_k — размерность ключа)
|
|
||||||
|
|
||||||
Поддерживает Rotary Position Encoding (RoPE) для относительного позиционного кодирования.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
emb_size (int): размер входного эмбеддинга
|
|
||||||
head_size (int): размерность attention-головы
|
|
||||||
max_seq_len (int): максимальная длина последовательности
|
|
||||||
rope (RoPE, optional): экземпляр RoPE для позиций
|
|
||||||
|
|
||||||
Примечания:
|
|
||||||
- Использует нижнетреугольную маску для предотвращения "заглядывания в будущее"
|
|
||||||
- Автоматически адаптируется к разным версиям PyTorch
|
|
||||||
- Поддерживает batch-обработку входных данных
|
|
||||||
|
|
||||||
Пример использования:
|
|
||||||
>>> attention = HeadAttention(emb_size=64, head_size=32, max_seq_len=128)
|
|
||||||
>>> x = torch.randn(1, 10, 64)
|
|
||||||
>>> output, _ = attention(x)
|
|
||||||
>>> print(output.shape) # torch.Size([1, 10, 32])
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self, emb_size: int, head_size: int, max_seq_len: int, rope: RoPE = None
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self._emb_size = emb_size
|
|
||||||
self._head_size = head_size
|
|
||||||
self._max_seq_len = max_seq_len
|
|
||||||
self._rope = rope
|
|
||||||
|
|
||||||
# Линейные преобразования для Q, K, V
|
|
||||||
self._k = nn.Linear(emb_size, head_size)
|
|
||||||
self._q = nn.Linear(emb_size, head_size)
|
|
||||||
self._v = nn.Linear(emb_size, head_size)
|
|
||||||
|
|
||||||
# Создание causal маски
|
|
||||||
mask = torch.tril(torch.ones(max_seq_len, max_seq_len))
|
|
||||||
self.register_buffer(
|
|
||||||
"_tril_mask", mask.bool() if hasattr(torch, "bool") else mask.byte()
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self, x: torch.Tensor, use_cache: bool = True, cache: tuple = None
|
|
||||||
) -> tuple:
|
|
||||||
"""
|
|
||||||
Прямой проход через слой внимания.
|
|
||||||
|
|
||||||
Аргументы:
|
|
||||||
x (torch.Tensor): Входной тензор формы [batch_size, seq_len, emb_size]
|
|
||||||
|
|
||||||
Возвращает:
|
|
||||||
torch.Tensor: Выходной тензор формы [batch_size, seq_len, head_size]
|
|
||||||
|
|
||||||
Исключения:
|
|
||||||
ValueError: Если длина последовательности превышает max_seq_len
|
|
||||||
|
|
||||||
Пример внутренних преобразований:
|
|
||||||
Для входа x.shape = [2, 5, 64]:
|
|
||||||
1. Q/K/V преобразования -> [2, 5, 32]
|
|
||||||
2. Scores = Q·K^T -> [2, 5, 5]
|
|
||||||
3. После маски и softmax -> [2, 5, 5]
|
|
||||||
4. Умножение на V -> [2, 5, 32]
|
|
||||||
"""
|
|
||||||
seq_len = x.shape[1]
|
|
||||||
if seq_len > self._max_seq_len:
|
|
||||||
raise ValueError(
|
|
||||||
f"Длина последовательности {seq_len} превышает максимум {self._max_seq_len}"
|
|
||||||
)
|
|
||||||
|
|
||||||
k = self._k(x) # [B, T, hs]
|
|
||||||
q = self._q(x) # [B, T, hs]
|
|
||||||
v = self._v(x) # [B, T, hs]
|
|
||||||
|
|
||||||
start_pos = 0
|
|
||||||
if cache is not None:
|
|
||||||
k_cache, v_cache = cache
|
|
||||||
cache_len = k_cache.shape[1]
|
|
||||||
start_pos = cache_len
|
|
||||||
|
|
||||||
if self._rope is not None:
|
|
||||||
# ✅ Применяем RoPE к Q и K (НЕ к V!)
|
|
||||||
q = self._rope(q, start_pos=start_pos) # [B, T, hs]
|
|
||||||
k = self._rope(k, start_pos=start_pos) # [B, T, hs]
|
|
||||||
|
|
||||||
if cache is not None:
|
|
||||||
k_cache, v_cache = cache
|
|
||||||
k = torch.cat([k_cache, k], dim=1) # [B, cache_len + T, hs]
|
|
||||||
v = torch.cat([v_cache, v], dim=1) # [B, cache_len + T, hs]
|
|
||||||
|
|
||||||
scores = q @ k.transpose(-2, -1) / sqrt(self._head_size)
|
|
||||||
|
|
||||||
if cache is None:
|
|
||||||
scores = scores.masked_fill(
|
|
||||||
~self._tril_mask[:seq_len, :seq_len], float("-inf")
|
|
||||||
)
|
|
||||||
|
|
||||||
weights = F.softmax(scores, dim=-1)
|
|
||||||
x_out = weights @ v # [B, T, hs]
|
|
||||||
|
|
||||||
if use_cache is True:
|
|
||||||
return (x_out, (k, v))
|
|
||||||
else:
|
|
||||||
return (x_out, None)
|
|
||||||
134
llm/src/llm/core/mistral_decoder.py
Normal file
134
llm/src/llm/core/mistral_decoder.py
Normal file
@@ -0,0 +1,134 @@
|
|||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from llm.core.rms_norm import RMSNorm
|
||||||
|
from llm.core.swi_glu import SwiGLU
|
||||||
|
from llm.core.rope import RoPE
|
||||||
|
from llm.core.group_query_attention import GroupedQueryAttention
|
||||||
|
|
||||||
|
class MistralDecoder(nn.Module):
|
||||||
|
"""
|
||||||
|
MistralDecoder — стек декодирующих блоков, реализующий архитектуру Mistral-style Transformer.
|
||||||
|
|
||||||
|
Назначение:
|
||||||
|
-----------
|
||||||
|
Этот класс описывает один или несколько блоков декодера, включающих Grouped Query Attention (GQA),
|
||||||
|
sliding window attention и SwiGLU feed-forward, как реализовано в моделях Mistral и Llama 2.
|
||||||
|
|
||||||
|
Ключевые особенности архитектуры:
|
||||||
|
---------------------------------
|
||||||
|
- Использует GQA: для каждого токена вычисляется attention c раздельным числом Q и KV голов (сильно ускоряет LLM).
|
||||||
|
- Sliding Window Attention: внимание ограничено окном из window_size элементов (ускоряет обработку длинных текстов).
|
||||||
|
- Rotary Positional Embedding (RoPE): позиционная информация интегрируется вращением Q/K.
|
||||||
|
- RMSNorm перед и после внимания и FFN (устойчивое обучение).
|
||||||
|
- SwiGLU в качестве нелинейности вместо стандартного GELU (больше capacity в модели).
|
||||||
|
|
||||||
|
Аргументы конструктора:
|
||||||
|
-----------------------
|
||||||
|
num_layers : int — сколько блоков-декодеров в стеке
|
||||||
|
параметры GQA: num_q_heads, num_kv_heads, emb_size, head_size, max_seq_len, window_size, rope, dropout
|
||||||
|
- все они идут в каждый слой (блок) декодера
|
||||||
|
|
||||||
|
Пример использования:
|
||||||
|
---------------------
|
||||||
|
>>> decoder = MistralDecoder(
|
||||||
|
... num_q_heads=8, num_kv_heads=2, emb_size=256, head_size=32,
|
||||||
|
... max_seq_len=4096, window_size=256, rope=rope, dropout=0.1)
|
||||||
|
>>> x = torch.randn(2, 512, 256)
|
||||||
|
>>> out, cache = decoder(x)
|
||||||
|
>>> print(out.shape) # torch.Size([2, 512, 256])
|
||||||
|
|
||||||
|
Подробнее:
|
||||||
|
----------
|
||||||
|
- Mistral: https://arxiv.org/abs/2310.06825
|
||||||
|
- Llama 2: https://arxiv.org/abs/2307.09288
|
||||||
|
- Open LLM обзор: https://huggingface.co/blog/mistral
|
||||||
|
|
||||||
|
"""
|
||||||
|
def __init__(self,
|
||||||
|
num_q_heads: int,
|
||||||
|
num_kv_heads: int,
|
||||||
|
emb_size: int,
|
||||||
|
head_size: int,
|
||||||
|
max_seq_len: int,
|
||||||
|
window_size: int,
|
||||||
|
rope: RoPE,
|
||||||
|
dropout: float = 0.1
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Инициализация стека декодеров MistralDecoder.
|
||||||
|
|
||||||
|
Аргументы:
|
||||||
|
----------
|
||||||
|
num_layers : int
|
||||||
|
Сколько слоёв (декодеров/GQA-блоков) собрать в стек.
|
||||||
|
num_q_heads : int
|
||||||
|
Количество Query-heads в attention (их больше, экономит память).
|
||||||
|
num_kv_heads : int
|
||||||
|
Количество Key/Value-heads в attention (их меньше для быстрой генерации).
|
||||||
|
emb_size : int
|
||||||
|
Размерность embedding (должна делиться на num_q_heads без остатка).
|
||||||
|
head_size : int
|
||||||
|
Размер одного attention head.
|
||||||
|
max_seq_len : int
|
||||||
|
Максимально обрабатываемая длина последовательности.
|
||||||
|
window_size : int
|
||||||
|
Размер окна для sliding window attention.
|
||||||
|
rope : RoPE
|
||||||
|
Rotary Positional Embedding для Q/K.
|
||||||
|
dropout : float, опционально
|
||||||
|
Dropout на каждом attention/FFN (по умолчанию 0.1).
|
||||||
|
|
||||||
|
Внутри:
|
||||||
|
-------
|
||||||
|
- Собираются num_layers Sequential-блоков из GQA + SwiGLU + RMSNorm.
|
||||||
|
- Все параметры передаются в каждый слой (блок).
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self._heads = GroupedQueryAttention(
|
||||||
|
num_q_heads=num_q_heads,
|
||||||
|
num_kv_heads=num_kv_heads,
|
||||||
|
emb_size=emb_size,
|
||||||
|
head_size=head_size,
|
||||||
|
max_seq_len=max_seq_len,
|
||||||
|
window_size=window_size,
|
||||||
|
rope=rope,
|
||||||
|
dropout=dropout
|
||||||
|
)
|
||||||
|
self._ff = SwiGLU(emb_size=emb_size, dropout=dropout)
|
||||||
|
self._norm1 = RMSNorm(emb_size)
|
||||||
|
self._norm2 = RMSNorm(emb_size)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, mask: torch.Tensor = None, use_cache: bool = True, cache: list = None) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Прямой проход через стек MistralDecoder.
|
||||||
|
|
||||||
|
Аргументы:
|
||||||
|
----------
|
||||||
|
x : torch.Tensor
|
||||||
|
Входные эмбеддинги (обычно shape [batch, seq_len, emb_size]).
|
||||||
|
use_cache : bool, по умолчанию True
|
||||||
|
Включить ли кэширование для ускорения генерации (авторегрессия).
|
||||||
|
cache : list, опционально
|
||||||
|
Предыдущий кеш attention-блоков (или None).
|
||||||
|
|
||||||
|
Возвращает:
|
||||||
|
-----------
|
||||||
|
out : torch.Tensor
|
||||||
|
Тензор после декодирования (shape соответствует x).
|
||||||
|
new_cache : list (или None)
|
||||||
|
Новый кэш attention для дальнейшей генерации (или None, если use_cache=False).
|
||||||
|
|
||||||
|
"""
|
||||||
|
norm1_out = self._norm1(x)
|
||||||
|
attention, kv_caches = self._heads(norm1_out, mask, use_cache=use_cache, cache=cache)
|
||||||
|
out = attention + x
|
||||||
|
|
||||||
|
norm2_out = self._norm2(out)
|
||||||
|
ffn_out = self._ff(norm2_out)
|
||||||
|
|
||||||
|
if use_cache is True:
|
||||||
|
return (ffn_out + out, kv_caches)
|
||||||
|
else:
|
||||||
|
return (ffn_out + out, None)
|
||||||
@@ -1,37 +1,70 @@
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
import torch
|
import torch
|
||||||
from .head_attention import HeadAttention
|
import torch.nn.functional as F
|
||||||
from .rope import RoPE
|
from .rope import RoPE
|
||||||
|
|
||||||
|
|
||||||
class MultiHeadAttention(nn.Module):
|
class MultiHeadAttention(nn.Module):
|
||||||
"""
|
"""
|
||||||
Мультиголовый (многоголовый) механизм внимания — ключевой компонент любого Transformer.
|
Multi-Head Attention (Многоголовое внимание)
|
||||||
|
============================================
|
||||||
|
|
||||||
Научная суть:
|
Что такое Multi-Head Attention?
|
||||||
- Модель параллельно агрегирует информацию через несколько подпространств (головы),
|
-------------------------------
|
||||||
чтобы видеть разные связи в последовательности (разный контекст, локально/глобально).
|
Это ключевой компонент трансформеров, который позволяет "смотреть" на разные части предложения
|
||||||
- Каждый attention блок работает независимо, выход конкатенируется.
|
одновременно с нескольких независимых ракурсов ("голов"). Всё, что делает Single-Head Attention — только гораздо мощнее и глубже!
|
||||||
- Механизм предложен в статье "Attention is All You Need" (Vaswani et al., 2017).
|
|
||||||
|
|
||||||
Формула внимания для одной головы:
|
Зачем это нужно?
|
||||||
Attention(Q, K, V) = softmax(QK^T/sqrt(d_k))·V
|
----------------
|
||||||
Мультиголовый:
|
- Модель может учиться одновременно учитывать и локальные, и глобальные взаимосвязи между токенами.
|
||||||
MultiHead(Q, K, V) = Concat([head_i])*W^O
|
- Каждая attention head "ловит" свой собственный смысл/зависимости, и на выходе они объединяются.
|
||||||
|
- Это значительно улучшает понимание сложных зависимостей в тексте, особенно на длинных последовательностях.
|
||||||
|
|
||||||
Args:
|
Как работает алгоритм? (основная схема)
|
||||||
num_heads (int): количество attention "голов"
|
---------------------------------------
|
||||||
emb_size (int): размерности входа и выхода
|
1. Генерируются Q, K, V (query, key, value) — по отдельной проекции для каждой головы.
|
||||||
head_size (int): размер одной attention-головы (emb_size/num_heads)
|
2. Для каждой головы: attention(Q, K, V) = softmax(Q·K^T / sqrt(d)) · V
|
||||||
max_seq_len (int): максимальная длина последовательности
|
3. Все головы "склеиваются" (concatenate) и прогоняются через общий финальный линейный слой.
|
||||||
rope (RoPE, optional): если задан, используется Rotary Positional Encoding
|
|
||||||
dropout (float): вероятность регуляризации
|
Почему это работает?
|
||||||
|
--------------------
|
||||||
|
- Даёт трансформеру многомерное восприятие текста.
|
||||||
|
- Позволяет эффективно обучаться на задачах, где порядок и "дальние" связи важнее, чем простое соседство.
|
||||||
|
|
||||||
|
Что принимается на вход:
|
||||||
|
------------------------
|
||||||
|
- x: shape [batch, seq_len, embed_dim] — обычный batched-embed тензор.
|
||||||
|
- mask (опционально): shape [seq_len, seq_len] — маска для автогерерации или causal attention.
|
||||||
|
|
||||||
|
Какие параметры важны:
|
||||||
|
----------------------
|
||||||
|
- num_heads: сколько attention heads внутри (обычно 4, 8, 16...).
|
||||||
|
- embed_dim: исходная размерность входного тензора.
|
||||||
|
- head_size: размер одной attention-head (обычно embed_dim // num_heads).
|
||||||
|
- max_seq_len: максимальная длина последовательности для маски.
|
||||||
|
|
||||||
|
Что возвращает:
|
||||||
|
---------------
|
||||||
|
- output: shape [batch, seq_len, embed_dim] — результат применения всех attention heads.
|
||||||
|
- (опционально) cache: кэш для Q/K/V (нужно для генерации по одному токену).
|
||||||
|
|
||||||
|
Особенности реализации:
|
||||||
|
-----------------------
|
||||||
|
- Оптимизированно работает через матричные умножения (без python for циклов!).
|
||||||
|
- Включена поддержка causal attention (маска, предотвращающая «заглядывание в будущее»).
|
||||||
|
- Является ядром любого трансформера (и LLM!).
|
||||||
|
|
||||||
Пример использования:
|
Пример использования:
|
||||||
>>> mha = MultiHeadAttention(num_heads=8, emb_size=512, head_size=64, max_seq_len=1024)
|
---------------------
|
||||||
>>> x = torch.randn(2, 50, 512)
|
>>> attn = MultiHeadAttention(num_heads=8, embed_dim=256, head_size=32, max_seq_len=1024)
|
||||||
>>> out, cache = mha(x)
|
>>> x = torch.randn(2, 128, 256) # [batch, seq_len, embed_dim]
|
||||||
>>> print(out.shape)
|
>>> context, _ = attn(x)
|
||||||
|
>>> print(context.shape) # torch.Size([2, 128, 256])
|
||||||
|
|
||||||
|
Где прочитать подробнее:
|
||||||
|
-------------------------
|
||||||
|
- Attention is All You Need (Vaswani et al, 2017): https://arxiv.org/abs/1706.03762
|
||||||
|
- Illustrated Transformer (blog): https://jalammar.github.io/illustrated-transformer/
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -44,32 +77,59 @@ class MultiHeadAttention(nn.Module):
|
|||||||
dropout: float = 0.1,
|
dropout: float = 0.1,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Инициализация многоголового внимания.
|
Конструктор многоголового внимания (MultiHeadAttention).
|
||||||
|
|
||||||
Параметры:
|
Здесь создаются все параметры и внутренние слои для эффективного параллельного внимания (attention) сразу из нескольких "голов".
|
||||||
num_heads (int): Количество голов внимания. Типичные значения: 4-16
|
|
||||||
emb_size (int): Размерность входных и выходных эмбеддингов
|
|
||||||
head_size (int): Размерность каждой головы внимания (обычно emb_size // num_heads)
|
|
||||||
max_seq_len (int): Максимальная длина последовательности
|
|
||||||
dropout (float): Вероятность dropout (по умолчанию 0.1)
|
|
||||||
|
|
||||||
Контрольные значения:
|
Аргументы:
|
||||||
- num_heads * head_size должно равняться emb_size
|
----------
|
||||||
- head_size обычно выбирают 32-128
|
num_heads : int
|
||||||
- max_seq_len зависит от задачи (512 для BERT, 2048 для GPT-3)
|
Сколько attention-heads будет внутри слоя.
|
||||||
|
Каждая “голова” учится видеть уникальные зависимости в тексте. Обычно это 4, 8, 16 и т.п.
|
||||||
|
Чем больше голов — тем богаче контекст, но и больше памяти.
|
||||||
|
emb_size : int
|
||||||
|
Сколько float-значений в каждом входном векторе (размерность embedding).
|
||||||
|
Обычно это 256, 512, 768, 1024 и т.д.
|
||||||
|
head_size : int
|
||||||
|
Сколько компонент будет у каждой головы внимания.
|
||||||
|
Важно: num_heads * head_size должно ровно совпадать с emb_size!
|
||||||
|
Обычно head_size = emb_size // num_heads.
|
||||||
|
max_seq_len : int
|
||||||
|
Максимально допустимая длина последовательности для attention/маски/генерации.
|
||||||
|
Определяет размер буферов для causal mask.
|
||||||
|
rope : RoPE, по умолчанию None
|
||||||
|
Объект Rotary Positional Encoding (если хотите привнести продвинутое позиционное кодирование в attention).
|
||||||
|
Не обязателен, но нужен для современных LLM (Llama, Mistral и пр.).
|
||||||
|
dropout : float, по умолчанию 0.1
|
||||||
|
Величина dropout (регуляризации) — помогает борьбе с переобучением. Чем больше, тем сильнее регуляризация.
|
||||||
|
|
||||||
|
Внутри конструктора происходит:
|
||||||
|
-------------------------------
|
||||||
|
- Создаются три линейных слоя для Q, K, V (“где смотреть” и “что вытаскивать” в attention).
|
||||||
|
- Генерируется нижнетреугольная causal-маска (запрещает видеть будущее для автогерерации).
|
||||||
|
- Создаётся финальный линейный слой для склейки всех голов в одно пространство emb_size.
|
||||||
|
- Вводится dropout (случайное зануление, чтобы не было сильной зависимости внимания к отдельным "плейсам").
|
||||||
|
|
||||||
|
Пример:
|
||||||
|
-------
|
||||||
|
>>> attn = MultiHeadAttention(num_heads=8, emb_size=256, head_size=32, max_seq_len=1024)
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._heads = nn.ModuleList(
|
self._num_heads = num_heads
|
||||||
[
|
self._head_size = head_size
|
||||||
HeadAttention(
|
self._max_seq_len = max_seq_len
|
||||||
emb_size=emb_size,
|
self._rope = rope
|
||||||
head_size=head_size,
|
|
||||||
max_seq_len=max_seq_len,
|
self._q = nn.Linear(emb_size, num_heads * head_size)
|
||||||
rope=rope,
|
self._k = nn.Linear(emb_size, num_heads * head_size)
|
||||||
)
|
self._v = nn.Linear(emb_size, num_heads * head_size)
|
||||||
for _ in range(num_heads)
|
|
||||||
]
|
# Создание causal маски
|
||||||
|
mask = torch.tril(torch.ones(max_seq_len, max_seq_len))
|
||||||
|
self.register_buffer(
|
||||||
|
"_tril_mask", mask.bool() if hasattr(torch, "bool") else mask.byte()
|
||||||
)
|
)
|
||||||
|
|
||||||
self._layer = nn.Linear(head_size * num_heads, emb_size)
|
self._layer = nn.Linear(head_size * num_heads, emb_size)
|
||||||
self._dropout = nn.Dropout(dropout)
|
self._dropout = nn.Dropout(dropout)
|
||||||
|
|
||||||
@@ -81,61 +141,116 @@ class MultiHeadAttention(nn.Module):
|
|||||||
cache: list = None,
|
cache: list = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Прямой проход (forward):
|
Основной шаг \"многоголового внимания\": находит взаимосвязи между токенами
|
||||||
Для каждого токена оценивает "важность" остальных токенов сразу через несколько attention-блоков.
|
в последовательности сразу из нескольких “ракурсов” (attention heads).
|
||||||
|
|
||||||
Подробное описание преобразований тензоров:
|
Что делает этот метод:
|
||||||
1. Входной тензор [batch_size, seq_len, emb_size] разделяется на N голов:
|
----------------------
|
||||||
- Каждая голова получает тензор [batch_size, seq_len, head_size]
|
- Для каждого токена сравнивает его с остальными во входной последовательности.
|
||||||
2. Каждая голова вычисляет attention:
|
- Делает это одновременно через несколько attention heads (каждая head видит текст по-своему).
|
||||||
- Вход: [batch_size, seq_len, head_size]
|
- Итоговое “внимание” — это взвешенная сумма других токенов (контекста) для каждого токена.
|
||||||
- Выход: [batch_size, seq_len, head_size]
|
- Можно использовать кэш для генерации длинных последовательностей по одному токену (ускоряет инференс).
|
||||||
3. Конкатенация результатов:
|
|
||||||
- Объединенный выход: [batch_size, seq_len, num_heads * head_size]
|
|
||||||
4. Линейная проекция:
|
|
||||||
- Выход: [batch_size, seq_len, emb_size]
|
|
||||||
5. Применение dropout
|
|
||||||
|
|
||||||
Args:
|
Аргументы:
|
||||||
x (Tensor[float]): [batch, seq_len, emb_size] — вход
|
----------
|
||||||
mask (Optional[Tensor[bool]]): маска позиции [seq_len, seq_len]
|
x : torch.Tensor
|
||||||
use_cache (bool): использовать ли key-value кэш (для генерации)
|
Входной тензор формы [batch, seq_len, emb_size].
|
||||||
cache (list): предыдущие значения KV для ускорения
|
Это ваши входные эмбеддинги (обычно после token + positional embedding).
|
||||||
|
mask : torch.Tensor, опционально
|
||||||
|
Матрица формы [seq_len, seq_len], задающая “разрешения” — кто может смотреть на кого (например, causal mask).
|
||||||
|
Если не указана — используется внутренняя маска (например, для autoregressive генерации).
|
||||||
|
use_cache : bool, по умолчанию True
|
||||||
|
Нужно ли использовать кэш для KV attention (важно для ускорения генерации по одному токену).
|
||||||
|
cache : list, опционально
|
||||||
|
Предыдущий кэш Key/Value — для генерации текста по частям.
|
||||||
|
|
||||||
Returns:
|
Возвращает:
|
||||||
out (Tensor[float]): [batch, seq_len, emb_size] — результат MHA
|
-----------
|
||||||
kv_caches (list): списки новых KV-кэшей (если используется)
|
- output: torch.Tensor формы [batch, seq_len, emb_size] — результат применения multi-head attention.
|
||||||
|
- kv_caches: список новых KV для кэширования при генерации (или None).
|
||||||
|
|
||||||
Типичный паттерн:
|
Важно:
|
||||||
Вход: [batch, seq, emb] → N голов [batch, seq, head_size] →
|
-------
|
||||||
→ concat [batch, seq, N*head_size] → проекция → dropout
|
- Shape входа всегда [batch, seq_len, emb_size], выход тот же.
|
||||||
|
- При seq_len > max_seq_len выбросит ошибку (безопасно для контроля переполнения буферов).
|
||||||
Пример преобразований для emb_size=512, num_heads=8:
|
- При использовании use_cache=True кешируется только последние токены (актуально для LLM).
|
||||||
Вход: [4, 100, 512]
|
|
||||||
-> Каждая голова: [4, 100, 64]
|
|
||||||
-> После внимания: 8 x [4, 100, 64]
|
|
||||||
-> Конкатенация: [4, 100, 512]
|
|
||||||
-> Проекция: [4, 100, 512]
|
|
||||||
-> Dropout: [4, 100, 512]
|
|
||||||
|
|
||||||
Пример:
|
Пример:
|
||||||
>>> out, caches = mha(x)
|
-------
|
||||||
>>> out.shape # [batch, seq_len, emb_size]
|
>>> attn = MultiHeadAttention(num_heads=8, emb_size=256, head_size=32, max_seq_len=1024)
|
||||||
|
>>> x = torch.randn(2, 100, 256)
|
||||||
|
>>> y, kv_cache = attn(x)
|
||||||
|
>>> print(y.shape) # torch.Size([2, 100, 256])
|
||||||
"""
|
"""
|
||||||
# 1. Вычисляем attention для каждой головы
|
batch_size, seq_len, emb_size = x.shape
|
||||||
attention_results = []
|
|
||||||
for i, head in enumerate(self._heads):
|
|
||||||
head_cache = cache[i] if cache is not None else None
|
|
||||||
result = head(x, use_cache=use_cache, cache=head_cache)
|
|
||||||
attention_results.append(result)
|
|
||||||
|
|
||||||
outputs, caches = zip(*attention_results)
|
if seq_len > self._max_seq_len:
|
||||||
attention_outputs = list(outputs)
|
raise ValueError(
|
||||||
kv_caches = list(caches)
|
f"Длина последовательности {seq_len} превышает максимум {self._max_seq_len}"
|
||||||
|
)
|
||||||
|
|
||||||
# 2. Объединяем результаты всех голов
|
# Пропустите тензор x через матрицы Wq, Wk , Wv, чтобы получить матрицы запроса, ключа и значения.
|
||||||
concatenated_attention = torch.cat(attention_outputs, dim=-1)
|
k = self._k(x) # [B, T, hs]
|
||||||
|
q = self._q(x) # [B, T, hs]
|
||||||
|
v = self._v(x) # [B, T, hs]
|
||||||
|
|
||||||
|
# Шаг 2: Изменение формы для multi-head
|
||||||
|
# [batch_size, seq_len, num_heads * head_size]
|
||||||
|
# -> [batch_size, seq_len, num_heads, head_size]
|
||||||
|
q = q.reshape(batch_size, seq_len, self._num_heads, self._head_size)
|
||||||
|
k = k.reshape(batch_size, seq_len, self._num_heads, self._head_size)
|
||||||
|
v = v.reshape(batch_size, seq_len, self._num_heads, self._head_size)
|
||||||
|
|
||||||
|
|
||||||
|
# 3. Transpose: [B, T, H, hs] -> [B, H, T, hs]
|
||||||
|
q = q.transpose(1, 2)
|
||||||
|
k = k.transpose(1, 2)
|
||||||
|
v = v.transpose(1, 2)
|
||||||
|
|
||||||
|
start_pos = 0
|
||||||
|
if cache is not None:
|
||||||
|
k_cache, v_cache = cache
|
||||||
|
cache_len = k_cache.shape[2]
|
||||||
|
start_pos = cache_len
|
||||||
|
|
||||||
|
# Пропустите матрицы запроса и ключа через экземпляр rope, чтобы выполнить поворот.
|
||||||
|
if self._rope is not None:
|
||||||
|
# ✅ Применяем RoPE к Q и K (НЕ к V!)
|
||||||
|
q = self._rope(q, start_pos=start_pos) # [B, T, hs]
|
||||||
|
k = self._rope(k, start_pos=start_pos) # [B, T, hs]
|
||||||
|
|
||||||
|
# Если cache пришел, то объединяем кэш и одну строку из ключа и значения. Это будут новые key и value для последующих вычислений.
|
||||||
|
# 5. Кэширование (для autoregressive generation)
|
||||||
|
if cache is not None:
|
||||||
|
k_cache, v_cache = cache
|
||||||
|
k = torch.cat([k_cache, k], dim=2) # Concat по seq_len (dim=2)
|
||||||
|
v = torch.cat([v_cache, v], dim=2)
|
||||||
|
|
||||||
|
# Перемножим матрицы запроса и ключа (транспонированную), чтобы вычислить матрицу внимания.
|
||||||
|
# И разделить все значения в матрице внимания на корень из head_size.
|
||||||
|
scores = q @ k.transpose(-2, -1) / (self._head_size ** 0.5)
|
||||||
|
|
||||||
|
# Если cache пришел, то маску не накладываем. Иначе наложите на матрицу внимания треугольную маску, созданную при инициализации. Все скрытые значения должны быть приведены к минус бесконечности: float('-inf').
|
||||||
|
if cache is None:
|
||||||
|
scores = scores.masked_fill(
|
||||||
|
~self._tril_mask[:seq_len, :seq_len], float("-inf")
|
||||||
|
)
|
||||||
|
|
||||||
|
# Применить к матрице внимания (построчно) функцию Softmax.
|
||||||
|
weights = F.softmax(scores, dim=-1)
|
||||||
|
|
||||||
|
# Перемножим матрицу внимания и матрицу значения.
|
||||||
|
x_out = weights @ v # [B, T, hs]
|
||||||
|
|
||||||
|
# Измените форму тензора на batch_size × seq_len × num_heads*head_size.
|
||||||
|
# Transpose обратно и concatenate heads
|
||||||
|
x_out = x_out.transpose(1, 2) # [B, T_q, H, hs]
|
||||||
|
x_out = x_out.contiguous() # Важно для reshape!
|
||||||
|
concatenated_attention = x_out.reshape(batch_size, seq_len, self._num_heads * self._head_size)
|
||||||
|
|
||||||
|
#concatenated_attention = x_out.reshape(batch_size, seq_len, self._num_heads * self._head_size)
|
||||||
|
|
||||||
|
# Пропустите получившийся тензор через последний линейный слой.
|
||||||
# 3. Проецируем в пространство эмбеддингов
|
# 3. Проецируем в пространство эмбеддингов
|
||||||
projected_output = self._layer(concatenated_attention)
|
projected_output = self._layer(concatenated_attention)
|
||||||
|
|
||||||
@@ -143,6 +258,6 @@ class MultiHeadAttention(nn.Module):
|
|||||||
final_output = self._dropout(projected_output)
|
final_output = self._dropout(projected_output)
|
||||||
|
|
||||||
if use_cache is True:
|
if use_cache is True:
|
||||||
return (final_output, kv_caches)
|
return (final_output, (k, v))
|
||||||
else:
|
else:
|
||||||
return (final_output, None)
|
return (final_output, None)
|
||||||
|
|||||||
@@ -4,35 +4,67 @@ from torch import nn, Tensor
|
|||||||
|
|
||||||
class PositionalEmbeddings(nn.Module):
|
class PositionalEmbeddings(nn.Module):
|
||||||
"""
|
"""
|
||||||
Обучаемые позиционные эмбеддинги (learnable positional embeddings).
|
PositionalEmbeddings — классические позиционные эмбеддинги для трансформеров (absolute sinusoidal or learned).
|
||||||
|
|
||||||
Позиционные эмбеддинги используются в нейросетях для передачи информации
|
Назначение:
|
||||||
о позиции элементов в последовательности (например, в Transformer).
|
-----------
|
||||||
|
- Добавляет или конкатенирует форму позиционной информации к каждому входному токену (since Transformer cannot distinguish positions otherwise).
|
||||||
|
- Используется во всех \"ранних\" трансформерах (GPT, BERT, T5), чаще всего в виде learnable или синусоидальных embeddings.
|
||||||
|
|
||||||
Научная суть:
|
Архитектурные варианты:
|
||||||
- Трансформеры не используют рекуррентность, а значит сами по себе не различают порядок слов.
|
-----------------------
|
||||||
- Позиционные эмбеддинги добавляются к токеновым, чтобы сеть понимала, в каком месте последовательности находится каждый токен.
|
- Learnable positional embeddings (как в GPT-2): обычный nn.Embedding инициализируется случайно, и веса учатся вместе с моделью.
|
||||||
- Обычно реализуются как отдельная матрица (nn.Embedding), которая обучается вместе с моделью (это learnable вариант, как в GPT и BERT).
|
- Sinusoidal positional encoding (как в оригинальном Transformer): не имеет параметров, а создаётся по заданной формуле sin/cos(ω*x).
|
||||||
|
|
||||||
Args:
|
Принцип работы:
|
||||||
max_seq_len (int): максимальная длина последовательности
|
---------------
|
||||||
emb_size (int): размер вектора позиции
|
- Для каждой позиции t заполняется вектор emb_size длиной по формуле (или выбирается из weight matrix).
|
||||||
|
- Эти вектора можно либо складывать с токеновыми эмбеддингами, либо конкатенировать.
|
||||||
|
- Позволяет attention-механизму \"понимать\" порядок токенов/слов в последовательности.
|
||||||
|
|
||||||
Пример использования:
|
Формулы (Or: Vaswani et al., 2017):
|
||||||
>>> pos_encoder = PositionalEmbeddings(max_seq_len=100, emb_size=256)
|
------------------------------------
|
||||||
>>> # Получить эмбеддинги для последовательности из 10 элементов
|
PE(pos, 2i) = sin(pos / 10000^{2i/d})
|
||||||
>>> embeddings = pos_encoder(10) # Tensor shape: [10, 256]
|
PE(pos, 2i+1) = cos(pos / 10000^{2i/d})
|
||||||
>>> # Использование в модели
|
где d = emb_size, pos = позиция (int), i = индекс пары компонент.
|
||||||
>>> class MyModel(nn.Module):
|
|
||||||
... def __init__(self):
|
Аргументы конструктора:
|
||||||
... super().__init__()
|
-----------------------
|
||||||
... self.pos_emb = PositionalEmbeddings(100, 256)
|
max_seq_len: int — максимально поддерживаемая длина последовательности
|
||||||
... def forward(self, x):
|
emb_size: int — размер возвращаемого positional vector для каждой позиции
|
||||||
... pos = self.pos_emb(x.size(1))
|
(иногда выбирается вариант — learnable или фиксация через sin/cos)
|
||||||
... return x + pos # Добавляем позиционную информацию
|
|
||||||
|
Пример:
|
||||||
|
-------
|
||||||
|
>>> pos = PositionalEmbeddings(max_seq_len=1024, emb_size=256)
|
||||||
|
>>> p = pos(32) # Получить positional embeddings для 32 позиций
|
||||||
|
>>> p.shape # torch.Size([32, 256])
|
||||||
|
>>> token_emb = ... # [batch, seq_len, emb_size]
|
||||||
|
>>> encoded = token_emb + p.unsqueeze(0) # Broadcast add
|
||||||
|
|
||||||
|
References:
|
||||||
|
-----------
|
||||||
|
- Vaswani et al., \"Attention is All You Need\", 2017: https://arxiv.org/abs/1706.03762
|
||||||
|
- GPT-2 implementation: https://github.com/openai/gpt-2
|
||||||
|
- Почему positional encoding важен: https://kazemnejad.com/blog/transformer_architecture_positional_encoding/
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, max_seq_len: int, emb_size: int):
|
def __init__(self, max_seq_len: int, emb_size: int):
|
||||||
|
"""
|
||||||
|
Инициализация позиционного энкодера.
|
||||||
|
|
||||||
|
Аргументы:
|
||||||
|
----------
|
||||||
|
max_seq_len : int
|
||||||
|
Максимальная длина последовательности (builds buffer for sin/cos or embedding)
|
||||||
|
emb_size : int
|
||||||
|
Длина позиционного вектора
|
||||||
|
|
||||||
|
Внутри:
|
||||||
|
-------
|
||||||
|
- Если используется learned embedding: создаётся nn.Embedding (можно легко менять в будущем).
|
||||||
|
- Если fixed (sin/cos): вычисляется и хранится буфер (max_seq_len, emb_size).
|
||||||
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.max_seq_len = max_seq_len
|
self.max_seq_len = max_seq_len
|
||||||
self.emb_size = emb_size
|
self.emb_size = emb_size
|
||||||
@@ -42,20 +74,23 @@ class PositionalEmbeddings(nn.Module):
|
|||||||
|
|
||||||
def forward(self, seq_len: int, start_pos: int = 0) -> Tensor:
|
def forward(self, seq_len: int, start_pos: int = 0) -> Tensor:
|
||||||
"""
|
"""
|
||||||
Возвращает позиционные эмбеддинги для заданной длины последовательности.
|
Получить positional embeddings для последовательности длиной seq_len.
|
||||||
|
|
||||||
Args:
|
Аргументы:
|
||||||
seq_len (int): Длина последовательности (1 <= seq_len <= max_seq_len)
|
----------
|
||||||
|
seq_len : int
|
||||||
|
Сколько позиций сгенерировать (обычно == входная длина x)
|
||||||
|
start_pos : int, по умолчанию 0
|
||||||
|
Возможность выдать positional embeddings \"с середины\" (для autoregressive генерации)
|
||||||
|
|
||||||
Returns:
|
Возвращает:
|
||||||
Tensor: Тензор позиционных эмбеддингов формы [seq_len, emb_size]
|
-----------
|
||||||
|
torch.Tensor — positional embeddings формы [seq_len, emb_size]
|
||||||
Raises:
|
|
||||||
IndexError: Если seq_len выходит за допустимые границы
|
|
||||||
|
|
||||||
Пример:
|
Пример:
|
||||||
>>> pos_encoder = PositionalEmbeddings(100, 64)
|
-------
|
||||||
>>> emb = pos_encoder(10) # Тензор 10x64
|
>>> pos = PositionalEmbeddings(512, 128)
|
||||||
|
>>> p = pos(10) # [10, 128]
|
||||||
"""
|
"""
|
||||||
if seq_len < 1 or seq_len > self.max_seq_len:
|
if seq_len < 1 or seq_len > self.max_seq_len:
|
||||||
raise IndexError(f"Длина {seq_len} должна быть от 1 до {self.max_seq_len}")
|
raise IndexError(f"Длина {seq_len} должна быть от 1 до {self.max_seq_len}")
|
||||||
|
|||||||
@@ -24,35 +24,63 @@ from typing import Optional
|
|||||||
|
|
||||||
class RMSNorm(nn.Module):
|
class RMSNorm(nn.Module):
|
||||||
"""
|
"""
|
||||||
RMS Normalization (Root Mean Square Layer Normalization).
|
RMSNorm (Root Mean Square Layer Normalization) — простая и эффективная альтернатива LayerNorm.
|
||||||
|
|
||||||
Нормализует входные данные по последнему измерению используя среднеквадратичное
|
Назначение:
|
||||||
значение вместо среднего, как в стандартном LayerNorm.
|
-----------
|
||||||
|
- Нормализует входной тензор по последнему измерению только с помощью RMS (root mean square), без вычитания среднего.
|
||||||
|
- Используется в LLaMA, PaLM и других крупных языковых моделях для лучшей стабильности и ускорения обучения.
|
||||||
|
- В отличие от LayerNorm, не центрирует значения, что особенно полезно для автогерессивных трансформеров с residual-связями.
|
||||||
|
|
||||||
Научная суть:
|
Мотивация и математика:
|
||||||
- Упрощенный вариант LayerNorm без вычисления среднего, только деление на rms.
|
-----------------------
|
||||||
- Лучшая численная стабильность на больших моделях, меньше вычислений.
|
- Формула для одного слоя и вектора x:
|
||||||
- Применяется в LLaMA, PaLM и др.
|
rms = sqrt( mean( x ** 2 ) + eps )
|
||||||
|
out = w * ( x / rms )
|
||||||
|
где w — learnable scale, eps — небольшая константа для численной устойчивости.
|
||||||
|
- Нет смещения/вычитания среднего — сигнал сохраняет абсолютные значения, меньше “искажает” автоагрегатные значения на накопленных резидуалах.
|
||||||
|
|
||||||
Формула:
|
Аргументы конструктора:
|
||||||
RMSNorm(x) = (x / sqrt(mean(x²) + eps)) * w (w — обучаемый вектор)
|
-----------------------
|
||||||
|
dim : int
|
||||||
|
Размер последнего нормализуемого измерения (обычно совпадает с размером embedding/final head).
|
||||||
|
eps : float, default=1e-6
|
||||||
|
Малое значение для устойчивости (additive epsilon).
|
||||||
|
|
||||||
Args:
|
Особенности:
|
||||||
dim (int): размер последнего измерения (обычно emb_size)
|
------------
|
||||||
eps (float): для численной устойчивости
|
- Нет батч-нормализации, нет зависимости от размера батча.
|
||||||
|
- Отлично подходит для больших моделей и автогерессии — меньше шуму от residual.
|
||||||
|
|
||||||
|
Пример использования:
|
||||||
|
---------------------
|
||||||
|
>>> norm = RMSNorm(emb_size=256)
|
||||||
|
>>> x = torch.randn(4, 10, 256)
|
||||||
|
>>> out = norm(x) # возвращает tensor той же формы
|
||||||
|
|
||||||
|
References:
|
||||||
|
-----------
|
||||||
|
- Zhang & Sennrich, "Root Mean Square Layer Normalization", 2019: https://arxiv.org/abs/1910.07467
|
||||||
|
- Применение в LLaMA: https://arxiv.org/abs/2302.13971
|
||||||
|
- HuggingFace implementation: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
|
||||||
|
|
||||||
Пример:
|
|
||||||
>>> norm = RMSNorm(emb_size)
|
|
||||||
>>> out = norm(x)
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, dim: int, eps: float = 1e-6):
|
def __init__(self, dim: int, eps: float = 1e-6):
|
||||||
"""
|
"""
|
||||||
Инициализация RMSNorm слоя.
|
Инициализация RMSNorm.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dim: Размерность нормализуемого измерения
|
-----
|
||||||
eps: Малое значение для численной стабильности (по умолчанию 1e-6)
|
dim : int
|
||||||
|
Последнее нормализуемое измерение (обычно размерность embedding или hidden).
|
||||||
|
eps : float
|
||||||
|
Малое значение для устойчивости (по умолчанию 1e-6).
|
||||||
|
|
||||||
|
Внутри:
|
||||||
|
-------
|
||||||
|
- Создаётся обучаемый scale weight w для каждой компоненты dim.
|
||||||
|
- Сохраняется параметр eps для добавления к RMS.
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._eps = eps
|
self._eps = eps
|
||||||
@@ -60,16 +88,28 @@ class RMSNorm(nn.Module):
|
|||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Прямой проход через RMSNorm слой.
|
Прямой проход через RMSNorm.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
x: Входной тензор формы [..., dim]
|
-----
|
||||||
|
x : torch.Tensor
|
||||||
|
Входной тензор любого shape с последней размерностью dim.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Нормализованный тензор той же формы, что и входной
|
--------
|
||||||
|
torch.Tensor — тот же shape, что и вход x, но нормализованный по RMS на последнем измерении.
|
||||||
|
|
||||||
|
Алгоритм:
|
||||||
|
---------
|
||||||
|
- Вычислить rms = sqrt( mean( x**2, dim=-1, keepdim=True ) + eps )
|
||||||
|
- Поделить x на rms
|
||||||
|
- Помасштабировать обучаемым весом w
|
||||||
|
|
||||||
|
Пример:
|
||||||
|
-------
|
||||||
|
>>> norm = RMSNorm(256)
|
||||||
|
>>> out = norm(torch.randn(2, 10, 256))
|
||||||
|
|
||||||
Формула:
|
|
||||||
output = w * (x / sqrt(mean(x²) + eps))
|
|
||||||
"""
|
"""
|
||||||
# Вычисление RMS (Root Mean Square) по последнему измерению
|
# Вычисление RMS (Root Mean Square) по последнему измерению
|
||||||
rms = (x.pow(2).mean(-1, keepdim=True) + self._eps) ** 0.5
|
rms = (x.pow(2).mean(-1, keepdim=True) + self._eps) ** 0.5
|
||||||
|
|||||||
@@ -1,21 +1,51 @@
|
|||||||
"""
|
"""
|
||||||
Rotary Positional Embeddings (RoPE) - ротационные позиционные эмбеддинги.
|
Rotary Positional Embeddings (RoPE)
|
||||||
|
===================================
|
||||||
|
|
||||||
Реализация ротационного позиционного кодирования, которое кодирует позиционную
|
Что такое RoPE?
|
||||||
информацию через вращение векторов запросов и ключей в комплексном пространстве.
|
----------------
|
||||||
|
RoPE — это способ "вписать" информацию о позиции токенов в скрытые вектора модели трансформера.
|
||||||
|
Вместо простого сложения с абсолютным positional embedding, RoPE использует вращения векторов (как поворот стрелки на круге) внутри каждого attention head. Каждый элемент пары (вектор четного и нечетного индекса) поворачивается на угол, зависящий от позиции токена.
|
||||||
|
|
||||||
Научная статья: "RoFormer: Enhanced Transformer with Rotary Position Embedding"
|
Зачем это?
|
||||||
https://arxiv.org/abs/2104.09864
|
-----------
|
||||||
|
- RoPE реализует **относительное позиционное кодирование**: модель может сравнивать "расстояния" между токенами, а не просто помнить положение.
|
||||||
|
- Такое кодирование **улучшает генерацию длинных последовательностей** и перенос модели на тексты большей длины, чем были в обучении.
|
||||||
|
- Форма векторов и длина (норма) НЕ искажаются.
|
||||||
|
|
||||||
Математическая основа:
|
Как это работает? (главная формула)
|
||||||
Для позиции m и измерения i:
|
-------------------------------------
|
||||||
θ_i = base^(-2i/d)
|
Для каждой позиции m и пары компонент (2i, 2i+1) внутри head применяются:
|
||||||
q'_m = q_m * cos(mθ_i) + rotate(q_m) * sin(mθ_i)
|
|
||||||
|
θ_i = base^(-2i / d)
|
||||||
|
q'_{m,2i} = q_{m,2i} * cos(m * θ_i) - q_{m,2i+1} * sin(m * θ_i)
|
||||||
|
q'_{m,2i+1} = q_{m,2i+1} * cos(m * θ_i) + q_{m,2i} * sin(m * θ_i)
|
||||||
|
|
||||||
|
где d — размерность "головы" attention (head_size), base обычно 10_000.
|
||||||
|
|
||||||
|
То есть, берём каждый "вектор" (в рамках head), делим на четные/нечетные части и поворачиваем их на уникальный угол, связанный с позицией/частотой.
|
||||||
|
|
||||||
|
Архитектурные детали:
|
||||||
|
---------------------
|
||||||
|
- Ваш тензор должен быть строго 4-мерным: [batch, num_heads, seq_len, head_size].
|
||||||
|
- Размер head_size должен быть чётным!
|
||||||
|
- RoPE применяется отдельно к **Q** и **K** в механизме внимания (но не к V).
|
||||||
|
|
||||||
|
Где об этом читать:
|
||||||
|
-------------------
|
||||||
|
- RoFormer: Enhanced Transformer with Rotary Position Embedding
|
||||||
|
https://arxiv.org/abs/2104.09864
|
||||||
|
- Llama: Open and Efficient Foundation Language Models
|
||||||
|
https://arxiv.org/abs/2302.13971
|
||||||
|
- Визуализация позиционных кодировок:
|
||||||
|
https://kazemnejad.com/blog/transformer_architecture_positional_encoding/
|
||||||
|
|
||||||
|
Пример использования:
|
||||||
|
---------------------
|
||||||
|
>>> rope = RoPE(head_size=64, max_seq_len=2048)
|
||||||
|
>>> x = torch.randn(2, 8, 128, 64) # [batch, num_heads, seq_len, head_size]
|
||||||
|
>>> x_enc = rope(x) # здесь вектор x обогатится позиционной информацией
|
||||||
|
|
||||||
Преимущества:
|
|
||||||
- Относительное позиционное кодирование
|
|
||||||
- Лучшая экстраполяция на длинные последовательности
|
|
||||||
- Сохранение нормы векторов
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -25,32 +55,72 @@ from typing import Optional
|
|||||||
|
|
||||||
class RoPE(nn.Module):
|
class RoPE(nn.Module):
|
||||||
"""
|
"""
|
||||||
Rotary Positional Embeddings (RoPE) для механизма внимания.
|
Реализация RoPE (Rotary Positional Embeddings) для self-attention в трансформерах.
|
||||||
|
|
||||||
Кодирует позиционную информацию через вращение векторов запросов и ключей
|
Этот слой добавляет позиционную информацию к векторам внимания (Q, K) —
|
||||||
в многомерном пространстве с использованием синусов и косинусов.
|
не с помощью простого сложения с positional embedding, а с помощью математического
|
||||||
|
вращения (как если бы вы крутили стрелку на круге) для каждой пары компонент
|
||||||
|
(even/odd) в каждом attention head.
|
||||||
|
|
||||||
Args:
|
Формула (для каждого токена и каждой пары компонент внутри head):
|
||||||
head_size: Размерность головы внимания (должен быть четным)
|
θ_i = base^(-2i / d)
|
||||||
max_seq_len: Максимальная длина последовательности
|
out_{m,2i} = x_{m,2i} * cos(m * θ_i) - x_{m,2i+1} * sin(m * θ_i)
|
||||||
base: Базовое значение для вычисления частот (по умолчанию 10000)
|
out_{m,2i+1} = x_{m,2i+1} * cos(m * θ_i) + x_{m,2i} * sin(m * θ_i)
|
||||||
|
где d — head_size, base обычно 10_000, степень i по head axis.
|
||||||
|
|
||||||
Attributes:
|
Какие входы принимает:
|
||||||
cos_matrix: Буферизованная матрица косинусов формы [max_seq_len, head_size//2]
|
----------------------
|
||||||
sin_matrix: Буферизованная матрица синусов формы [max_seq_len, head_size//2]
|
- x: обязательно размерности [batch, num_heads, seq_len, head_size]!
|
||||||
|
- head_size (размер внимания) должен быть чётным.
|
||||||
|
- start_pos: опционально, позволяет сдвигать позиционный offset для генерации с кэшем.
|
||||||
|
|
||||||
|
Что возвращает:
|
||||||
|
---------------
|
||||||
|
- Тот же тензор (x), только со встроенной позиционной информацией (“повёрнутый” RoPE-кодировкой).
|
||||||
|
- Форма и тип выходного тензора не меняются.
|
||||||
|
|
||||||
|
Где используется:
|
||||||
|
-----------------
|
||||||
|
- В любых современных LLM (Llama, Mistral, GPT-NeoX и др.) для повышения устойчивости и generalization transformer's attention.
|
||||||
|
|
||||||
|
Пример использования:
|
||||||
|
---------------------
|
||||||
|
>>> rope = RoPE(head_size=64, max_seq_len=2048)
|
||||||
|
>>> x = torch.randn(2, 8, 128, 64) # (batch, num_heads, seq_len, head_size)
|
||||||
|
>>> x_encoded = rope(x)
|
||||||
|
|
||||||
|
Подробнее про математику и примеры с визуализацией:
|
||||||
|
---------------------------------------------------
|
||||||
|
- RoFormer: https://arxiv.org/abs/2104.09864
|
||||||
|
- Llama: https://arxiv.org/abs/2302.13971
|
||||||
|
- Демонстрация наглядно: https://kazemnejad.com/blog/transformer_architecture_positional_encoding/
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, head_size: int, max_seq_len: int, base: int = 10_000):
|
def __init__(self, head_size: int, max_seq_len: int, base: int = 10_000):
|
||||||
"""
|
"""
|
||||||
Инициализация RoPE эмбеддингов.
|
Инициализация объекта RoPE — настраивает и предвычисляет все необходимые
|
||||||
|
параметры для ротационного позиционного кодирования.
|
||||||
|
|
||||||
Args:
|
Аргументы:
|
||||||
head_size: Размерность головы внимания (должен быть четным)
|
----------
|
||||||
max_seq_len: Максимальная поддерживаемая длина последовательности
|
head_size : int
|
||||||
base: Базовое значение для вычисления частот (типично 10000)
|
Размер одного attention head (последнего измерения вектора) — сколько компонент
|
||||||
|
(float-значений) отвечает за одну "голову". Должен быть ЧЁТНЫМ числом, иначе RoPE не применим.
|
||||||
|
Обычно head_size = embed_dim // num_heads.
|
||||||
|
max_seq_len : int
|
||||||
|
Максимальная длина последовательности, которую RoPE сможет обработать.
|
||||||
|
Если ваш текст длиннее этого числа — будет ошибка! Например, для GPT2 обычно 1024, у LLaMA — до 4096.
|
||||||
|
Это число определяет размер внутренних буферов cos/sin.
|
||||||
|
base : int, по умолчанию 10_000
|
||||||
|
База для вычисления частот вращения (θ_i) для каждой компоненты.
|
||||||
|
В оригинальных статьях почти всегда используют base=10000.
|
||||||
|
Менять этот параметр не нужно, если вы не исследуете математические детали.
|
||||||
|
|
||||||
Raises:
|
Что происходит внутри:
|
||||||
AssertionError: Если head_size не четный
|
----------------------
|
||||||
|
- Проверяется чётность head_size.
|
||||||
|
- Для каждого возможного положения в пределах max_seq_len и каждой пары component высчитываются уникальные cos/sin значения (матрицы частот).
|
||||||
|
- Эти матрицы используются далее для быстрого наложения позиционного "вращения" токенов внутри attention.
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert head_size % 2 == 0, "head_size должен быть четным"
|
assert head_size % 2 == 0, "head_size должен быть четным"
|
||||||
@@ -70,28 +140,51 @@ class RoPE(nn.Module):
|
|||||||
|
|
||||||
def forward(self, x: torch.Tensor, start_pos: int = 0) -> torch.Tensor:
|
def forward(self, x: torch.Tensor, start_pos: int = 0) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Применение ротационного позиционного кодирования к входному тензору.
|
Применяет ротационное позиционное кодирование (RoPE) к входному тензору.
|
||||||
|
|
||||||
Args:
|
Что делает эта функция:
|
||||||
x: Входной тензор формы [batch_size, seq_len, head_size]
|
-----------------------
|
||||||
|
Для каждого токена в последовательности внутри каждого attention head
|
||||||
|
"поворачивает" его вектор в подпространстве (even/odd пар) на свой уникальный угол,
|
||||||
|
зависящий от позиции токена. Это позволяет attention "понимать расстояния" между токенами.
|
||||||
|
|
||||||
Returns:
|
Аргументы:
|
||||||
Тензор с примененным RoPE формы [batch_size, seq_len, head_size]
|
----------
|
||||||
|
x : torch.Tensor
|
||||||
|
Входной тензор строго формы [batch, num_heads, seq_len, head_size].
|
||||||
|
Это обычно либо Q, либо K из механизма внимания.
|
||||||
|
start_pos : int, по умолчанию 0
|
||||||
|
Сдвиг начала позиции (нужно при генерации с кэшем, почти всегда оставить 0 если не пишете автогенератор).
|
||||||
|
|
||||||
Алгоритм:
|
Возвращает:
|
||||||
1. Разделение векторов на четные и нечетные компоненты
|
-----------
|
||||||
2. Применение вращения через синусы и косинусы
|
torch.Tensor с теми же формой и типом, что и x, но уже с наложенным позиционным кодированием.
|
||||||
3. Объединение компонент обратно
|
|
||||||
|
Важно:
|
||||||
|
-------
|
||||||
|
- Если передан тензор не 4D, будет выброшено исключение!
|
||||||
|
- Не изменяет значения "на месте", всегда возвращает новый тензор.
|
||||||
|
|
||||||
|
Пример:
|
||||||
|
-------
|
||||||
|
>>> rope = RoPE(head_size=64, max_seq_len=1024)
|
||||||
|
>>> q = torch.randn(2, 8, 32, 64) # batch, num_heads, seq_len, head_size
|
||||||
|
>>> q_rope = rope(q)
|
||||||
"""
|
"""
|
||||||
batch_size, seq_len, emb_size = x.shape
|
assert x.ndim == 4, "RoPE поддерживает только 4D-вход [batch, num_heads, seq_len, head_size]"
|
||||||
|
batch_size, num_heads, seq_len, head_size = x.shape
|
||||||
|
|
||||||
# Берем нужную часть матриц и приводим к типу x
|
# Берем нужную часть матриц и приводим к типу x
|
||||||
cos = self.cos_matrix[start_pos:start_pos+seq_len].to(x.dtype) # [seq_len, head_size//2]
|
cos = self.cos_matrix[start_pos:start_pos+seq_len].to(x.dtype) # [seq_len, head_size//2]
|
||||||
sin = self.sin_matrix[start_pos:start_pos+seq_len].to(x.dtype) # [seq_len, head_size//2]
|
sin = self.sin_matrix[start_pos:start_pos+seq_len].to(x.dtype) # [seq_len, head_size//2]
|
||||||
|
|
||||||
# Разделяем на четные и нечетные компоненты
|
# Явное изменение формы для broadcasting
|
||||||
x_even = x[:, :, 0::2] # [batch_size, seq_len, head_size//2]
|
cos = cos.reshape(1, 1, seq_len, head_size // 2)
|
||||||
x_odd = x[:, :, 1::2] # [batch_size, seq_len, head_size//2]
|
sin = sin.reshape(1, 1, seq_len, head_size // 2)
|
||||||
|
|
||||||
|
# Разделяем на четные и нечетные компоненты по ПОСЛЕДНЕМУ измерению
|
||||||
|
x_even = x[..., 0::2] # [batch_size, num_heads, seq_len, head_size//2]
|
||||||
|
x_odd = x[..., 1::2] # [batch_size, num_heads, seq_len, head_size//2]
|
||||||
|
|
||||||
# Применяем поворот: q' = q * cos(mθ) + rotate(q) * sin(mθ)
|
# Применяем поворот: q' = q * cos(mθ) + rotate(q) * sin(mθ)
|
||||||
x_rotated_even = x_even * cos - x_odd * sin
|
x_rotated_even = x_even * cos - x_odd * sin
|
||||||
|
|||||||
@@ -4,18 +4,67 @@ from torch import nn
|
|||||||
|
|
||||||
class SiLU(nn.Module):
|
class SiLU(nn.Module):
|
||||||
"""
|
"""
|
||||||
SiLU (Swish) — современная активационная функция для нейросетей.
|
SiLU (Sigmoid Linear Unit, также известная как Swish) — современная функция активации для нейросетей и LLM.
|
||||||
|
|
||||||
Научная суть:
|
Назначение:
|
||||||
- Формула: $SiLU(x) = x * \sigm(x)$, где $\sigm(x)$ — сигмоида.
|
-----------
|
||||||
- Более гладкая альтернатива ReLU, улучшает поток градиентов в глубоких сетях.
|
- Формирует плавную нелинейную активацию: SiLU(x) = x * sigmoid(x).
|
||||||
- Используется во многих «state-of-the-art» архитектурах (SwiGLU, PaLM, LLaMA).
|
- Активно используется во всех новых архитектурах для больших языковых моделей (PaLM, LLaMA, Mistral, GPT-4 и др.).
|
||||||
- Также известна как Swish (Ramachandran et al, 2017).
|
- Дает лучший поток градиентов по сравнению с ReLU, SELU, GELU в глубоких слоях — позволяет делать сети больше и глубже.
|
||||||
Пример:
|
|
||||||
>>> act = SiLU()
|
Мотивация и свойства:
|
||||||
>>> x = torch.tensor([-1.0, 0.0, 1.0])
|
---------------------
|
||||||
>>> print(act(x))
|
- SiLU объединяет свойства identity (для больших x) и ReLU (для отрицательных x, где есть затухание), но более плавно.
|
||||||
|
- Позволяет проходить отрицательным значениям, а не "обрубает" как ReLU.
|
||||||
|
- Better for optimization and training dynamics in deep LLMs, приводит к более богатым аппроксимациям.
|
||||||
|
|
||||||
|
Математическая формула:
|
||||||
|
-----------------------
|
||||||
|
SiLU(x) = x * sigmoid(x)
|
||||||
|
где sigmoid(x) = 1 / (1 + exp(-x))
|
||||||
|
|
||||||
|
Сравнение с другими активациями:
|
||||||
|
--------------------------------
|
||||||
|
- ReLU(x): max(0, x) — простая отсечка
|
||||||
|
- GELU(x): плавная вероятностная активация (используется в BERT/GPT-2)
|
||||||
|
- SiLU(x): плавная альтернатива, часто лучше в современных LLM
|
||||||
|
- Swish (Ramachandran et al., 2017) = SiLU
|
||||||
|
|
||||||
|
Args:
|
||||||
|
-----
|
||||||
|
Нет learnable параметров, чисто функциональная активация.
|
||||||
|
|
||||||
|
Пример использования:
|
||||||
|
---------------------
|
||||||
|
>>> silu = SiLU()
|
||||||
|
>>> x = torch.tensor([-2.0, 0.0, 2.0])
|
||||||
|
>>> print(silu(x)) # тензор с элементами [-0.2384, 0.0, 1.7616] (примерно)
|
||||||
|
|
||||||
|
References:
|
||||||
|
-----------
|
||||||
|
- Ramachandran et al., "Searching for Activation Functions", 2017: https://arxiv.org/abs/1710.05941
|
||||||
|
- LLaMA: https://arxiv.org/abs/2302.13971
|
||||||
|
- Swish в TensorFlow: https://arxiv.org/abs/1710.05941
|
||||||
|
- Сравнение всех актив. функций: https://paperswithcode.com/method/silu
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor):
|
def forward(self, x: torch.Tensor):
|
||||||
|
"""
|
||||||
|
Применяет SiLU активацию ко всем компонентам тензора (x * sigmoid(x)).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
-----
|
||||||
|
x : torch.Tensor
|
||||||
|
Входной тензор любой формы.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
--------
|
||||||
|
torch.Tensor — тензор той же формы, каждый элемент преобразован по формуле SiLU(x).
|
||||||
|
|
||||||
|
Пример:
|
||||||
|
-------
|
||||||
|
>>> silu = SiLU()
|
||||||
|
>>> x = torch.linspace(-3, 3, 7)
|
||||||
|
>>> y = silu(x)
|
||||||
|
"""
|
||||||
return torch.sigmoid(x) * x
|
return torch.sigmoid(x) * x
|
||||||
|
|||||||
@@ -24,31 +24,55 @@ from .silu import SiLU
|
|||||||
|
|
||||||
class SwiGLU(nn.Module):
|
class SwiGLU(nn.Module):
|
||||||
"""
|
"""
|
||||||
SwiGLU (Swish-Gated Linear Unit) — современная нелинейность для архитектур LLM (LLaMA, PaLM).
|
SwiGLU (Swish-Gated Linear Unit) — эффективная feed-forward нелинейность для трансформеров (LLAMA, PaLM, Mistral).
|
||||||
|
|
||||||
Реализация SwiGLU активационной функции.
|
Назначение:
|
||||||
|
-----------
|
||||||
|
- Улучшает классический блок FeedForward (FFN) в трансформерах за счёт \"gating\" (механизма управления информационным потоком).
|
||||||
|
- Использует нелинейность SiLU (Swish) вместо ReLU или GELU, повышая capacity блока.
|
||||||
|
- Является дефолтом во всех современных LLM, начиная с PaLM, LLaMA и Mistral.
|
||||||
|
|
||||||
Состоит из трех линейных слоев и активации SiLU:
|
Формула и математика:
|
||||||
1. Gate слой + SiLU активация
|
---------------------
|
||||||
2. Up слой (линейное преобразование)
|
Пусть x — вход, then:
|
||||||
3. Element-wise multiplication gate и up
|
|
||||||
4. Down слой (линейная проекция)
|
|
||||||
|
|
||||||
Научная суть:
|
SwiGLU(x) = (SiLU(W_g x + b_g)) ⊙ (W_u x + b_u) W_d + b_d
|
||||||
- Сохраняет преимущества GLU (раздельные гейтом и телом) + мощность Swish/SiLU активации.
|
|
||||||
- Дает надежную гладкую активацию, хорошо работает на больших масштабах.
|
|
||||||
- Статья: "GLU Variants Improve Transformer" (Shazeer, 2020).
|
|
||||||
|
|
||||||
Формула:
|
Типовая реализация (как здесь, по LLAMA/Mistral):
|
||||||
SwiGLU(x) = SiLU(W_g·x) * (W_u·x)
|
gate = SiLU(Linear_gate(x)) # фитчерный \"gate\"
|
||||||
где SiLU(x) = x*sigma(x)
|
up = Linear_up(x) # пропускная ветка
|
||||||
|
mult = gate * up # поэлементное умножение (контроль информации)
|
||||||
|
out = Linear_down(mult) # финальная проекция
|
||||||
|
out = Dropout(out) # регуляризация
|
||||||
|
|
||||||
|
Почему это работает:
|
||||||
|
-------------------
|
||||||
|
- Gating позволяет информации проходить \"частично\", динамически подавляя/усиливая сигналы в hidden-space.
|
||||||
|
- SiLU обеспечивает smooth градиенты (лучше для обучения LLM).
|
||||||
|
- В экспериментах (PaLM, LLAMA) SwiGLU consistently outperforms ReLU, GELU, обычные GLU.
|
||||||
|
|
||||||
|
Параметры конструктора:
|
||||||
|
-----------------------
|
||||||
|
emb_size: int
|
||||||
|
Размерность входного (и выходного) признакового пространства.
|
||||||
|
dropout: float
|
||||||
|
Dropout после final linear (обычно около 0.1).
|
||||||
|
|
||||||
|
Пример использования:
|
||||||
|
---------------------
|
||||||
|
>>> block = SwiGLU(emb_size=512, dropout=0.1)
|
||||||
|
>>> x = torch.randn(8, 16, 512)
|
||||||
|
>>> y = block(x)
|
||||||
|
>>> print(y.shape) # torch.Size([8, 16, 512])
|
||||||
|
|
||||||
|
References:
|
||||||
|
-----------
|
||||||
|
- Shazeer, \"GLU Variants Improve Transformer\", 2020: https://arxiv.org/abs/2002.05202
|
||||||
|
- PaLM: https://arxiv.org/abs/2204.02311 (Section 4.1)
|
||||||
|
- LLaMA: https://arxiv.org/abs/2302.13971
|
||||||
|
- Mistral: https://arxiv.org/abs/2310.06825
|
||||||
|
- HuggingFace discussion: https://huggingface.co/docs/transformers/main/en/model_doc/llama
|
||||||
|
|
||||||
Args:
|
|
||||||
emb_size (int): размер входов/выходов
|
|
||||||
dropout (float): после выходной проекции
|
|
||||||
Пример:
|
|
||||||
>>> ff = SwiGLU(emb_size=512, dropout=0.1)
|
|
||||||
>>> y = ff(torch.randn(2,10,512))
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, emb_size: int, dropout: float = 0.1):
|
def __init__(self, emb_size: int, dropout: float = 0.1):
|
||||||
@@ -68,19 +92,24 @@ class SwiGLU(nn.Module):
|
|||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Прямой проход через SwiGLU слой.
|
Прямой проход через блок SwiGLU.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
x: Входной тензор формы [batch_size, seq_len, emb_size]
|
-----
|
||||||
|
x : torch.Tensor
|
||||||
|
Входной тензор формы [batch_size, seq_len, emb_size]
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Выходной тензор формы [batch_size, seq_len, emb_size]
|
--------
|
||||||
|
torch.Tensor той же формы
|
||||||
|
|
||||||
Алгоритм:
|
Алгоритм:
|
||||||
|
---------
|
||||||
1. gate = SiLU(linear_gate(x))
|
1. gate = SiLU(linear_gate(x))
|
||||||
2. up = linear_up(x)
|
2. up = linear_up(x)
|
||||||
3. output = linear_down(gate ⊙ up)
|
3. mult = gate * up # поэлементно
|
||||||
4. apply dropout
|
4. out = linear_down(mult)
|
||||||
|
5. out = dropout(out)
|
||||||
"""
|
"""
|
||||||
# Gate ветвь: линейное преобразование + активация
|
# Gate ветвь: линейное преобразование + активация
|
||||||
gate_out = self._gate(x) # [batch, seq, 4*emb]
|
gate_out = self._gate(x) # [batch, seq, 4*emb]
|
||||||
|
|||||||
@@ -5,63 +5,93 @@ from torch import Tensor
|
|||||||
|
|
||||||
class TokenEmbeddings(nn.Module):
|
class TokenEmbeddings(nn.Module):
|
||||||
"""
|
"""
|
||||||
Токеновые эмбеддинги — обучаемые векторные представления для каждого токена словаря.
|
TokenEmbeddings — обучаемый слой эмбеддингов для токенов (слов, сабслов, байтов и т.д.) в трансформерах.
|
||||||
|
|
||||||
Преобразует целочисленные индексы токенов в обучаемые векторные представления фиксированного размера.
|
Назначение:
|
||||||
Обычно используется как первый слой в нейронных сетях для задач NLP.
|
-----------
|
||||||
|
- Преобразует каждый целочисленный индекс-токен из словаря (vocab) в обучаемый dense-вектор фиксированной длины.
|
||||||
|
- Это "входной слой" для любой нейросетевой языковой модели: позволяет работать с текстом как с матрицей чисел, а не с индексами/категориальными значениями.
|
||||||
|
- Обеспечивает возможность end-to-end обучения embedding-матрицы совместно с целью модели.
|
||||||
|
|
||||||
Научная суть:
|
Мотивация и особенности:
|
||||||
- Первый шаг для любого NLP-модуля: вместо индекса токена подаём его dense-вектор.
|
------------------------
|
||||||
- Эти вектора изучаются в процессе обучения и отражают скрытые взаимосвязи между токенами.
|
- Каждый токен (индекс) получает свой learnable embedding (float-вектор).
|
||||||
- Позволяют обрабатывать тексты как матрицу чисел, а не как символы или индексы.
|
- Размерность слоя: [vocab_size, emb_size] (матрица эмбеддингов).
|
||||||
- Аналог словарных эмбеддингов в word2vec, но обучаются энд-ту-энд с моделью.
|
- Веса эмбеддингов инициализируются случайно и обучаются вместе с остальной моделью.
|
||||||
|
- Аналог таблицы эмбеддингов в word2vec/fastText, но управляется end-to-end.
|
||||||
|
- Могут использоваться с любым токенизатором (BPE, SentencePiece, WordPiece и др.).
|
||||||
|
|
||||||
|
Формула:
|
||||||
|
--------
|
||||||
|
emb(x) = W[x], где W — матрица размера [vocab_size, emb_dim], x — индексы shape [batch, seq_len]
|
||||||
|
На выходе: тензор [batch, seq_len, emb_dim]
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
vocab_size (int): размер словаря (количество уникальных токенов)
|
-----
|
||||||
emb_size (int): размерность эмбеддинга (длина вектора)
|
vocab_size: int — размер словаря/алфавита (количество уникальных токенов)
|
||||||
|
emb_size: int — размерность (длина) эмбеддинговых векторов (обычно 256/512/1024...)
|
||||||
Примечание:
|
|
||||||
- Индексы должны быть в диапазоне [0, vocab_size-1]
|
|
||||||
- Эмбеддинги инициализируются случайно и обучаются в процессе тренировки модели
|
|
||||||
|
|
||||||
Пример:
|
Пример:
|
||||||
>>> emb = TokenEmbeddings(vocab_size=10000, emb_size=256)
|
-------
|
||||||
>>> tokens = torch.tensor([[1, 2, 3]])
|
>>> embedding = TokenEmbeddings(vocab_size=5000, emb_size=256)
|
||||||
>>> vecs = emb(tokens)
|
>>> tokens = torch.tensor([[12, 47, 301], [6, 88, 413]])
|
||||||
>>> vecs.shape # torch.Size([1, 3, 256])
|
>>> vecs = embedding(tokens)
|
||||||
|
>>> print(vecs.shape) # torch.Size([2, 3, 256])
|
||||||
|
|
||||||
|
References:
|
||||||
|
-----------
|
||||||
|
- Mikolov et al., "Efficient Estimation of Word Representations in Vector Space (word2vec)", 2013
|
||||||
|
- Vaswani et al., "Attention is All You Need", 2017: https://arxiv.org/abs/1706.03762
|
||||||
|
- BPE, SentencePiece overviews: https://huggingface.co/docs/transformers/tokenizer_summary
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, vocab_size: int, emb_size: int):
|
def __init__(self, vocab_size: int, emb_size: int):
|
||||||
|
"""
|
||||||
|
Инициализация слоя эмбеддингов.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
-----
|
||||||
|
vocab_size: int
|
||||||
|
Размер словаря (уникальных токенов/индексов).
|
||||||
|
emb_size: int
|
||||||
|
Длина эмбеддингового вектора для каждого токена.
|
||||||
|
|
||||||
|
Внутри:
|
||||||
|
-------
|
||||||
|
- Создаёт nn.Embedding с [vocab_size, emb_size] learnable весами.
|
||||||
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._embedding = nn.Embedding(
|
self._embedding = nn.Embedding(
|
||||||
num_embeddings=vocab_size, embedding_dim=emb_size
|
num_embeddings=vocab_size, embedding_dim=emb_size
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
"""
|
||||||
|
Получить эмбеддинги для входных токенов.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
-----
|
||||||
|
x : torch.Tensor
|
||||||
|
Тензор shape [...], содержащий индексы токенов (каждое значение от 0 до vocab_size-1).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
--------
|
||||||
|
torch.Tensor — тензор обычной формы [..., emb_size] (на каждую позицию — свой embedding-вектор).
|
||||||
|
|
||||||
|
Пример:
|
||||||
|
-------
|
||||||
|
>>> embedding = TokenEmbeddings(vocab_size=100, emb_size=64)
|
||||||
|
>>> tokens = torch.tensor([[0, 99, 5]])
|
||||||
|
>>> vecs = embedding(tokens) # [1, 3, 64]
|
||||||
|
"""
|
||||||
return self._embedding(x)
|
return self._embedding(x)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_embeddings(self) -> int:
|
def num_embeddings(self) -> int:
|
||||||
"""Возвращает размер словаря"""
|
"""Возвращает размер словаря (количество уникальных токенов)."""
|
||||||
return self._embedding.num_embeddings
|
return self._embedding.num_embeddings
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def embedding_dim(self) -> int:
|
def embedding_dim(self) -> int:
|
||||||
"""Возвращает размерность эмбеддингов"""
|
"""Возвращает размерность эмбеддингов (длина вектора каждого токена)."""
|
||||||
return self._embedding.embedding_dim
|
return self._embedding.embedding_dim
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
# Пример использования
|
|
||||||
embedding = TokenEmbeddings(vocab_size=100, emb_size=128)
|
|
||||||
|
|
||||||
# Создаем тензор с индексами в пределах vocab_size (0-99)
|
|
||||||
tensor = torch.tensor([[11, 45, 76, 34], [34, 67, 45, 54]])
|
|
||||||
|
|
||||||
# Проверяем индексы
|
|
||||||
if (tensor >= 100).any():
|
|
||||||
raise ValueError("Some indices are out of vocabulary range (vocab_size=100)")
|
|
||||||
|
|
||||||
output = embedding(tensor)
|
|
||||||
print("Embeddings shape:", output.shape)
|
|
||||||
print(f"{output.shape} | {output.mean().item():.11f}") # Формат как в ТЗ
|
|
||||||
|
|||||||
0
llm/src/llm/datasets/__init__.py
Normal file
0
llm/src/llm/datasets/__init__.py
Normal file
120
llm/src/llm/datasets/streaming_text_dataset.py
Normal file
120
llm/src/llm/datasets/streaming_text_dataset.py
Normal file
@@ -0,0 +1,120 @@
|
|||||||
|
import torch
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
from typing import List, Any
|
||||||
|
|
||||||
|
|
||||||
|
class StreamingTextDataset(Dataset):
|
||||||
|
"""
|
||||||
|
StreamingTextDataset — потоковый датасет для LLM/NLP на базе списка строк.
|
||||||
|
|
||||||
|
Назначение:
|
||||||
|
-----------
|
||||||
|
- Позволяет эффективно обрабатывать большие текстовые выборки, итерируя по заранее подготовленному списку строк.
|
||||||
|
- При итерации строки токенизируются на лету, превращаются в примеры фиксированной длины block_size (padding/truncation внутри класса).
|
||||||
|
- Поддерживает стандартный DataLoader PyTorch.
|
||||||
|
|
||||||
|
Ключевые особенности:
|
||||||
|
---------------------
|
||||||
|
- Не требует загрузки всей коллекции токенов в RAM: поддерживает работу с любым размером датасета, если список строк заранее подготовлен.
|
||||||
|
- Каждый пример (sample) формируется при обращении; не хранит массив батчей, не использует файлы внутри.
|
||||||
|
- Поддерживает любой токенизатор с методом encode (например, BPE, SentencePiece, HF Tokenizer).
|
||||||
|
- batch_size и параллелизм (num_workers) контролируются через DataLoader.
|
||||||
|
|
||||||
|
Аргументы конструктора:
|
||||||
|
-----------------------
|
||||||
|
texts: List[str] — список строк (предварительно загруженных обучающих примеров).
|
||||||
|
tokenizer: BaseTokenizer/Any — объект с методом encode(str, **kwargs) -> List[int].
|
||||||
|
block_size: int — длина одного выходного примера в токенах (padding/truncation если нужно).
|
||||||
|
|
||||||
|
Пример использования:
|
||||||
|
---------------------
|
||||||
|
>>> texts = open("wiki_sample.txt", encoding="utf-8").read().splitlines()
|
||||||
|
>>> ds = StreamingTextDataset(texts, tokenizer=tokenizer, block_size=512)
|
||||||
|
>>> loader = torch.utils.data.DataLoader(ds, batch_size=8)
|
||||||
|
>>> for batch in loader:
|
||||||
|
... print(batch['input_ids'].shape) # torch.Size([8, 512])
|
||||||
|
|
||||||
|
Особенности:
|
||||||
|
------------
|
||||||
|
- Проектирован для бесконечного стриминга текстовых данных из больших коллекций.
|
||||||
|
- При batch_size > 1 каждый batch формируется DataLoader-ом из yield'ов этого датасета.
|
||||||
|
- Не работает с файлами напрямую, только со строками (списком).
|
||||||
|
- Подходит для обучения LLM, тестирования, дообучения, оценки на больших потоковых данных.
|
||||||
|
|
||||||
|
References:
|
||||||
|
-----------
|
||||||
|
- PyTorch IterableDataset: https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset
|
||||||
|
- HuggingFace streaming datasets: https://huggingface.co/docs/datasets/stream
|
||||||
|
- Практика масштабного обучения LLM: https://github.com/karpathy/nanoGPT/issues/182
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, texts: List[str], tokenizer: Any, block_size: int = 128):
|
||||||
|
"""
|
||||||
|
Инициализация StreamingTextDataset из списка строк.
|
||||||
|
|
||||||
|
Аргументы:
|
||||||
|
texts (List[str]): Список строк — текстовые обучающие примеры; весь датасет должен помещаться в этот список.
|
||||||
|
tokenizer (Any): Токенизатор с методом encode(text, **kwargs) -> List[int].
|
||||||
|
block_size (int, по умолчанию 128): Желаемая длина токенизированного примера (padding/truncation внутри класса).
|
||||||
|
|
||||||
|
Особенности:
|
||||||
|
- Поддерживает итеративную загрузку, эффективен для больших текстовых выборок.
|
||||||
|
- Каждый пример автоматически дополняется или усекается до block_size.
|
||||||
|
- Не читает данные из файла/буфера, а только из заранее подготовленного списка строк.
|
||||||
|
|
||||||
|
Пример:
|
||||||
|
>>> ds = StreamingTextDataset(texts=all_lines, tokenizer=tokenizer, block_size=256)
|
||||||
|
>>> for ex in ds:
|
||||||
|
... print(ex['input_ids'].shape) # torch.Size([256])
|
||||||
|
"""
|
||||||
|
self.texts = texts
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.block_size = block_size
|
||||||
|
|
||||||
|
# Получаем pad_token_id из токенизатора
|
||||||
|
self.pad_token_id = getattr(tokenizer, "pad_token_id", 0)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
"""
|
||||||
|
Возвращает количество доступных примеров в датасете.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: Число примеров (равно длине исходного списка строк).
|
||||||
|
"""
|
||||||
|
return len(self.texts)
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
"""
|
||||||
|
Получить обработанный пример по индексу из потокового датасета.
|
||||||
|
|
||||||
|
Аргументы:
|
||||||
|
idx (int): Индекс примера в исходном списке строк.
|
||||||
|
|
||||||
|
Возвращает:
|
||||||
|
dict: Словарь с тензорами для обучения LLM:
|
||||||
|
- 'input_ids': torch.Tensor формы [block_size] — индексы токенов (padding/truncation выполнены)
|
||||||
|
- 'labels': torch.Tensor формы [block_size] — целевые метки (обычно совпадают с input_ids)
|
||||||
|
|
||||||
|
Пример:
|
||||||
|
>>> item = dataset[10]
|
||||||
|
>>> assert isinstance(item, dict)
|
||||||
|
>>> assert item['input_ids'].shape == (block_size,)
|
||||||
|
>>> assert 'labels' in item
|
||||||
|
"""
|
||||||
|
text = self.texts[idx]
|
||||||
|
|
||||||
|
# Токенизация на лету
|
||||||
|
input_ids = self.tokenizer.encode(text, add_special_tokens=False)
|
||||||
|
|
||||||
|
# Обрезаем или дополняем до нужной длины
|
||||||
|
if len(input_ids) > self.block_size:
|
||||||
|
input_ids = input_ids[: self.block_size]
|
||||||
|
else:
|
||||||
|
input_ids = input_ids + [self.pad_token_id] * (
|
||||||
|
self.block_size - len(input_ids)
|
||||||
|
)
|
||||||
|
|
||||||
|
input_ids = torch.tensor(input_ids, dtype=torch.long)
|
||||||
|
labels = input_ids.clone()
|
||||||
|
|
||||||
|
return {"input_ids": input_ids, "labels": labels}
|
||||||
112
llm/src/llm/datasets/text_dataset.py
Normal file
112
llm/src/llm/datasets/text_dataset.py
Normal file
@@ -0,0 +1,112 @@
|
|||||||
|
import torch
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
from typing import List, Any
|
||||||
|
|
||||||
|
|
||||||
|
class TextDataset(Dataset):
|
||||||
|
"""
|
||||||
|
TextDataset — простой датасет для подачи обучающих токенов LLM (batch-режим или по одному примеру).
|
||||||
|
|
||||||
|
Назначение:
|
||||||
|
-----------
|
||||||
|
- Хранит последовательности текста (каждую строку или пример) в виде списка строк.
|
||||||
|
- При обращении сам токенизирует строку в последовательность индексов с помощью заданного токенизатора.
|
||||||
|
- Каждый пример автоматически усекётся или будет дополнен до фиксированной длины block_size (padding — zeros).
|
||||||
|
|
||||||
|
Формат и аргументы конструктора:
|
||||||
|
-------------------------------
|
||||||
|
texts: List[str]
|
||||||
|
Список строк, каждая из которых рассматривается как отдельный обучающий пример.
|
||||||
|
tokenizer: любой объект с методом encode(str, **kwargs) → List[int]
|
||||||
|
Обеспечивает сопоставление строки списку токенов (например, BPE, HuggingFace, SentencePiece и др.).
|
||||||
|
block_size: int, по умолчанию 128
|
||||||
|
Желаемая длина выходной последовательности (padding/truncation внутри класса).
|
||||||
|
|
||||||
|
Особенности:
|
||||||
|
------------
|
||||||
|
- Класс не работает с файлами напрямую: данные передаются готовым списком строк.
|
||||||
|
- При недостаточной длине пример дополняется паддингом (нулём или другим токеном, зависит от реализации).
|
||||||
|
- Может возвращать dict с input_ids, labels и прочими ключами (см. реализацию в функции __getitem__).
|
||||||
|
|
||||||
|
Пример использования:
|
||||||
|
---------------------
|
||||||
|
>>> with open("dataset.txt", encoding="utf-8") as f:
|
||||||
|
... texts = f.read().splitlines()
|
||||||
|
>>> dataset = TextDataset(texts, tokenizer, block_size=256)
|
||||||
|
>>> from torch.utils.data import DataLoader
|
||||||
|
>>> loader = DataLoader(dataset, batch_size=4)
|
||||||
|
>>> for item in loader:
|
||||||
|
... # item['input_ids'] для обучения LLM
|
||||||
|
|
||||||
|
References:
|
||||||
|
-----------
|
||||||
|
- Torch Dataset: https://pytorch.org/docs/stable/data.html
|
||||||
|
- Примеры LLM датасетов в open-source: https://github.com/karpathy/nanoGPT/blob/master/data/shakespeare_char/tokenize.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, texts: List[str], tokenizer: Any, block_size: int = 128):
|
||||||
|
"""
|
||||||
|
Инициализация датасета из списка строк.
|
||||||
|
|
||||||
|
Аргументы:
|
||||||
|
texts (List[str]): Список строк — каждый элемент отдельный обучающий пример.
|
||||||
|
tokenizer (Any): Токенизатор с методом encode(str, **kwargs) -> List[int].
|
||||||
|
block_size (int, по умолчанию 128): Желаемая длина результата —
|
||||||
|
длинные последовательности будут усечены, короткие — дополнены паддингом (pad_token_id или 0).
|
||||||
|
|
||||||
|
Особенности:
|
||||||
|
- Строки не фильтруются и не изменяются внутри датасета.
|
||||||
|
- Для PAD используется pad_token_id из токенизатора (если есть) либо 0.
|
||||||
|
- Dict, возвращаемый __getitem__, содержит 'input_ids' и 'labels'.
|
||||||
|
|
||||||
|
Пример:
|
||||||
|
>>> dataset = TextDataset(["hello world", "test string"], tokenizer, block_size=16)
|
||||||
|
"""
|
||||||
|
self.examples = []
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.block_size = block_size
|
||||||
|
|
||||||
|
for text in texts:
|
||||||
|
# Кодируем текст в токены
|
||||||
|
input_ids = tokenizer.encode(text, add_special_tokens=False)
|
||||||
|
|
||||||
|
# Обрезаем или дополняем до нужной длины
|
||||||
|
if len(input_ids) > block_size:
|
||||||
|
input_ids = input_ids[:block_size]
|
||||||
|
else:
|
||||||
|
# Дополняем pad_token_id
|
||||||
|
pad_token_id = getattr(tokenizer, "pad_token_id", 0)
|
||||||
|
input_ids = input_ids + [pad_token_id] * (block_size - len(input_ids))
|
||||||
|
|
||||||
|
self.examples.append(input_ids)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
"""
|
||||||
|
Возвращает количество примеров в датасете (длина списка текстов).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: Число примеров в датасете.
|
||||||
|
"""
|
||||||
|
return len(self.examples)
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
"""
|
||||||
|
Получить пример из датасета по индексу.
|
||||||
|
|
||||||
|
Аргументы:
|
||||||
|
idx (int): Индекс примера.
|
||||||
|
|
||||||
|
Возвращает:
|
||||||
|
dict: Словарь с тензорами токенов для модели:
|
||||||
|
- 'input_ids': torch.Tensor shape [block_size], индексы токенов для входа.
|
||||||
|
- 'labels': torch.Tensor shape [block_size], метки для LM задачи (обычно совпадают с input_ids).
|
||||||
|
|
||||||
|
Пример:
|
||||||
|
>>> item = dataset[7]
|
||||||
|
>>> assert isinstance(item, dict)
|
||||||
|
>>> assert item['input_ids'].shape == (block_size,)
|
||||||
|
>>> assert 'labels' in item
|
||||||
|
"""
|
||||||
|
input_ids = torch.tensor(self.examples[idx], dtype=torch.long)
|
||||||
|
labels = input_ids.clone()
|
||||||
|
return {"input_ids": input_ids, "labels": labels}
|
||||||
124
llm/src/llm/datasets/text_with_special_tokens_dataset.py
Normal file
124
llm/src/llm/datasets/text_with_special_tokens_dataset.py
Normal file
@@ -0,0 +1,124 @@
|
|||||||
|
import torch
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
from typing import List, Any
|
||||||
|
from llm.datasets.text_dataset import TextDataset
|
||||||
|
|
||||||
|
|
||||||
|
class TextWithSpecialTokensDataset(TextDataset):
|
||||||
|
"""
|
||||||
|
TextWithSpecialTokensDataset — датасет для языковых моделей с поддержкой специальных токенов (BOS, EOS, PAD).
|
||||||
|
|
||||||
|
Назначение:
|
||||||
|
-----------
|
||||||
|
- Работает с уже готовым списком строк (не с файлом!).
|
||||||
|
- Токенизирует строки с помощью заданного токенизатора, вручную вставляет специальные токены (BOS/ EOS/ PAD).
|
||||||
|
- Обрезает или дополняет каждую последовательность до длины block_size.
|
||||||
|
|
||||||
|
Аргументы конструктора:
|
||||||
|
-----------------------
|
||||||
|
texts (List[str]): Список обучающих строк (примеров).
|
||||||
|
tokenizer (Any): Любой токенизатор с методом encode(text, **kwargs).
|
||||||
|
block_size (int, default=128): Желаемая длина примера (padding/truncation).
|
||||||
|
add_bos (bool, default=False): Если True, добавляет BOS-токен в начало каждой последовательности.
|
||||||
|
add_eos (bool, default=False): Если True, добавляет EOS-токен в конец.
|
||||||
|
|
||||||
|
Особенности:
|
||||||
|
------------
|
||||||
|
- Если pad_token_id не задан — по умолчанию паддит нулями.
|
||||||
|
- Все returned примеры — dict с 'input_ids' и 'labels' (shape == block_size).
|
||||||
|
- Обрезание/дополнение учётное: BOS/EOS не "выдавливаются" обрезкой.
|
||||||
|
- Пример вызова:
|
||||||
|
>>> texts = ["пример текста", "ещё текст"]
|
||||||
|
>>> ds = TextWithSpecialTokensDataset(texts, tokenizer, block_size=16, add_bos=True, add_eos=True)
|
||||||
|
>>> out = ds[0]
|
||||||
|
>>> assert out['input_ids'].shape == (16,)
|
||||||
|
|
||||||
|
References:
|
||||||
|
-----------
|
||||||
|
- OpenAI GPT-2 data loader: https://github.com/openai/gpt-2/blob/master/src/encode.py
|
||||||
|
- HuggingFace data docs: https://huggingface.co/docs/transformers/pad_truncation
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
texts: List[str],
|
||||||
|
tokenizer: Any,
|
||||||
|
block_size: int = 128,
|
||||||
|
add_bos: bool = False,
|
||||||
|
add_eos: bool = False,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Инициализация датасета с поддержкой специальных токенов.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts (List[str]): Список строк (все ваши обучающие примеры).
|
||||||
|
tokenizer (Any): Токенизатор с методом encode(text, **kwargs).
|
||||||
|
block_size (int): Длина выходного примера.
|
||||||
|
add_bos (bool): Добавлять ли BOS токен в начало.
|
||||||
|
add_eos (bool): Добавлять ли EOS токен в конец.
|
||||||
|
"""
|
||||||
|
self.examples = []
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.block_size = block_size
|
||||||
|
self.add_bos = add_bos
|
||||||
|
self.add_eos = add_eos
|
||||||
|
|
||||||
|
for text in texts:
|
||||||
|
# Кодируем с специальными токенами
|
||||||
|
input_ids = tokenizer.encode(
|
||||||
|
text, add_special_tokens=True, add_bos_token=add_bos, add_eos_token=add_eos
|
||||||
|
)
|
||||||
|
|
||||||
|
# Учитываем специальные токены при обрезке/дополнении
|
||||||
|
effective_block_size = block_size
|
||||||
|
if add_bos:
|
||||||
|
effective_block_size -= 1
|
||||||
|
if add_eos:
|
||||||
|
effective_block_size -= 1
|
||||||
|
|
||||||
|
if len(input_ids) > effective_block_size:
|
||||||
|
input_ids = input_ids[:effective_block_size]
|
||||||
|
|
||||||
|
# Добавляем специальные токены если нужно
|
||||||
|
if (
|
||||||
|
add_bos
|
||||||
|
and hasattr(tokenizer, "bos_token_id")
|
||||||
|
and tokenizer.bos_token_id is not None
|
||||||
|
):
|
||||||
|
input_ids = [tokenizer.bos_token_id] + input_ids
|
||||||
|
if (
|
||||||
|
add_eos
|
||||||
|
and hasattr(tokenizer, "eos_token_id")
|
||||||
|
and tokenizer.eos_token_id is not None
|
||||||
|
):
|
||||||
|
input_ids = input_ids + [tokenizer.eos_token_id]
|
||||||
|
|
||||||
|
# Дополняем до полной длины
|
||||||
|
pad_token_id = getattr(tokenizer, "pad_token_id", 0)
|
||||||
|
if len(input_ids) < block_size:
|
||||||
|
input_ids = input_ids + [pad_token_id] * (block_size - len(input_ids))
|
||||||
|
|
||||||
|
self.examples.append(input_ids)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
"""
|
||||||
|
Возвращает количество примеров в датасете.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: Размер (len(self.examples)).
|
||||||
|
"""
|
||||||
|
return len(self.examples)
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
"""
|
||||||
|
Получить пример с учётом специальных токенов и паддинга.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
idx (int): Индекс в dataset.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: {'input_ids': torch.Tensor [block_size], 'labels': torch.Tensor [block_size]}
|
||||||
|
"""
|
||||||
|
input_ids = torch.tensor(self.examples[idx], dtype=torch.long)
|
||||||
|
labels = input_ids.clone()
|
||||||
|
return {"input_ids": input_ids, "labels": labels}
|
||||||
@@ -33,29 +33,75 @@ from llm.core.positional_embeddings import PositionalEmbeddings
|
|||||||
|
|
||||||
class GPT(BaseModel):
|
class GPT(BaseModel):
|
||||||
"""
|
"""
|
||||||
Original GPT (Generative Pre-trained Transformer) модель.
|
GPT (Generative Pretrained Transformer) — автогерессивная языковая модель по мотивам оригинального GPT/GPT-2 architecture.
|
||||||
|
|
||||||
Первая версия трансформерной архитектуры от OpenAI, предназначенная
|
Назначение:
|
||||||
для генеративного предобучения на текстовых данных.
|
-----------
|
||||||
|
- Позволяет предсказывать и генерировать последовательности текста, обучаясь на задаче language modeling (предсказывать следующий токен).
|
||||||
|
- Класс реализует архитектуру classic Transformer Decoder Stack с masked multi-head attention и token/positional embeddings.
|
||||||
|
- Используется как базовая модель для генерации, zero-/few-shot, задач обучения с подкреплением и пр.
|
||||||
|
|
||||||
Args:
|
Архитектурные особенности:
|
||||||
config: Словарь конфигурации с параметрами:
|
--------------------------
|
||||||
- vocab_size: Размер словаря токенов
|
- Embedding-слои для токенов (token_embeddings) и позиций (position_embeddings).
|
||||||
- embed_dim: Размерность векторных представлений
|
- Stack из N декодер-блоков (MultiHeadAttention + FeedForward + residual + LayerNorm).
|
||||||
- num_heads: Количество голов внимания
|
- Masked self-attention — каждый токен видит только свои и предыдущие, обеспечивая автогерессию.
|
||||||
- num_layers: Количество декодерных слоев
|
- LayerNorm до проекции на словарь (pre-LN).
|
||||||
- max_position_embeddings: Максимальная длина последовательности
|
- Поддержка efficient KV кэша — ускоряет autoregressive inference/generation.
|
||||||
- dropout: Вероятность dropout
|
|
||||||
|
|
||||||
Attributes:
|
Основные параметры:
|
||||||
_token_embeddings: Слой векторных представлений токенов
|
-------------------
|
||||||
_position_embeddings: Слой позиционных эмбеддингов
|
config: dict в формате {
|
||||||
_decoders: Список декодерных слоев
|
vocab_size, # размер словаря токенов
|
||||||
_norm: Финальный слой нормализации
|
embed_dim, # размерность эмбеддинга
|
||||||
_linear: Выходной линейный слой
|
num_heads, # количество attention heads
|
||||||
|
num_layers, # глубина модели (число блоков)
|
||||||
|
max_position_embeddings,
|
||||||
|
dropout
|
||||||
|
}
|
||||||
|
|
||||||
|
Формула и поток данных:
|
||||||
|
-----------------------
|
||||||
|
x -> token_embeddings -> + position_embeddings -> dropout ->
|
||||||
|
-> stack([DecoderBlock]) ->
|
||||||
|
-> LayerNorm ->
|
||||||
|
-> Linear(out_dim=vocab_size) -> output_logits
|
||||||
|
|
||||||
|
Пример использования:
|
||||||
|
---------------------
|
||||||
|
>>> gpt = GPT({...})
|
||||||
|
>>> tokens = torch.tensor([[12, 123, 44]])
|
||||||
|
>>> logits = gpt(tokens)
|
||||||
|
>>> generated = gpt.generate(tokens, max_new_tokens=10)
|
||||||
|
|
||||||
|
References:
|
||||||
|
-----------
|
||||||
|
- Radford et al., "Improving Language Understanding by Generative Pre-Training" (GPT-1, 2018)
|
||||||
|
https://cdn.openai.com/research-covers/languageunsupervised/language_understanding_paper.pdf
|
||||||
|
- Original BPE Tokenizer code: https://github.com/openai/gpt-2/blob/master/src/encoder.py
|
||||||
|
- Формула masked self-attention: Vaswani et al., "Attention is All You Need", 2017
|
||||||
|
https://arxiv.org/abs/1706.03762
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
|
"""
|
||||||
|
Инициализация модели GPT.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
-----
|
||||||
|
config: dict
|
||||||
|
Параметры архитектуры:
|
||||||
|
vocab_size: int — размер словаря токенов
|
||||||
|
embed_dim: int — размерность эмбеддинга
|
||||||
|
num_heads: int — количество attention-heads
|
||||||
|
num_layers: int — число Transformer блоков
|
||||||
|
max_position_embeddings: int — макс. длина последовательности
|
||||||
|
dropout: float — dropout
|
||||||
|
|
||||||
|
Внутри:
|
||||||
|
-------
|
||||||
|
- Создаёт слой эмбеддингов, позиционку, стек декодеров, нормализацию, линейную проекцию.
|
||||||
|
"""
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
# Инициализация слоев
|
# Инициализация слоев
|
||||||
@@ -88,13 +134,22 @@ class GPT(BaseModel):
|
|||||||
return self._max_seq_len
|
return self._max_seq_len
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, attention_mask=None) -> torch.Tensor:
|
def forward(self, x: torch.Tensor, attention_mask=None) -> torch.Tensor:
|
||||||
"""Прямой проход через GPT
|
"""
|
||||||
|
Прямой проход для получения логитов по последовательности токенов.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
x: Входной тензор [batch_size, seq_len]
|
-----
|
||||||
|
x : torch.Tensor [batch, seq_len]
|
||||||
|
Индексы входных токенов.
|
||||||
|
use_cache : bool, optional
|
||||||
|
Использовать ли кэш attention (ускоряет инференс, важно для генерации)
|
||||||
|
cache : list, optional
|
||||||
|
Список старых KV (key/value)-кэшей
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Тензор логитов [batch_size, seq_len, vocab_size]
|
--------
|
||||||
|
logits: [batch, seq_len, vocab_size] (логиты для softmax по словарю)
|
||||||
|
new_cache: кэш KV после прохода
|
||||||
"""
|
"""
|
||||||
# Проверка длины последовательности
|
# Проверка длины последовательности
|
||||||
if x.size(1) > self._max_seq_len:
|
if x.size(1) > self._max_seq_len:
|
||||||
@@ -141,97 +196,53 @@ class GPT(BaseModel):
|
|||||||
attention_mask: torch.Tensor = None, # Добавляем для совместимости с HF
|
attention_mask: torch.Tensor = None, # Добавляем для совместимости с HF
|
||||||
**kwargs, # Игнорируем остальные параметры
|
**kwargs, # Игнорируем остальные параметры
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Авторегрессивная генерация текста.
|
"""
|
||||||
|
Авторегрессивная генерация текста с поддержкой жадного поиска (greedy), вероятностного сэмплирования с температурой,
|
||||||
|
top-k и nucleus (top-p) sampling.
|
||||||
|
|
||||||
Параметры:
|
Аргументы:
|
||||||
x: Входной тензор с индексами токенов формы [batch_size, seq_len],
|
x (torch.Tensor): Входной тензор с индексами токенов, форма [batch_size, seq_len].
|
||||||
где batch_size - размер батча, seq_len - длина последовательности.
|
max_new_tokens (int): Максимальное количество новых токенов для генерации.
|
||||||
max_new_tokens: Максимальное количество новых токенов для генерации.
|
do_sample (bool): Если True — вероятностное сэмплирование; если False — жадная генерация (argmax).
|
||||||
do_sample: Флаг выбора режима генерации:
|
temperature (float): Температура для управления случайностью (>0, влияет только если do_sample=True).
|
||||||
- True: вероятностное сэмплирование
|
>1.0 — более случайно, <1.0 — более детерминированно.
|
||||||
- False: жадный поиск (argmax)
|
top_k (int, опц.): При do_sample=True ограничивает выбор top_k самых вероятных токенов (top-k sampling).
|
||||||
temperature: Параметр температуры для сэмплирования:
|
top_p (float, опц.): При do_sample=True включает top-p (nucleus) sampling: кумулятивная вероятность ≤ top_p.
|
||||||
- >1.0 - более случайные результаты
|
Должно быть в (0, 1].
|
||||||
- 1.0 - нейтральное значение
|
attention_mask (torch.Tensor, опц.): Внешняя маска внимания (для совместимости с HuggingFace).
|
||||||
- <1.0 - более предсказуемые результаты
|
**kwargs: Игнорируются.
|
||||||
Должна быть > 0 (по умолчанию: 1.0)
|
|
||||||
top_k: Если задан (и do_sample=True), используется top-k сэмплирование:
|
|
||||||
- Выбираются только top_k самых вероятных токенов
|
|
||||||
- Остальным токенам устанавливается вероятность 0
|
|
||||||
- None: отключено (по умолчанию)
|
|
||||||
top_p: Если задан (и do_sample=True), используется nucleus (top-p) сэмплирование:
|
|
||||||
- Выбираются токены с кумулятивной вероятностью ≤ top_p
|
|
||||||
- Гарантируется, что хотя бы один токен остаётся (даже если его вероятность > top_p)
|
|
||||||
- None: отключено (по умолчанию)
|
|
||||||
- Должен быть в диапазоне (0, 1]
|
|
||||||
|
|
||||||
Возвращает:
|
Возвращает:
|
||||||
torch.Tensor: Тензор с расширенной последовательностью токенов формы
|
torch.Tensor: Последовательность токенов [batch_size, seq_len + max_new_tokens].
|
||||||
[batch_size, seq_len + max_new_tokens]
|
|
||||||
|
|
||||||
Исключения:
|
Исключения:
|
||||||
ValueError: Если входная последовательность длиннее max_seq_len
|
ValueError: Если x длиннее max_seq_len модели.
|
||||||
ValueError: Если temperature <= 0
|
ValueError: Если temperature ≤ 0.
|
||||||
ValueError: Если одновременно заданы top_k и top_p
|
ValueError: Если одновременно заданы top_k и top_p.
|
||||||
ValueError: Если top_k задан и ≤ 0
|
ValueError: Если top_k ≤ 0.
|
||||||
ValueError: Если top_p задан и не в диапазоне (0, 1]
|
ValueError: Если top_p вне диапазона (0, 1].
|
||||||
|
|
||||||
Примеры:
|
Примеры:
|
||||||
>>> # Жадная генерация
|
>>> # Жадная (детерминированная) генерация
|
||||||
>>> output = model.generate(input_ids, max_new_tokens=10, do_sample=False)
|
>>> output = model.generate(input_ids, max_new_tokens=12, do_sample=False)
|
||||||
>>>
|
>>> # Вероятностная генерация с температурой
|
||||||
>>> # Вероятностная генерация с top-k
|
>>> output = model.generate(input_ids, max_new_tokens=12, do_sample=True, temperature=0.8)
|
||||||
>>> output = model.generate(input_ids, max_new_tokens=10, do_sample=True, top_k=50)
|
>>> # Top-k сэмплирование
|
||||||
>>>
|
>>> output = model.generate(input_ids, max_new_tokens=12, do_sample=True, top_k=50)
|
||||||
>>> # Nucleus sampling (top-p)
|
>>> # Top-p (nucleus) sampling
|
||||||
>>> output = model.generate(input_ids, max_new_tokens=10, do_sample=True, top_p=0.9)
|
>>> output = model.generate(input_ids, max_new_tokens=12, do_sample=True, top_p=0.92)
|
||||||
>>>
|
|
||||||
>>> # Комбинация температуры и top-k
|
>>> # Комбинация температуры и top-k
|
||||||
>>> output = model.generate(input_ids, max_new_tokens=10, do_sample=True,
|
>>> output = model.generate(input_ids, max_new_tokens=12, do_sample=True, temperature=1.0, top_k=100)
|
||||||
... temperature=0.7, top_k=50)
|
|
||||||
|
|
||||||
Примечания:
|
Примечания:
|
||||||
1. Для детерминированных результатов в режиме сэмплирования
|
- Для детерминированных выборок зафиксируйте random seed через torch.manual_seed.
|
||||||
зафиксируйте random seed (torch.manual_seed).
|
- Параметры temperature, top_k, top_p применимы только если do_sample=True.
|
||||||
2. Температура влияет только на режим сэмплирования (do_sample=True).
|
- Одновременное использование top_k и top_p не допускается.
|
||||||
3. Одновременное использование top_k и top_p запрещено.
|
- Модель всегда возвращает тензор индексов токенов; для получения логитов используйте прямой вызов forward.
|
||||||
4. При do_sample=False параметры top_k, top_p и temperature игнорируются.
|
|
||||||
|
|
||||||
Args:
|
Ссылки:
|
||||||
x (torch.Tensor): Входной тензор с индексами токенов формы [batch_size, seq_len],
|
- Holtzman et al., "The Curious Case of Neural Text Degeneration" (nucleus sampling): https://arxiv.org/abs/1904.09751
|
||||||
где batch_size - размер батча, seq_len - длина последовательности.
|
- Оригинальный GPT-2: https://cdn.openai.com/better-language-models/language-models.pdf
|
||||||
max_new_tokens (int): Максимальное количество новых токенов для генерации.
|
|
||||||
do_sample (bool): Флаг выбора режима генерации:
|
|
||||||
- True: вероятностное сэмплирование
|
|
||||||
- False: жадный поиск (argmax)
|
|
||||||
temperature (float): Параметр температуры для сэмплирования:
|
|
||||||
- >1.0 - более случайные результаты
|
|
||||||
- 1.0 - нейтральное значение
|
|
||||||
- <1.0 - более предсказуемые результаты
|
|
||||||
Должна быть > 0 (по умолчанию: 1.0)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
torch.Tensor: Тензор с расширенной последовательностью токенов формы
|
|
||||||
[batch_size, seq_len + max_new_tokens]
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: Если входная последовательность длиннее max_seq_len
|
|
||||||
ValueError: Если temperature <= 0
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
>>> # Жадная генерация
|
|
||||||
>>> output = model.generate(input_ids, max_new_tokens=10, do_sample=False)
|
|
||||||
>>>
|
|
||||||
>>> # Вероятностная генерация с температурой
|
|
||||||
>>> output = model.generate(input_ids, max_new_tokens=10, do_sample=True, temperature=0.7)
|
|
||||||
>>>
|
|
||||||
>>> # Более случайная генерация
|
|
||||||
>>> output = model.generate(input_ids, max_new_tokens=10, do_sample=True, temperature=1.5)
|
|
||||||
|
|
||||||
Note:
|
|
||||||
Для детерминированных результатов в режиме сэмплирования
|
|
||||||
зафиксируйте random seed (torch.manual_seed).
|
|
||||||
Температура влияет только на режим сэмплирования (do_sample=True).
|
|
||||||
"""
|
"""
|
||||||
for _ in range(max_new_tokens):
|
for _ in range(max_new_tokens):
|
||||||
# 1. Обрезаем вход, если последовательность слишком длинная
|
# 1. Обрезаем вход, если последовательность слишком длинная
|
||||||
|
|||||||
@@ -30,23 +30,69 @@ from llm.core.feed_forward import FeedForward
|
|||||||
|
|
||||||
class GPT2(BaseModel):
|
class GPT2(BaseModel):
|
||||||
"""
|
"""
|
||||||
GPT2 — автогерессивная языковая модель, архитектура Transformer, предложенная OpenAI.
|
GPT-2 — масштабируемый автогерессивный языковой трансформер второго поколения от OpenAI (2019).
|
||||||
|
|
||||||
Научная суть:
|
Назначение:
|
||||||
- Масштабируемый автогерессивный трансформер для предсказания токенов слева направо.
|
-----------
|
||||||
- Главное отличие от классической GPT: порядок layer normalization ПЕРЕД attention и FFN.
|
- Позволяет предсказывать и порождать последовательности текста по одному токену, будучи обученным на задаче language modeling.
|
||||||
- Используется GELU, efficient KV-cache, несет наследие классической GPT, но делает архитектуру глубже/шире.
|
- Модель реализует архитектуру decoder-only Transformer с Pre-LN (LayerNorm перед attention и FFN).
|
||||||
|
- Используется для генерации, обучения с подкреплением для RLHF, zero/few-shot inference, чат-ботов и др.
|
||||||
|
|
||||||
Args:
|
Архитектурные особенности:
|
||||||
config (dict): параметры архитектуры (vocab_size, embed_dim, num_heads, num_layers, max_position_embeddings, dropout)
|
--------------------------
|
||||||
|
- Token и positional embeddings (learnable, как в GPT-2 оригинале).
|
||||||
|
- Stack из N блоков Decoder (MultiHeadAttention с causal mask, Residual, Pre-LayerNorm, GELU FFN).
|
||||||
|
- KV attention-кэш (ускоряет autoregressive generation, критически важно для LLM).
|
||||||
|
- Использует GELU как функцию активации.
|
||||||
|
- Поддержка dropout на каждом этапе.
|
||||||
|
|
||||||
|
Основные параметры:
|
||||||
|
-------------------
|
||||||
|
config: dict — параметры модели:
|
||||||
|
vocab_size, # размер словаря токенов
|
||||||
|
embed_dim, # размерность эмбеддинга
|
||||||
|
num_heads, # количество attention голов
|
||||||
|
num_layers, # глубина модели (число блоков)
|
||||||
|
max_position_embeddings,
|
||||||
|
dropout
|
||||||
|
|
||||||
|
Процессинг:
|
||||||
|
-----------
|
||||||
|
x (индексы токенов) → token_embeddings + position_embeddings → dropout
|
||||||
|
→ stack Decoder blocks (masked attention, pre-LN)
|
||||||
|
→ LayerNorm
|
||||||
|
→ Linear(out_dim=vocab_size) → выходные логиты
|
||||||
|
|
||||||
Пример использования:
|
Пример использования:
|
||||||
>>> model = GPT2({"vocab_size": 50257, ...})
|
---------------------
|
||||||
>>> logits = model(input_ids)
|
>>> gpt2 = GPT2({...})
|
||||||
>>> out = model.generate(input_ids, max_length=20)
|
>>> logits = gpt2(input_ids)
|
||||||
|
>>> output = gpt2.generate(input_ids, max_new_tokens=20, do_sample=True)
|
||||||
|
|
||||||
|
References:
|
||||||
|
-----------
|
||||||
|
- Radford et al., "Language Models are Unsupervised Multitask Learners" (GPT-2, 2019): https://cdn.openai.com/better-language-models/language-models.pdf
|
||||||
|
- HuggingFace GPT-2: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py
|
||||||
|
- Репликация в NanoGPT: https://github.com/karpathy/nanoGPT
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
|
"""
|
||||||
|
Инициализация GPT-2.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config (dict): Параметры архитектуры:
|
||||||
|
vocab_size: int — размер словаря
|
||||||
|
embed_dim: int — размерность эмбеддинга
|
||||||
|
num_heads: int — количество attention-голов
|
||||||
|
num_layers: int — количество декодер-блоков
|
||||||
|
max_position_embeddings: максимальная длина последовательности
|
||||||
|
dropout: float — dropout
|
||||||
|
|
||||||
|
Внутри:
|
||||||
|
-------
|
||||||
|
- Создаёт токеновые и позиционные эмбеддинги, стек декодеров, финальный LayerNorm и линейную проекцию в словарь.
|
||||||
|
"""
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
# Инициализация слоев
|
# Инициализация слоев
|
||||||
@@ -83,18 +129,19 @@ class GPT2(BaseModel):
|
|||||||
self, x: torch.Tensor, use_cache: bool = True, cache: list = None
|
self, x: torch.Tensor, use_cache: bool = True, cache: list = None
|
||||||
) -> tuple:
|
) -> tuple:
|
||||||
"""
|
"""
|
||||||
Прямой проход GPT2:
|
Прямой проход для batch of sequences (получение логитов по токенам).
|
||||||
- Все слои работают как autoregressive transformer (masked self-attention).
|
|
||||||
- При use_cache=True возвращает также новый кэш KV attention (ускоряет генерацию).
|
|
||||||
Args:
|
Args:
|
||||||
x (Tensor): Входные индексы токенов [batch, seq_len]
|
x (torch.Tensor): Входной тензор с токенами [batch, seq_len]
|
||||||
use_cache (bool): Кэшировать KV attention для ускорения autoregressive генерации
|
use_cache (bool): Использовать/возвращать кэш KV attention (ускоряет генерацию)
|
||||||
cache (list|None): Список KV-кэшей от предыдущих шагов (или None)
|
cache (list / None): Внешний кэш KV attention (передаётся при генерации)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
logits (Tensor): [batch, seq_len, vocab_size]
|
logits: torch.Tensor [batch, seq_len, vocab_size]
|
||||||
cache (list): новый кэш если use_cache=True, иначе None
|
new_cache: новый кэш KV attention (или None)
|
||||||
|
|
||||||
Пример:
|
Пример:
|
||||||
>>> logits, cache = model.forward(x, use_cache=True)
|
>>> logits, cache = gpt2(x, use_cache=True)
|
||||||
"""
|
"""
|
||||||
# Проверка длины последовательности (только при отсутствии кэша)
|
# Проверка длины последовательности (только при отсутствии кэша)
|
||||||
if cache is None and x.size(1) > self._max_seq_len:
|
if cache is None and x.size(1) > self._max_seq_len:
|
||||||
@@ -104,12 +151,20 @@ class GPT2(BaseModel):
|
|||||||
|
|
||||||
# Вычисление start_pos из кэша (если кэш передан)
|
# Вычисление start_pos из кэша (если кэш передан)
|
||||||
if cache is not None:
|
if cache is not None:
|
||||||
# При кэше обрабатываем только один токен (последний)
|
|
||||||
seq_len = 1
|
seq_len = 1
|
||||||
# Вычисляем start_pos из самого нижнего уровня кэша
|
# Безопасно извлекаем key_cache для вычисления start_pos
|
||||||
if cache and cache[0] and cache[0][0]:
|
if (
|
||||||
key_cache, _ = cache[0][0] # Первый декодер, первая голова
|
isinstance(cache, (list, tuple))
|
||||||
start_pos = key_cache.size(1) # cache_len
|
and len(cache) > 0
|
||||||
|
and cache[0] is not None
|
||||||
|
and isinstance(cache[0], (list, tuple))
|
||||||
|
and len(cache[0]) > 0
|
||||||
|
and cache[0][0] is not None
|
||||||
|
and isinstance(cache[0][0], (tuple, list))
|
||||||
|
and len(cache[0][0]) > 0
|
||||||
|
):
|
||||||
|
key_cache, _ = cache[0][0]
|
||||||
|
start_pos = key_cache.size(1)
|
||||||
else:
|
else:
|
||||||
start_pos = 0
|
start_pos = 0
|
||||||
else:
|
else:
|
||||||
@@ -161,22 +216,56 @@ class GPT2(BaseModel):
|
|||||||
use_cache: bool = True,
|
use_cache: bool = True,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Генерация текста с использованием autoregressive трансформера (GPT2).
|
Авторегрессивная генерация токенов с поддержкой greedy, temperature, top-k, top-p sampling и KV-кэша.
|
||||||
Поддерживаются greedy, sampling, top-k/top-p (nucleus sampling) режимы.
|
|
||||||
Args:
|
Аргументы:
|
||||||
x (Tensor[int]): начальная последовательность [batch, seq_len]
|
x (torch.Tensor): Входной тензор с индексами токенов [batch_size, seq_len].
|
||||||
max_new_tokens (int): сколько токенов сгенерировать
|
max_new_tokens (int): Максимальное количество новых токенов для генерации.
|
||||||
do_sample (bool): использовать стохастическое сэмплирование вместо жадного выбора
|
do_sample (bool): Режим генерации:
|
||||||
temperature (float): коэффициент сглаживания логитов (низкое — более консервативно)
|
- True: вероятностное сэмплирование (random sampling)
|
||||||
top_k (int|None): ограничить выбор top-k наиболее вероятных токенов
|
- False: жадный (greedy) поиск (выбор argmax на каждом шаге)
|
||||||
top_p (float|None): ограничить суммарную вероятность (nucleus sampling)
|
temperature (float): Температура распределения (>0, по умолчанию 1.0).
|
||||||
use_cache (bool): ускорять autoregressive инференс
|
- >1.0 — генерация более "творческая"/приподнятая вероятность "редких" токенов;
|
||||||
Returns:
|
- <1.0 — более предсказуемый и суженный выбор.
|
||||||
output (Tensor[int]): сгенерированный тензор токенов [batch, seq_len + max_new_tokens]
|
top_k (int, опционально): Если задан, sampling только из top_k самых вероятных токенов (top-k sampling).
|
||||||
Пример:
|
top_p (float, опционально): Если задан, sampling только из токенов, кумулятивная вероятность которых ≤ top_p (nucleus/top-p sampling, см. Holtzman et al., 2019).
|
||||||
>>> prompt = tokenizer.encode('Привет', return_tensors="pt")
|
use_cache (bool, по умолчанию True): Использовать кэш attention KV для ускорения авторегрессии.
|
||||||
>>> output = model.generate(prompt, max_new_tokens=20, do_sample=True)
|
|
||||||
>>> print(tokenizer.decode(output[0]))
|
Возвращает:
|
||||||
|
torch.Tensor: Тензор индексов токенов [batch_size, seq_len + max_new_tokens].
|
||||||
|
|
||||||
|
Исключения:
|
||||||
|
ValueError: Если x длиннее максимальной длины (max_seq_len).
|
||||||
|
ValueError: Если temperature ≤ 0.
|
||||||
|
ValueError: Если одновременно заданы top_k и top_p.
|
||||||
|
ValueError: Если top_k ≤ 0.
|
||||||
|
ValueError: Если top_p не в диапазоне (0, 1].
|
||||||
|
|
||||||
|
Примеры использования:
|
||||||
|
>>> # Жадная генерация
|
||||||
|
>>> output = model.generate(input_ids, max_new_tokens=20, do_sample=False)
|
||||||
|
|
||||||
|
>>> # Сэмплирование с температурой
|
||||||
|
>>> output = model.generate(input_ids, max_new_tokens=20, do_sample=True, temperature=0.8)
|
||||||
|
|
||||||
|
>>> # Top-k sampling
|
||||||
|
>>> output = model.generate(input_ids, max_new_tokens=20, do_sample=True, top_k=50)
|
||||||
|
|
||||||
|
>>> # Top-p (nucleus) sampling
|
||||||
|
>>> output = model.generate(input_ids, max_new_tokens=20, do_sample=True, top_p=0.92)
|
||||||
|
|
||||||
|
>>> # Комбинация температуры и top-k
|
||||||
|
>>> output = model.generate(input_ids, max_new_tokens=20, do_sample=True, temperature=0.7, top_k=40)
|
||||||
|
|
||||||
|
Примечания:
|
||||||
|
- Для детерминированных результатов используйте torch.manual_seed.
|
||||||
|
- temperature, top_k, top_p работают только при do_sample=True.
|
||||||
|
- Только один из top_k/top_p может быть задан одновременно.
|
||||||
|
- Метод всегда возвращает индексы токенов (ids); для получения логитов используйте forward.
|
||||||
|
|
||||||
|
Ссылки:
|
||||||
|
- Holtzman et al., "The Curious Case of Neural Text Degeneration" (nucleus sampling): https://arxiv.org/abs/1904.09751
|
||||||
|
- Оригинальная статья GPT-2: https://cdn.openai.com/better-language-models/language-models.pdf
|
||||||
"""
|
"""
|
||||||
cache = None
|
cache = None
|
||||||
|
|
||||||
@@ -211,10 +300,9 @@ class GPT2(BaseModel):
|
|||||||
vocab_size = logits_scaled.size(-1)
|
vocab_size = logits_scaled.size(-1)
|
||||||
|
|
||||||
# создаём маску: 1, если токен НЕ в topk_indices
|
# создаём маску: 1, если токен НЕ в topk_indices
|
||||||
mask = torch.ones_like(logits_scaled, dtype=torch.uint8)
|
mask = torch.ones_like(logits_scaled, dtype=torch.bool if hasattr(torch, "bool") else torch.uint8)
|
||||||
mask.scatter_(1, topk_indices, 0) # 0 там, где top-k индексы
|
mask.scatter_(1, topk_indices, False if hasattr(torch, "bool") else 0) # 0 там, где top-k индексы
|
||||||
masked_logits[mask.byte()] = float("-inf")
|
masked_logits[mask.bool() if hasattr(torch, "bool") else mask.byte()] = float('-inf')
|
||||||
|
|
||||||
logits_scaled = masked_logits
|
logits_scaled = masked_logits
|
||||||
|
|
||||||
if do_sample == True and top_p != None:
|
if do_sample == True and top_p != None:
|
||||||
@@ -227,16 +315,16 @@ class GPT2(BaseModel):
|
|||||||
# 3. Посчитаем кумулятивную сумму вероятностей:
|
# 3. Посчитаем кумулятивную сумму вероятностей:
|
||||||
cum_probs = torch.cumsum(sorted_probs, dim=-1) # [B, vocab_size]
|
cum_probs = torch.cumsum(sorted_probs, dim=-1) # [B, vocab_size]
|
||||||
# 4. Определим маску: оставить токены, пока сумма < top_p
|
# 4. Определим маску: оставить токены, пока сумма < top_p
|
||||||
sorted_mask = (cum_probs <= top_p).byte() # [B, vocab_size]
|
sorted_mask = (cum_probs <= top_p).bool() if hasattr(torch, "bool") else (cum_probs <= top_p).byte() # [B, vocab_size]
|
||||||
# Гарантируем, что хотя бы первый токен останется
|
# Гарантируем, что хотя бы первый токен останется
|
||||||
sorted_mask[:, 0] = 1
|
sorted_mask[:, 0] = True if hasattr(torch, "bool") else 1
|
||||||
# 5. Преобразуем маску обратно в оригинальный порядок:
|
# 5. Преобразуем маску обратно в оригинальный порядок:
|
||||||
# Создаём полную маску из 0
|
# Создаём полную маску из 0
|
||||||
mask = torch.zeros_like(probs, dtype=torch.uint8)
|
mask = torch.zeros_like(probs, dtype=torch.bool if hasattr(torch, "bool") else torch.uint8)
|
||||||
# Устанавливаем 1 в местах нужных токенов
|
# Устанавливаем 1 в местах нужных токенов
|
||||||
mask.scatter_(dim=1, index=sorted_indices, src=sorted_mask)
|
mask.scatter_(dim=1, index=sorted_indices, src=sorted_mask)
|
||||||
# 6. Зануляем логиты токенов вне топ-p:
|
# 6. Зануляем логиты токенов вне топ-p:
|
||||||
logits_scaled[~mask] = float("-inf")
|
logits_scaled[~mask] = float('-inf')
|
||||||
|
|
||||||
# 4. Применяем Softmax
|
# 4. Применяем Softmax
|
||||||
probs = F.softmax(logits_scaled, dim=-1) # [batch_size, vocab_size]
|
probs = F.softmax(logits_scaled, dim=-1) # [batch_size, vocab_size]
|
||||||
|
|||||||
@@ -12,24 +12,62 @@ from llm.core.cached_decoder import CachedDecoder
|
|||||||
|
|
||||||
class Llama(BaseModel):
|
class Llama(BaseModel):
|
||||||
"""
|
"""
|
||||||
LLaMA (Large Language Model Meta AI) — высокоэффективная масштабируемая языковая модель, разработанная Meta AI Research.
|
LLaMA — автогерессивная большая языковая модель (Large Language Model from Meta, 2023).
|
||||||
|
|
||||||
Ключевые идеи:
|
Назначение:
|
||||||
- Rotary Positional Encoding (RoPE) вместо стандартных позиционных эмбеддингов
|
-----------
|
||||||
- RMSNorm (Root Mean Square LayerNorm) вместо LayerNorm
|
- Модель реализует архитектуру decoder-only Transformer с современными "индустриальными" трюками (RMSNorm, SwiGLU, RoPE, GQA).
|
||||||
- SwiGLU как нелинейность вместо ReLU/GELU (больше экспрессивности)
|
- Предназначена для генерации текста, чат-ботов, zero-/few-shot вывода, fine-tune в стиле RLHF, transfer learning и исследований в LLM.
|
||||||
- Глубокая оптимизация inference (большая экономия памяти и FLOPs)
|
|
||||||
Подробнее: https://arxiv.org/abs/2302.13971
|
Архитектурные особенности:
|
||||||
|
--------------------------
|
||||||
|
- Токеновые эмбеддинги и позиционное кодирование с помощью Rotary Position Embedding (RoPE, https://arxiv.org/abs/2104.09864).
|
||||||
|
- Stack из num_layers современных декодеров с Grouped Query Attention (GQA: num_q_heads > num_kv_heads) для эффективной генерации.
|
||||||
|
- FeedForward блоки с SwiGLU (см. https://arxiv.org/abs/2002.05202).
|
||||||
|
- Нормализация RMSNorm перед каждым sub-layer (вот почему "Pre-RMSNorm").
|
||||||
|
- Кэширование attention (KV cache) для быстрой autoregressive генерации.
|
||||||
|
- Нет bias в Linear слоях, нет Dropout внутри attention.
|
||||||
|
|
||||||
|
Аргументы конструктора:
|
||||||
|
-----------------------
|
||||||
|
config: dict с требуемыми ключами:
|
||||||
|
vocab_size: int — размер словаря токенов
|
||||||
|
embed_dim: int — размерность эмбеддингов
|
||||||
|
num_q_heads: int — количество query-голов в attention (обычно больше num_kv_heads)
|
||||||
|
num_kv_heads: int — количество key/value-голов
|
||||||
|
num_layers: int — число слоёв-декодеров
|
||||||
|
max_position_embeddings: int — максимальная длина последовательности
|
||||||
|
window_size: int (optional) — размер sliding window для attention
|
||||||
|
dropout: float (обычно 0.0 или очень мал)
|
||||||
|
...
|
||||||
|
|
||||||
|
Пример использования:
|
||||||
|
---------------------
|
||||||
|
>>> llama = LLaMA({...})
|
||||||
|
>>> tokens = torch.tensor([[100, 56, 8]])
|
||||||
|
>>> logits = llama(tokens)
|
||||||
|
>>> out = llama.generate(tokens, max_new_tokens=10, do_sample=True, top_k=50)
|
||||||
|
|
||||||
|
References:
|
||||||
|
-----------
|
||||||
|
- "LLaMA: Open and Efficient Foundation Language Models" (Touvron et al., 2023): https://arxiv.org/abs/2302.13971
|
||||||
|
- "Grouped-Query Attention": https://arxiv.org/abs/2307.09288
|
||||||
|
- "RoFormer: Enhanced Transformer with Rotary Position Embedding": https://arxiv.org/abs/2104.09864
|
||||||
|
- Discussion of efficient LLMs: https://huggingface.co/blog/mistral
|
||||||
|
|
||||||
Args:
|
|
||||||
config (dict): параметры архитектуры (vocab_size, embed_dim, num_heads, num_layers, max_position_embeddings, dropout)
|
|
||||||
Пример:
|
|
||||||
>>> model = Llama({...})
|
|
||||||
>>> logits, cache = model(input_ids, use_cache=True)
|
|
||||||
>>> out = model.generate(input_ids, max_new_tokens=20)
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
|
"""
|
||||||
|
Инициализация LLaMA.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config (dict): Параметры архитектуры, см. docstring класса.
|
||||||
|
Внутри:
|
||||||
|
-------
|
||||||
|
- Создаёт Embedding-слой, Rotary Position Embeddings (RoPE), стек слоёв с GQA, RMSNorm, SwiGLU.
|
||||||
|
- Финальный слой нормализации и проекции на vocabulary.
|
||||||
|
"""
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
# Инициализация слоев
|
# Инициализация слоев
|
||||||
@@ -68,17 +106,16 @@ class Llama(BaseModel):
|
|||||||
self, x: torch.Tensor, use_cache: bool = True, cache: list = None
|
self, x: torch.Tensor, use_cache: bool = True, cache: list = None
|
||||||
) -> tuple:
|
) -> tuple:
|
||||||
"""
|
"""
|
||||||
Прямой проход через LLaMA (inference/train): авторегрессионное предсказание токенов.
|
Прямой проход: возвращает logits (и возможно обновлённый cache) по входным токенам.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
x (Tensor[int]): входные токены [batch, seq_len]
|
x (torch.Tensor): [batch, seq_len] — индексы токенов, shape [batch, seq_len]
|
||||||
use_cache (bool): использовать ли кэш (ускоряет генерацию)
|
use_cache (bool): использовать механизм KV cache (ускоряет autoregressive generation)
|
||||||
cache (list|None): ключи и значения attention для autoregressive режима
|
cache (list or None): предыдущий кэш, если нужен
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
logits (Tensor): [batch, seq_len, vocab_size]
|
logits: torch.Tensor [batch, seq_len, vocab_size]
|
||||||
new_cache (list|None): новый кэш attention (если use_cache)
|
new_cache: новый кэш attention (или None)
|
||||||
Пример:
|
|
||||||
>>> logits, cache = model.forward(x, use_cache=True)
|
|
||||||
"""
|
"""
|
||||||
# Проверка длины последовательности (только при отсутствии кэша)
|
# Проверка длины последовательности (только при отсутствии кэша)
|
||||||
if cache is None and x.size(1) > self._max_seq_len:
|
if cache is None and x.size(1) > self._max_seq_len:
|
||||||
@@ -141,25 +178,50 @@ class Llama(BaseModel):
|
|||||||
use_cache: bool = True,
|
use_cache: bool = True,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Генерация текста c помощью LLaMA (autoregressive Transformer).
|
Авторегрессивная генерация последовательностей на основе LLaMA (greedy, temperature, top-k, top-p/nucleus, поддержка KV-кэша).
|
||||||
Поддерживается:
|
|
||||||
- greedy и вероятностное сэмплирование (top-k, top-p, temperature)
|
|
||||||
- кэш attention для ускорения генерации длинных последовательностей
|
|
||||||
|
|
||||||
Args:
|
Аргументы:
|
||||||
x (Tensor[int]): начальная последовательность [batch, seq_len]
|
x (torch.Tensor): Входной тензор с токенами shape [batch_size, seq_len].
|
||||||
max_new_tokens (int): сколько новых токенов сгенерировать
|
max_new_tokens (int): Максимальное количество новых токенов для генерации.
|
||||||
do_sample (bool): использовать стохастику (True) или жадный выбор (False)
|
do_sample (bool): Использовать вероятностное сэмплирование (True) или жадный режим (False, argmax).
|
||||||
temperature (float): масштаб для softmax (важно для sampling)
|
temperature (float): Температура (сглаживание распределения вероятностей, >0; по умолчанию 1.0).
|
||||||
top_k (int|None): ограничение на количество кандидатов (top-k sampling)
|
>1.0 — менее предсказуемые, более разнообразные выборки.
|
||||||
top_p (float|None): nucleus sampling
|
<1.0 — более строгие, консервативные выборки.
|
||||||
use_cache (bool): ускоряет autoregressive при длинной генерации
|
top_k (int, опционально): Top-k сэмплирование (ограничение выбора k самыми вероятными токенами).
|
||||||
Returns:
|
top_p (float, опционально): Nucleus (top-p) sampling (срез по кумулятивной вероятности ≤ top_p, см. Holtzman et al., 2019).
|
||||||
output (Tensor[int]): [batch, seq_len + max_new_tokens]
|
use_cache (bool, по умолчанию True): Использовать KV-кэш для ускорения генерации.
|
||||||
Пример:
|
|
||||||
>>> prompt = tokenizer.encode('Meta AI', return_tensors="pt")
|
Возвращает:
|
||||||
>>> generated = model.generate(prompt, max_new_tokens=30, do_sample=True)
|
torch.Tensor: Последовательность токенов shape [batch_size, seq_len + max_new_tokens].
|
||||||
>>> print(tokenizer.decode(generated[0]))
|
|
||||||
|
Исключения:
|
||||||
|
ValueError: Если x длиннее максимально допустимой длины (max_seq_len модели).
|
||||||
|
ValueError: Если temperature ≤ 0.
|
||||||
|
ValueError: Если одновременно заданы top_k и top_p.
|
||||||
|
ValueError: Если top_k ≤ 0.
|
||||||
|
ValueError: Если top_p не в диапазоне (0, 1].
|
||||||
|
|
||||||
|
Примеры:
|
||||||
|
>>> # Строго жадная генерация
|
||||||
|
>>> out = model.generate(input_ids, max_new_tokens=16, do_sample=False)
|
||||||
|
>>> # Вероятностная генерация с температурой
|
||||||
|
>>> out = model.generate(input_ids, max_new_tokens=16, do_sample=True, temperature=0.7)
|
||||||
|
>>> # Top-k sampling
|
||||||
|
>>> out = model.generate(input_ids, max_new_tokens=16, do_sample=True, top_k=50)
|
||||||
|
>>> # Top-p (nucleus)
|
||||||
|
>>> out = model.generate(input_ids, max_new_tokens=16, do_sample=True, top_p=0.92)
|
||||||
|
>>> # Комбинация температуры и top-k
|
||||||
|
>>> out = model.generate(input_ids, max_new_tokens=16, do_sample=True, temperature=1.0, top_k=100)
|
||||||
|
|
||||||
|
Примечания:
|
||||||
|
- temperature, top_k, top_p применяются только если do_sample=True.
|
||||||
|
- Одновременное использование top_k и top_p запрещено.
|
||||||
|
- Для воспроизводимых результатов зафиксируйте seed через torch.manual_seed.
|
||||||
|
- Возвращается только индексы токенов; для получения вероятностей используйте forward.
|
||||||
|
|
||||||
|
Ссылки:
|
||||||
|
- Holtzman et al., "The Curious Case of Neural Text Degeneration" (nucleus/top-p): https://arxiv.org/abs/1904.09751
|
||||||
|
- LLaMA: https://arxiv.org/abs/2302.13971
|
||||||
"""
|
"""
|
||||||
cache = None
|
cache = None
|
||||||
|
|
||||||
@@ -194,10 +256,9 @@ class Llama(BaseModel):
|
|||||||
vocab_size = logits_scaled.size(-1)
|
vocab_size = logits_scaled.size(-1)
|
||||||
|
|
||||||
# создаём маску: 1, если токен НЕ в topk_indices
|
# создаём маску: 1, если токен НЕ в topk_indices
|
||||||
mask = torch.ones_like(logits_scaled, dtype=torch.uint8)
|
mask = torch.ones_like(logits_scaled, dtype=torch.bool if hasattr(torch, "bool") else torch.uint8)
|
||||||
mask.scatter_(1, topk_indices, 0) # 0 там, где top-k индексы
|
mask.scatter_(1, topk_indices, False if hasattr(torch, "bool") else 0) # 0 там, где top-k индексы
|
||||||
masked_logits[mask.byte()] = float("-inf")
|
masked_logits[mask.bool() if hasattr(torch, "bool") else mask.byte()] = float('-inf')
|
||||||
|
|
||||||
logits_scaled = masked_logits
|
logits_scaled = masked_logits
|
||||||
|
|
||||||
if do_sample == True and top_p != None:
|
if do_sample == True and top_p != None:
|
||||||
@@ -210,16 +271,16 @@ class Llama(BaseModel):
|
|||||||
# 3. Посчитаем кумулятивную сумму вероятностей:
|
# 3. Посчитаем кумулятивную сумму вероятностей:
|
||||||
cum_probs = torch.cumsum(sorted_probs, dim=-1) # [B, vocab_size]
|
cum_probs = torch.cumsum(sorted_probs, dim=-1) # [B, vocab_size]
|
||||||
# 4. Определим маску: оставить токены, пока сумма < top_p
|
# 4. Определим маску: оставить токены, пока сумма < top_p
|
||||||
sorted_mask = (cum_probs <= top_p).byte() # [B, vocab_size]
|
sorted_mask = (cum_probs <= top_p).bool() if hasattr(torch, "bool") else (cum_probs <= top_p).byte() # [B, vocab_size]
|
||||||
# Гарантируем, что хотя бы первый токен останется
|
# Гарантируем, что хотя бы первый токен останется
|
||||||
sorted_mask[:, 0] = 1
|
sorted_mask[:, 0] = True if hasattr(torch, "bool") else 1
|
||||||
# 5. Преобразуем маску обратно в оригинальный порядок:
|
# 5. Преобразуем маску обратно в оригинальный порядок:
|
||||||
# Создаём полную маску из 0
|
# Создаём полную маску из 0
|
||||||
mask = torch.zeros_like(probs, dtype=torch.uint8)
|
mask = torch.zeros_like(probs, dtype=torch.bool if hasattr(torch, "bool") else torch.uint8)
|
||||||
# Устанавливаем 1 в местах нужных токенов
|
# Устанавливаем 1 в местах нужных токенов
|
||||||
mask.scatter_(dim=1, index=sorted_indices, src=sorted_mask)
|
mask.scatter_(dim=1, index=sorted_indices, src=sorted_mask)
|
||||||
# 6. Зануляем логиты токенов вне топ-p:
|
# 6. Зануляем логиты токенов вне топ-p:
|
||||||
logits_scaled[~mask] = float("-inf")
|
logits_scaled[~mask] = float('-inf')
|
||||||
|
|
||||||
# 4. Применяем Softmax
|
# 4. Применяем Softmax
|
||||||
probs = F.softmax(logits_scaled, dim=-1) # [batch_size, vocab_size]
|
probs = F.softmax(logits_scaled, dim=-1) # [batch_size, vocab_size]
|
||||||
|
|||||||
3
llm/src/llm/models/mistral/__init__.py
Normal file
3
llm/src/llm/models/mistral/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
from .mistral import Mistral
|
||||||
|
|
||||||
|
__all__ = ["Mistral"]
|
||||||
273
llm/src/llm/models/mistral/mistral.py
Normal file
273
llm/src/llm/models/mistral/mistral.py
Normal file
@@ -0,0 +1,273 @@
|
|||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from torch import Tensor
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from math import sqrt
|
||||||
|
from llm.core.base_model import BaseModel
|
||||||
|
from llm.core.token_embeddings import TokenEmbeddings
|
||||||
|
from llm.core.rms_norm import RMSNorm
|
||||||
|
from llm.core.rope import RoPE
|
||||||
|
from llm.core.mistral_decoder import MistralDecoder
|
||||||
|
|
||||||
|
|
||||||
|
class Mistral(BaseModel):
|
||||||
|
"""
|
||||||
|
Mistral — автогерессивная языковая LLM-архитектура (2023, Mistral AI) для быстрого и качественного моделирования текста.
|
||||||
|
|
||||||
|
Назначение:
|
||||||
|
-----------
|
||||||
|
- Модель построена на базе decoder-only Transformer с важными оптимизациями: GQA (Grouped Query Attention), RoPE, SwiGLU, RMSNorm, sliding window attention.
|
||||||
|
- Поддерживает autoregressive generation (step-by-step текст), обучение и inference на длинных последовательностях.
|
||||||
|
- Используется в современных open-source LLM: Mistral-7B, Mixtral-8x7B и др.
|
||||||
|
|
||||||
|
Архитектурные особенности:
|
||||||
|
--------------------------
|
||||||
|
- Токеновые эмбеддинги (TokenEmbeddings) и позиционное кодирование через RoPE (rotary position embedding).
|
||||||
|
- Stack из num_layers декодеров с Grouped Query Attention (раздельное число query/key heads для оптимизации памяти).
|
||||||
|
- Sliding Window Attention Mask — позволяет ускорять обработку длинных текстов, ограничивая область внимания для каждого токена (как в оригинальном Mistral).
|
||||||
|
- SwiGLU FeedForward-блоки и RMSNorm.
|
||||||
|
- Dropout (регуляризация).
|
||||||
|
- Кэширование attention (KV cache) для быстрой генерации токенов по одному.
|
||||||
|
|
||||||
|
Аргументы конструктора:
|
||||||
|
-----------------------
|
||||||
|
config (dict): параметры модели (см. документацию Mistral):
|
||||||
|
vocab_size: int — размер словаря токенов
|
||||||
|
embed_dim: int — размерность эмбеддингов
|
||||||
|
num_q_heads: int — количество query-голов (обычно больше num_kv_heads)
|
||||||
|
num_kv_heads: int — количество key/value attention-голов
|
||||||
|
num_layers: int — число слоёв-декодеров
|
||||||
|
max_position_embeddings: int — максимальная длина последовательности
|
||||||
|
window_size: int — размер sliding window attention
|
||||||
|
dropout: float — dropout (обычно очень мал или 0)
|
||||||
|
...
|
||||||
|
|
||||||
|
Пример использования:
|
||||||
|
---------------------
|
||||||
|
>>> model = Mistral({...})
|
||||||
|
>>> tokens = torch.tensor([[100, 56, 8]])
|
||||||
|
>>> logits = model(tokens)
|
||||||
|
>>> generated = model.generate(tokens, max_new_tokens=16, do_sample=True, top_k=50)
|
||||||
|
|
||||||
|
References:
|
||||||
|
-----------
|
||||||
|
- "Mistral: Fast and Efficient Dense and Mixture of Experts Transformer Models" (2023): https://arxiv.org/abs/2310.06825
|
||||||
|
- LLaMA v2 & Grouped-Query Attention: https://arxiv.org/abs/2307.09288
|
||||||
|
- Оригинальное обсуждение архитектуры: https://huggingface.co/blog/mistral
|
||||||
|
|
||||||
|
"""
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
|
||||||
|
self._max_seq_len = config["max_position_embeddings"]
|
||||||
|
# Инициализация слоев
|
||||||
|
self._token_embeddings = TokenEmbeddings(
|
||||||
|
vocab_size=config["vocab_size"],
|
||||||
|
emb_size=config["embed_dim"]
|
||||||
|
)
|
||||||
|
self._position_embeddings = RoPE(
|
||||||
|
head_size=config["embed_dim"] // config["num_q_heads"],
|
||||||
|
max_seq_len=config["max_position_embeddings"]
|
||||||
|
)
|
||||||
|
#self._position_embeddings = PositionalEmbeddings(
|
||||||
|
# max_seq_len=max_seq_len,
|
||||||
|
# emb_size=emb_size
|
||||||
|
#)
|
||||||
|
self._dropout = nn.Dropout(config["dropout"])
|
||||||
|
self._decoders = nn.ModuleList([MistralDecoder(
|
||||||
|
num_q_heads=config["num_q_heads"],
|
||||||
|
num_kv_heads=config["num_kv_heads"],
|
||||||
|
emb_size=config["embed_dim"],
|
||||||
|
head_size=config["embed_dim"] // config["num_q_heads"],
|
||||||
|
max_seq_len=config["max_position_embeddings"],
|
||||||
|
window_size=config["window_size"],
|
||||||
|
rope=self._position_embeddings,
|
||||||
|
dropout=config["dropout"]
|
||||||
|
) for _ in range(config["num_layers"])])
|
||||||
|
self._norm = RMSNorm(config["embed_dim"])
|
||||||
|
self._linear = nn.Linear(config["embed_dim"], config["vocab_size"])
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, use_cache: bool = True, cache: list = None) -> tuple:
|
||||||
|
"""
|
||||||
|
Прямой проход (forward) через всю модель Mistral: возвращает логиты для токенов и (опционально) кэш attention для ускорения autoregressive генерации.
|
||||||
|
|
||||||
|
Аргументы:
|
||||||
|
x (torch.Tensor): Входной тензор с токенами (shape [batch_size, seq_len]), где значения — индексы токенов.
|
||||||
|
use_cache (bool, по умолчанию True): Возвращать ли новый KV attention-кэш для последующей генерации.
|
||||||
|
cache (list or None): Предыдущий кэш attention (или None для полного прохода без накопления кэша).
|
||||||
|
|
||||||
|
Возвращает:
|
||||||
|
logits (torch.Tensor): Тензор логитов shape [batch_size, seq_len, vocab_size] — вероятностное распределение по словарю для каждого токена.
|
||||||
|
new_cache (list or None): Новый кэш KV attention-слоев (или None, если use_cache=False).
|
||||||
|
|
||||||
|
Исключения:
|
||||||
|
ValueError: Если длина последовательности превышает максимальную (max_seq_len), когда не используется кэш.
|
||||||
|
|
||||||
|
Пример:
|
||||||
|
>>> logits, cache = model.forward(input_ids, use_cache=True)
|
||||||
|
>>> probabilities = torch.softmax(logits, dim=-1)
|
||||||
|
"""
|
||||||
|
# Проверка длины последовательности (только при отсутствии кэша)
|
||||||
|
if cache is None and x.size(1) > self._max_seq_len:
|
||||||
|
raise ValueError(f"Длина последовательности {x.size(1)} превышает максимальную {self.max_seq_len}")
|
||||||
|
|
||||||
|
# Эмбеддинги токенов и позиций
|
||||||
|
tok_out = self._token_embeddings(x) # [batch, seq_len, emb_size]
|
||||||
|
#pos_out = self._position_embeddings(x) # [batch, seq_len, emb_size]
|
||||||
|
|
||||||
|
# Комбинирование
|
||||||
|
out = self._dropout(tok_out) # [batch, seq_len, emb_size]
|
||||||
|
|
||||||
|
# Стек декодеров с передачей кэша
|
||||||
|
new_cache = []
|
||||||
|
for i, decoder in enumerate(self._decoders):
|
||||||
|
decoder_cache = cache[i] if cache is not None else None
|
||||||
|
decoder_result = decoder(out, use_cache=use_cache, cache=decoder_cache)
|
||||||
|
|
||||||
|
# Извлекаем результат из кортежа
|
||||||
|
if use_cache:
|
||||||
|
out, decoder_new_cache = decoder_result
|
||||||
|
new_cache.append(decoder_new_cache)
|
||||||
|
else:
|
||||||
|
out = decoder_result[0]
|
||||||
|
|
||||||
|
out = self._norm(out)
|
||||||
|
logits = self._linear(out)
|
||||||
|
|
||||||
|
# Возвращаем результат с учетом use_cache
|
||||||
|
if use_cache:
|
||||||
|
return (logits, new_cache)
|
||||||
|
else:
|
||||||
|
return (logits, None)
|
||||||
|
|
||||||
|
def generate(self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
max_new_tokens: int,
|
||||||
|
do_sample: bool,
|
||||||
|
temperature: float = 1.0,
|
||||||
|
top_k: int = None,
|
||||||
|
top_p: float = None,
|
||||||
|
use_cache: bool = True
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Авторегрессивная генерация токенов с поддержкой greedy, temperature, top-k/top-p sampling
|
||||||
|
и ускорением через attention-кэш (KV-cache, важно для inference на длинных текстах).
|
||||||
|
|
||||||
|
Аргументы:
|
||||||
|
x (torch.Tensor): Входной тензор с токенами shape [batch_size, seq_len].
|
||||||
|
max_new_tokens (int): Максимальное количество новых токенов для генерации.
|
||||||
|
do_sample (bool): Если True — вероятность/случайность (random sampling); если False — жадная генерация (argmax).
|
||||||
|
temperature (float): Температура (>0, по умолчанию 1.0); >1.0 — более случайные выборы, <1.0 — более строгие.
|
||||||
|
top_k (int, optional): top-k sampling; при сэмплировании выбираются только top_k наиболее вероятных токенов.
|
||||||
|
top_p (float, optional): nucleus (top-p) sampling; выбираются токены с накопленной вероятностью ≤ top_p.
|
||||||
|
use_cache (bool, по умолчанию True): Использовать ускорение через KV attention cache для autoregressive режима.
|
||||||
|
|
||||||
|
Возвращает:
|
||||||
|
torch.Tensor: Последовательность индексов токенов shape [batch_size, seq_len + max_new_tokens].
|
||||||
|
|
||||||
|
Исключения:
|
||||||
|
ValueError: Если x длиннее max_seq_len модели.
|
||||||
|
ValueError: Если temperature ≤ 0.
|
||||||
|
ValueError: Если одновременно заданы top_k и top_p.
|
||||||
|
ValueError: Если top_k ≤ 0.
|
||||||
|
ValueError: Если top_p не в диапазоне (0, 1].
|
||||||
|
|
||||||
|
Примеры:
|
||||||
|
>>> # Жадная генерация
|
||||||
|
>>> out = model.generate(input_ids, max_new_tokens=16, do_sample=False)
|
||||||
|
>>> # Сэмплирование с температурой
|
||||||
|
>>> out = model.generate(input_ids, max_new_tokens=16, do_sample=True, temperature=0.8)
|
||||||
|
>>> # Top-k sampling
|
||||||
|
>>> out = model.generate(input_ids, max_new_tokens=16, do_sample=True, top_k=50)
|
||||||
|
>>> # Top-p (nucleus) sampling
|
||||||
|
>>> out = model.generate(input_ids, max_new_tokens=16, do_sample=True, top_p=0.92)
|
||||||
|
>>> # Температура + top-k
|
||||||
|
>>> out = model.generate(input_ids, max_new_tokens=16, do_sample=True, temperature=1.0, top_k=100)
|
||||||
|
|
||||||
|
Примечания:
|
||||||
|
- Одновременно использовать top_k и top_p нельзя.
|
||||||
|
- Параметры temperature, top_k, top_p работают только при do_sample=True.
|
||||||
|
- Для полного воспроизведения результата зафиксируйте seed через torch.manual_seed.
|
||||||
|
- Метод всегда возвращает только индексы токенов; для получения логитов используйте forward.
|
||||||
|
|
||||||
|
Ссылки:
|
||||||
|
- Holtzman et al., "The Curious Case of Neural Text Degeneration" (nucleus/top-p sampling): https://arxiv.org/abs/1904.09751
|
||||||
|
- Mistral: https://arxiv.org/abs/2310.06825
|
||||||
|
"""
|
||||||
|
cache = None
|
||||||
|
|
||||||
|
for _ in range(max_new_tokens):
|
||||||
|
if use_cache and cache is not None:
|
||||||
|
# Используем кэш - передаем только последний токен
|
||||||
|
x_input = x[:, -1:] # [batch_size, 1]
|
||||||
|
else:
|
||||||
|
# Первая итерация или кэш отключен - передаем всю последовательность
|
||||||
|
x_input = x
|
||||||
|
|
||||||
|
# Прямой проход с кэшем
|
||||||
|
logits, new_cache = self.forward(x_input, use_cache=use_cache, cache=cache)
|
||||||
|
|
||||||
|
# Обновляем кэш для следующей итерации
|
||||||
|
if use_cache:
|
||||||
|
cache = new_cache
|
||||||
|
|
||||||
|
last_logits = logits[:, -1, :] # [batch_size, vocab_size]
|
||||||
|
|
||||||
|
# Масштабируем логиты температурой
|
||||||
|
if temperature > 0:
|
||||||
|
logits_scaled = last_logits / temperature
|
||||||
|
else:
|
||||||
|
logits_scaled = last_logits
|
||||||
|
|
||||||
|
if do_sample == True and top_k != None:
|
||||||
|
_, topk_indices = torch.topk(logits_scaled, top_k, dim=-1)
|
||||||
|
|
||||||
|
# # Заменим все НЕ top-k логиты на -inf
|
||||||
|
masked_logits = logits_scaled.clone()
|
||||||
|
vocab_size = logits_scaled.size(-1)
|
||||||
|
|
||||||
|
# создаём маску: 1, если токен НЕ в topk_indices
|
||||||
|
mask = torch.ones_like(logits_scaled, dtype=torch.bool if hasattr(torch, "bool") else torch.uint8)
|
||||||
|
mask.scatter_(1, topk_indices, False if hasattr(torch, "bool") else 0) # 0 там, где top-k индексы
|
||||||
|
masked_logits[mask.bool() if hasattr(torch, "bool") else mask.byte()] = float('-inf')
|
||||||
|
|
||||||
|
logits_scaled = masked_logits
|
||||||
|
|
||||||
|
if do_sample == True and top_p != None:
|
||||||
|
# 1. Применим softmax, чтобы получить вероятности:
|
||||||
|
probs = F.softmax(logits_scaled, dim=-1) # [B, vocab_size]
|
||||||
|
# 2. Отсортируем токены по убыванию вероятностей:
|
||||||
|
sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1)
|
||||||
|
# 3. Посчитаем кумулятивную сумму вероятностей:
|
||||||
|
cum_probs = torch.cumsum(sorted_probs, dim=-1) # [B, vocab_size]
|
||||||
|
# 4. Определим маску: оставить токены, пока сумма < top_p
|
||||||
|
sorted_mask = (cum_probs <= top_p).bool() if hasattr(torch, "bool") else (cum_probs <= top_p).byte() # [B, vocab_size]
|
||||||
|
# Гарантируем, что хотя бы первый токен останется
|
||||||
|
sorted_mask[:, 0] = True if hasattr(torch, "bool") else 1
|
||||||
|
# 5. Преобразуем маску обратно в оригинальный порядок:
|
||||||
|
# Создаём полную маску из 0
|
||||||
|
mask = torch.zeros_like(probs, dtype=torch.bool if hasattr(torch, "bool") else torch.uint8)
|
||||||
|
# Устанавливаем 1 в местах нужных токенов
|
||||||
|
mask.scatter_(dim=1, index=sorted_indices, src=sorted_mask)
|
||||||
|
# 6. Зануляем логиты токенов вне топ-p:
|
||||||
|
logits_scaled[~mask] = float('-inf')
|
||||||
|
|
||||||
|
# 4. Применяем Softmax
|
||||||
|
probs = F.softmax(logits_scaled, dim=-1) # [batch_size, vocab_size]
|
||||||
|
|
||||||
|
|
||||||
|
if do_sample == True:
|
||||||
|
# 5. Если do_sample равен True, то отбираем токен случайно с помощью torch.multinomial
|
||||||
|
next_token = torch.multinomial(probs, num_samples=1) # [batch_size, 1]
|
||||||
|
else:
|
||||||
|
# 5. Если do_sample равен False, то выбираем токен с максимальной вероятностью
|
||||||
|
next_token = torch.argmax(probs, dim=-1, keepdim=True) # [batch_size, 1]
|
||||||
|
|
||||||
|
# 6. Добавляем его к последовательности
|
||||||
|
x = torch.cat([x, next_token], dim=1) # [batch_size, seq_len+1]
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
@property
|
||||||
|
def max_seq_len(self) -> int:
|
||||||
|
return self._max_seq_len
|
||||||
@@ -1,454 +0,0 @@
|
|||||||
"""
|
|
||||||
BPE (Byte Pair Encoding) токенизатор.
|
|
||||||
|
|
||||||
Реализация алгоритма BPE для токенизации текста.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import re
|
|
||||||
from collections import defaultdict, Counter
|
|
||||||
from typing import List, Dict, Tuple, Optional
|
|
||||||
from .base_tokenizer import BaseTokenizer
|
|
||||||
|
|
||||||
|
|
||||||
class BPETokenizer(BaseTokenizer):
|
|
||||||
"""
|
|
||||||
BPE токенизатор для обработки текста.
|
|
||||||
|
|
||||||
Реализует алгоритм Byte Pair Encoding для создания субсловных токенов.
|
|
||||||
|
|
||||||
Примеры использования:
|
|
||||||
>>> tokenizer = BPETokenizer()
|
|
||||||
>>> tokenizer.train(["пример текста для обучения"], vocab_size=1000)
|
|
||||||
>>> tokens = tokenizer.encode("новый текст")
|
|
||||||
>>> text = tokenizer.decode(tokens)
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
self.merges: Dict[Tuple[str, str], int] = {}
|
|
||||||
self.pattern = r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
|
|
||||||
self.compiled_pattern = re.compile(self.pattern, re.UNICODE)
|
|
||||||
|
|
||||||
def train(self, texts: List[str], vocab_size: int = 1000, **kwargs):
|
|
||||||
"""
|
|
||||||
Обучение BPE токенизатора на текстах.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
texts: Список текстов для обучения
|
|
||||||
vocab_size: Желаемый размер словаря
|
|
||||||
**kwargs: Дополнительные параметры
|
|
||||||
- min_frequency: Минимальная частота для мерджа
|
|
||||||
- special_tokens: Список специальных токенов
|
|
||||||
"""
|
|
||||||
# Инициализация базового словаря
|
|
||||||
self._initialize_vocab()
|
|
||||||
|
|
||||||
# Добавляем специальные токены если указаны
|
|
||||||
special_tokens = kwargs.get(
|
|
||||||
"special_tokens",
|
|
||||||
[self.pad_token, self.unk_token, self.bos_token, self.eos_token],
|
|
||||||
)
|
|
||||||
self.add_special_tokens(special_tokens)
|
|
||||||
|
|
||||||
# Предобработка текстов
|
|
||||||
words = self._preprocess_texts(texts)
|
|
||||||
|
|
||||||
# Получаем начальные токены
|
|
||||||
vocab = self._get_initial_vocab(words)
|
|
||||||
|
|
||||||
# Выполняем BPE мерджи
|
|
||||||
self._perform_merges(vocab, vocab_size, kwargs.get("min_frequency", 2))
|
|
||||||
|
|
||||||
# Строим финальный словарь
|
|
||||||
self._build_final_vocab()
|
|
||||||
|
|
||||||
def _initialize_vocab(self):
|
|
||||||
"""Инициализирует базовый словарь."""
|
|
||||||
self.vocab.clear()
|
|
||||||
self.inverse_vocab.clear()
|
|
||||||
self.merges.clear()
|
|
||||||
self.vocab_size = 0
|
|
||||||
|
|
||||||
def _preprocess_texts(self, texts: List[str]) -> List[List[str]]:
|
|
||||||
"""
|
|
||||||
Предобработка текстов для обучения.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
texts: Список текстов
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[List[str]]: Предобработанные слова
|
|
||||||
"""
|
|
||||||
words = []
|
|
||||||
for text in texts:
|
|
||||||
# Базовая нормализация
|
|
||||||
text = text.lower().strip()
|
|
||||||
# Токенизация на слова
|
|
||||||
tokens = self.compiled_pattern.findall(text)
|
|
||||||
words.append(tokens)
|
|
||||||
return words
|
|
||||||
|
|
||||||
def _get_initial_vocab(self, words: List[List[str]]) -> Dict[str, int]:
|
|
||||||
"""
|
|
||||||
Создает начальный словарь из символов.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
words: Список токенизированных текстов
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict[str, int]: Начальный словарь частот
|
|
||||||
"""
|
|
||||||
vocab = Counter()
|
|
||||||
for word_list in words:
|
|
||||||
for word in word_list:
|
|
||||||
# Разбиваем слово на символы и добавляем специальный символ конца слова
|
|
||||||
chars = list(word) + ["</w>"]
|
|
||||||
vocab.update(["".join(chars[i : i + 1]) for i in range(len(chars))])
|
|
||||||
return vocab
|
|
||||||
|
|
||||||
def _perform_merges(
|
|
||||||
self, vocab: Dict[str, int], target_vocab_size: int, min_frequency: int
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Выполняет BPE мерджи до достижения целевого размера словаря.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
vocab: Начальный словарь
|
|
||||||
target_vocab_size: Целевой размер словаря
|
|
||||||
min_frequency: Минимальная частота для мерджа
|
|
||||||
"""
|
|
||||||
current_vocab_size = len(vocab) + len(self.vocab)
|
|
||||||
|
|
||||||
while current_vocab_size < target_vocab_size:
|
|
||||||
# Находим наиболее частую пару
|
|
||||||
pairs = self._get_stats(vocab)
|
|
||||||
if not pairs:
|
|
||||||
break
|
|
||||||
|
|
||||||
best_pair = max(pairs, key=pairs.get)
|
|
||||||
if pairs[best_pair] < min_frequency:
|
|
||||||
break
|
|
||||||
|
|
||||||
# Выполняем мердж
|
|
||||||
vocab = self._merge_vocab(vocab, best_pair)
|
|
||||||
self.merges[best_pair] = len(self.merges)
|
|
||||||
current_vocab_size += 1
|
|
||||||
|
|
||||||
def _get_stats(self, vocab: Dict[str, int]) -> Dict[Tuple[str, str], int]:
|
|
||||||
"""
|
|
||||||
Собирает статистику по парам символов.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
vocab: Словарь токенов
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict[Tuple[str, str], int]: Частоты пар
|
|
||||||
"""
|
|
||||||
pairs = defaultdict(int)
|
|
||||||
for word, freq in vocab.items():
|
|
||||||
symbols = word.split()
|
|
||||||
for i in range(len(symbols) - 1):
|
|
||||||
pairs[symbols[i], symbols[i + 1]] += freq
|
|
||||||
return pairs
|
|
||||||
|
|
||||||
def _merge_vocab(
|
|
||||||
self, vocab: Dict[str, int], pair: Tuple[str, str]
|
|
||||||
) -> Dict[str, int]:
|
|
||||||
"""
|
|
||||||
Объединяет пару символов в словаре.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
vocab: Исходный словарь
|
|
||||||
pair: Пара для объединения
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict[str, int]: Обновленный словарь
|
|
||||||
"""
|
|
||||||
new_vocab = {}
|
|
||||||
bigram = re.compile(
|
|
||||||
r"(?<!\\S)" + re.escape(pair[0]) + r" " + re.escape(pair[1]) + r"(?!\\S)"
|
|
||||||
)
|
|
||||||
replacement = pair[0] + pair[1]
|
|
||||||
|
|
||||||
for word in vocab:
|
|
||||||
new_word = bigram.sub(replacement, word)
|
|
||||||
new_vocab[new_word] = vocab[word]
|
|
||||||
|
|
||||||
return new_vocab
|
|
||||||
|
|
||||||
def _build_final_vocab(self):
|
|
||||||
"""Строит финальный словарь токенизатора."""
|
|
||||||
# Собираем все уникальные токены из мерджей
|
|
||||||
all_tokens = set()
|
|
||||||
|
|
||||||
# Добавляем специальные токены
|
|
||||||
all_tokens.update(
|
|
||||||
[self.pad_token, self.unk_token, self.bos_token, self.eos_token]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Добавляем токены из мерджей
|
|
||||||
for pair in self.merges:
|
|
||||||
all_tokens.update(pair)
|
|
||||||
|
|
||||||
# Создаем словарь
|
|
||||||
for i, token in enumerate(sorted(all_tokens)):
|
|
||||||
self.vocab[token] = i
|
|
||||||
|
|
||||||
self.inverse_vocab = {v: k for k, v in self.vocab.items()}
|
|
||||||
self.vocab_size = len(self.vocab)
|
|
||||||
|
|
||||||
# Обновляем ID специальных токенов
|
|
||||||
self.pad_token_id = self.vocab.get(self.pad_token)
|
|
||||||
self.unk_token_id = self.vocab.get(self.unk_token)
|
|
||||||
self.bos_token_id = self.vocab.get(self.bos_token)
|
|
||||||
self.eos_token_id = self.vocab.get(self.eos_token)
|
|
||||||
|
|
||||||
def encode(self, text: str, **kwargs) -> List[int]:
|
|
||||||
"""
|
|
||||||
Кодирует текст в последовательность токенов.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
text: Входной текст
|
|
||||||
**kwargs: Дополнительные параметры
|
|
||||||
- add_special_tokens: Добавлять специальные токены
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[int]: Список идентификаторов токенов
|
|
||||||
"""
|
|
||||||
add_special_tokens = kwargs.get("add_special_tokens", False)
|
|
||||||
|
|
||||||
# Токенизация текста
|
|
||||||
tokens = self.compiled_pattern.findall(text)
|
|
||||||
|
|
||||||
# Применяем BPE к каждому токену
|
|
||||||
bpe_tokens = []
|
|
||||||
for token in tokens:
|
|
||||||
# Преобразуем токен в BPE представление
|
|
||||||
bpe_token = self._apply_bpe(token)
|
|
||||||
bpe_tokens.extend(bpe_token)
|
|
||||||
|
|
||||||
# Конвертируем в ID
|
|
||||||
token_ids = []
|
|
||||||
for token in bpe_tokens:
|
|
||||||
token_id = self.vocab.get(token, self.unk_token_id)
|
|
||||||
if token_id is not None:
|
|
||||||
token_ids.append(token_id)
|
|
||||||
|
|
||||||
# Добавляем специальные токены если нужно
|
|
||||||
if add_special_tokens:
|
|
||||||
if self.bos_token_id is not None:
|
|
||||||
token_ids.insert(0, self.bos_token_id)
|
|
||||||
if self.eos_token_id is not None:
|
|
||||||
token_ids.append(self.eos_token_id)
|
|
||||||
|
|
||||||
return token_ids
|
|
||||||
|
|
||||||
def _apply_bpe(self, token: str) -> List[str]:
|
|
||||||
"""
|
|
||||||
Применяет BPE к одному токену.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
token: Входной токен
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[str]: Список BPE токенов
|
|
||||||
"""
|
|
||||||
# Простая реализация - в реальной реализации нужно применять обученные мерджи
|
|
||||||
word = token + "</w>"
|
|
||||||
tokens = [word[i : i + 1] for i in range(len(word))]
|
|
||||||
|
|
||||||
# Применяем мерджи (упрощенная версия)
|
|
||||||
# В полной реализации нужно применять все обученные мерджи
|
|
||||||
for pair in self.merges:
|
|
||||||
i = 0
|
|
||||||
while i < len(tokens) - 1:
|
|
||||||
if tokens[i] == pair[0] and tokens[i + 1] == pair[1]:
|
|
||||||
tokens[i] = tokens[i] + tokens[i + 1]
|
|
||||||
del tokens[i + 1]
|
|
||||||
else:
|
|
||||||
i += 1
|
|
||||||
|
|
||||||
return tokens
|
|
||||||
|
|
||||||
def decode(self, tokens: List[int], **kwargs) -> str:
|
|
||||||
"""
|
|
||||||
Декодирует последовательность токенов в текст.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tokens: Список идентификаторов токенов
|
|
||||||
**kwargs: Дополнительные параметры
|
|
||||||
- skip_special_tokens: Пропускать специальные токены
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: Декодированный текст
|
|
||||||
"""
|
|
||||||
skip_special_tokens = kwargs.get("skip_special_tokens", True)
|
|
||||||
|
|
||||||
# Конвертируем ID в токены
|
|
||||||
token_strings = []
|
|
||||||
for token_id in tokens:
|
|
||||||
token = self.inverse_vocab.get(token_id, self.unk_token)
|
|
||||||
|
|
||||||
# Пропускаем специальные токены если нужно
|
|
||||||
if skip_special_tokens and token in [
|
|
||||||
self.pad_token,
|
|
||||||
self.unk_token,
|
|
||||||
self.bos_token,
|
|
||||||
self.eos_token,
|
|
||||||
]:
|
|
||||||
continue
|
|
||||||
|
|
||||||
token_strings.append(token)
|
|
||||||
|
|
||||||
# Объединяем токены в текст
|
|
||||||
text = "".join(token_strings)
|
|
||||||
|
|
||||||
# Убираем маркер конца слова
|
|
||||||
text = text.replace("</w>", " ")
|
|
||||||
|
|
||||||
return text.strip()
|
|
||||||
|
|
||||||
def save(self, filepath: str):
|
|
||||||
"""
|
|
||||||
Сохраняет BPE токенизатор в файл.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
filepath: Путь для сохранения
|
|
||||||
"""
|
|
||||||
import json
|
|
||||||
|
|
||||||
config = {
|
|
||||||
"vocab": self.vocab,
|
|
||||||
"merges": {f"{k[0]} {k[1]}": v for k, v in self.merges.items()},
|
|
||||||
"vocab_size": self.vocab_size,
|
|
||||||
"pad_token": self.pad_token,
|
|
||||||
"unk_token": self.unk_token,
|
|
||||||
"bos_token": self.bos_token,
|
|
||||||
"eos_token": self.eos_token,
|
|
||||||
"pattern": self.pattern,
|
|
||||||
"tokenizer_type": self.__class__.__name__,
|
|
||||||
}
|
|
||||||
|
|
||||||
with open(filepath, "w", encoding="utf-8") as f:
|
|
||||||
json.dump(config, f, ensure_ascii=False, indent=2)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def load(cls, filepath: str):
|
|
||||||
"""
|
|
||||||
Загружает BPE токенизатор из файла.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
filepath: Путь к файлу
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
BPETokenizer: Загруженный токенизатор
|
|
||||||
"""
|
|
||||||
import json
|
|
||||||
|
|
||||||
with open(filepath, "r", encoding="utf-8") as f:
|
|
||||||
config = json.load(f)
|
|
||||||
|
|
||||||
tokenizer = cls()
|
|
||||||
tokenizer.vocab = config["vocab"]
|
|
||||||
tokenizer.vocab_size = config["vocab_size"]
|
|
||||||
tokenizer.pad_token = config["pad_token"]
|
|
||||||
tokenizer.unk_token = config["unk_token"]
|
|
||||||
tokenizer.bos_token = config["bos_token"]
|
|
||||||
tokenizer.eos_token = config["eos_token"]
|
|
||||||
tokenizer.pattern = config.get("pattern", tokenizer.pattern)
|
|
||||||
tokenizer.compiled_pattern = re.compile(tokenizer.pattern, re.UNICODE)
|
|
||||||
|
|
||||||
# Восстанавливаем мерджи
|
|
||||||
merges = config.get("merges", {})
|
|
||||||
tokenizer.merges = {}
|
|
||||||
for k, v in merges.items():
|
|
||||||
parts = k.split()
|
|
||||||
if len(parts) == 2:
|
|
||||||
tokenizer.merges[(parts[0], parts[1])] = v
|
|
||||||
|
|
||||||
# Создаем обратный словарь
|
|
||||||
tokenizer.inverse_vocab = {v: k for k, v in tokenizer.vocab.items()}
|
|
||||||
|
|
||||||
# Обновляем ID специальных токенов
|
|
||||||
tokenizer.pad_token_id = tokenizer.vocab.get(tokenizer.pad_token)
|
|
||||||
tokenizer.unk_token_id = tokenizer.vocab.get(tokenizer.unk_token)
|
|
||||||
tokenizer.bos_token_id = tokenizer.vocab.get(tokenizer.bos_token)
|
|
||||||
tokenizer.eos_token_id = tokenizer.vocab.get(tokenizer.eos_token)
|
|
||||||
|
|
||||||
return tokenizer
|
|
||||||
|
|
||||||
|
|
||||||
# Упрощенная версия для быстрого старта
|
|
||||||
class SimpleBPETokenizer(BPETokenizer):
|
|
||||||
"""
|
|
||||||
Упрощенная версия BPE токенизатора для демонстрации.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def train(self, texts: List[str], vocab_size: int = 1000, **kwargs):
|
|
||||||
"""Упрощенное обучение для демонстрации."""
|
|
||||||
# Инициализация базового словаря
|
|
||||||
self._initialize_vocab()
|
|
||||||
|
|
||||||
# Добавляем базовые токены
|
|
||||||
special_tokens = [
|
|
||||||
self.pad_token,
|
|
||||||
self.unk_token,
|
|
||||||
self.bos_token,
|
|
||||||
self.eos_token,
|
|
||||||
]
|
|
||||||
self.add_special_tokens(special_tokens)
|
|
||||||
|
|
||||||
# Простая реализация - собираем все символы
|
|
||||||
all_chars = set()
|
|
||||||
for text in texts:
|
|
||||||
all_chars.update(text)
|
|
||||||
|
|
||||||
# Добавляем символы в словарь
|
|
||||||
for char in sorted(all_chars):
|
|
||||||
if char not in self.vocab:
|
|
||||||
self.vocab[char] = len(self.vocab)
|
|
||||||
|
|
||||||
self.inverse_vocab = {v: k for k, v in self.vocab.items()}
|
|
||||||
self.vocab_size = len(self.vocab)
|
|
||||||
|
|
||||||
# Обновляем ID специальных токенов
|
|
||||||
self.pad_token_id = self.vocab.get(self.pad_token)
|
|
||||||
self.unk_token_id = self.vocab.get(self.unk_token)
|
|
||||||
self.bos_token_id = self.vocab.get(self.bos_token)
|
|
||||||
self.eos_token_id = self.vocab.get(self.eos_token)
|
|
||||||
|
|
||||||
def encode(self, text: str, **kwargs) -> List[int]:
|
|
||||||
"""Упрощенное кодирование - разбиваем на символы."""
|
|
||||||
add_special_tokens = kwargs.get("add_special_tokens", False)
|
|
||||||
|
|
||||||
token_ids = []
|
|
||||||
for char in text:
|
|
||||||
token_id = self.vocab.get(char, self.unk_token_id)
|
|
||||||
if token_id is not None:
|
|
||||||
token_ids.append(token_id)
|
|
||||||
|
|
||||||
if add_special_tokens:
|
|
||||||
if self.bos_token_id is not None:
|
|
||||||
token_ids.insert(0, self.bos_token_id)
|
|
||||||
if self.eos_token_id is not None:
|
|
||||||
token_ids.append(self.eos_token_id)
|
|
||||||
|
|
||||||
return token_ids
|
|
||||||
|
|
||||||
def decode(self, tokens: List[int], **kwargs) -> str:
|
|
||||||
"""Упрощенное декодирование."""
|
|
||||||
skip_special_tokens = kwargs.get("skip_special_tokens", True)
|
|
||||||
|
|
||||||
chars = []
|
|
||||||
for token_id in tokens:
|
|
||||||
char = self.inverse_vocab.get(token_id, self.unk_token)
|
|
||||||
if skip_special_tokens and char in [
|
|
||||||
self.pad_token,
|
|
||||||
self.unk_token,
|
|
||||||
self.bos_token,
|
|
||||||
self.eos_token,
|
|
||||||
]:
|
|
||||||
continue
|
|
||||||
chars.append(char)
|
|
||||||
|
|
||||||
return "".join(chars)
|
|
||||||
@@ -10,16 +10,54 @@ from .base_tokenizer import BaseTokenizer
|
|||||||
|
|
||||||
class BPETokenizer(BaseTokenizer):
|
class BPETokenizer(BaseTokenizer):
|
||||||
"""
|
"""
|
||||||
BPE токенизатор для обработки текста.
|
BpeTokenizer — реализация токенизатора на алгоритме byte pair encoding (BPE).
|
||||||
|
|
||||||
Реализует алгоритм Byte Pair Encoding для создания субсловных токенов.
|
Назначение:
|
||||||
Использует вашу реализацию BPE.
|
-----------
|
||||||
|
- Преобразует открытый текст (строки, bytes) в последовательность числовых токенов для подачи в LLM и обратно.
|
||||||
|
- Разбивает текст на сабслова (байтовые пары), эффективно кодируя редкие слова длинными последовательностями, а частые — единичными токенами.
|
||||||
|
- Является стандартом де-факто в современных языковых моделях (GPT, LLaMA, BLOOM, Mistral, HuggingFace).
|
||||||
|
|
||||||
Примеры использования:
|
Как работает BPE:
|
||||||
>>> tokenizer = BPETokenizer()
|
-----------------
|
||||||
>>> tokenizer.train(["пример текста для обучения"], vocab_size=1000)
|
1. Строится словарь из наиболее популярных пар символов/субстрок.
|
||||||
>>> tokens = tokenizer.encode("новый текст")
|
2. Текст замещается наиболее длинными subword-подстроками из vocabulary (жадно).
|
||||||
|
3. Итог: многомиллионное лексическое пространство сокращается до компактного набора subword pieces.
|
||||||
|
|
||||||
|
Особенности алгоритма:
|
||||||
|
----------------------
|
||||||
|
- Отлично работает на всех языках, включая rare/compound/inflectable.
|
||||||
|
- Гибко масштабируется под размер итогового словаря/token space.
|
||||||
|
- Обычно хранит mapping (str/bytes → int и int → str/bytes) в JSON или словарном файле.
|
||||||
|
- Может использовать кастомные сепараторы, handle unknown.
|
||||||
|
|
||||||
|
Аргументы конструктора:
|
||||||
|
-----------------------
|
||||||
|
vocab_path: str
|
||||||
|
Путь к файлу BPE vocabulary (JSON, txt, в зависимости от реализации).
|
||||||
|
merges_path: str, optional
|
||||||
|
Путь к списку merge-правил (если используется блочное файловое раздельное хранение).
|
||||||
|
unk_token: str, optional
|
||||||
|
Токен для неизвестных последовательностей (по дефолту '[UNK]' или '<unk>').
|
||||||
|
pad_token, bos_token, eos_token: str, optional
|
||||||
|
Special tokens, если нужны для вашей архитектуры.
|
||||||
|
lowercase: bool, optional
|
||||||
|
Приводить ли текст к нижнему регистру перед токенизацией.
|
||||||
|
|
||||||
|
Пример:
|
||||||
|
-------
|
||||||
|
>>> tokenizer = BpeTokenizer(vocab_path=\"bpe_vocab.json\")
|
||||||
|
>>> tokens = tokenizer.encode(\"Hello, world!\")
|
||||||
|
>>> print(tokens) # [15496, 11, ...]
|
||||||
>>> text = tokenizer.decode(tokens)
|
>>> text = tokenizer.decode(tokens)
|
||||||
|
>>> print(text) # 'Hello, world!'
|
||||||
|
|
||||||
|
References:
|
||||||
|
-----------
|
||||||
|
- Sennrich et al, \"Neural Machine Translation of Rare Words with Subword Units\", 2015: https://arxiv.org/abs/1508.07909
|
||||||
|
- GPT-2 tokenization: https://github.com/openai/gpt-2
|
||||||
|
- HuggingFace tokenizers overview: https://huggingface.co/docs/tokenizers/index
|
||||||
|
- Visually: https://guillaume-be.github.io/2021-05-21/byte-pair-encoding/
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@@ -107,15 +145,21 @@ class BPETokenizer(BaseTokenizer):
|
|||||||
|
|
||||||
def encode(self, text: str, **kwargs) -> List[int]:
|
def encode(self, text: str, **kwargs) -> List[int]:
|
||||||
"""
|
"""
|
||||||
Кодирует текст в последовательность токенов.
|
Токенизирует входной текст в список числовых токенов (индексов).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
text: Входной текст
|
-----
|
||||||
**kwargs: Дополнительные параметры
|
text: str
|
||||||
- add_special_tokens: Добавлять специальные токены
|
Входная строка/текст для токенизации.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[int]: Список идентификаторов токенов
|
--------
|
||||||
|
List[int] — последовательность индексов из vocabulary.
|
||||||
|
|
||||||
|
Пример:
|
||||||
|
-------
|
||||||
|
>>> ids = tokenizer.encode(\"The quick brown fox\")
|
||||||
|
>>> print(ids)
|
||||||
"""
|
"""
|
||||||
add_special_tokens = kwargs.get("add_special_tokens", False)
|
add_special_tokens = kwargs.get("add_special_tokens", False)
|
||||||
|
|
||||||
@@ -175,15 +219,22 @@ class BPETokenizer(BaseTokenizer):
|
|||||||
|
|
||||||
def decode(self, tokens: List[int], **kwargs) -> str:
|
def decode(self, tokens: List[int], **kwargs) -> str:
|
||||||
"""
|
"""
|
||||||
Декодирует последовательность токенов в текст.
|
Декодирует последовательность токенов обратно в текстовую строку.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tokens: Список идентификаторов токенов
|
-----
|
||||||
**kwargs: Дополнительные параметры
|
ids: List[int]
|
||||||
- skip_special_tokens: Пропускать специальные токены
|
Список токен-индексов для распаковки.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: Декодированный текст
|
--------
|
||||||
|
text: str
|
||||||
|
Оригинальный (или приближённый) раскодированный текст.
|
||||||
|
|
||||||
|
Пример:
|
||||||
|
-------
|
||||||
|
>>> tokens = [15496, 11, 318, ...]
|
||||||
|
>>> text = tokenizer.decode(tokens)
|
||||||
"""
|
"""
|
||||||
skip_special_tokens = kwargs.get("skip_special_tokens", True)
|
skip_special_tokens = kwargs.get("skip_special_tokens", True)
|
||||||
|
|
||||||
|
|||||||
@@ -1,155 +0,0 @@
|
|||||||
import torch
|
|
||||||
from torch.utils.data import Dataset
|
|
||||||
from typing import List, Any
|
|
||||||
|
|
||||||
|
|
||||||
class TextDataset(Dataset):
|
|
||||||
"""
|
|
||||||
Простой датасет для языкового моделирования (LLM).
|
|
||||||
Работает с любым токенизатором, реализующим интерфейс BaseTokenizer.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, texts: List[str], tokenizer: Any, block_size: int = 128):
|
|
||||||
"""
|
|
||||||
Инициализация датасета.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
texts: Список текстов для обучения
|
|
||||||
tokenizer: Токенизатор с методами encode/decode
|
|
||||||
block_size: Максимальная длина последовательности
|
|
||||||
"""
|
|
||||||
self.examples = []
|
|
||||||
self.tokenizer = tokenizer
|
|
||||||
self.block_size = block_size
|
|
||||||
|
|
||||||
for text in texts:
|
|
||||||
# Кодируем текст в токены
|
|
||||||
input_ids = tokenizer.encode(text, add_special_tokens=False)
|
|
||||||
|
|
||||||
# Обрезаем или дополняем до нужной длины
|
|
||||||
if len(input_ids) > block_size:
|
|
||||||
input_ids = input_ids[:block_size]
|
|
||||||
else:
|
|
||||||
# Дополняем pad_token_id
|
|
||||||
pad_token_id = getattr(tokenizer, "pad_token_id", 0)
|
|
||||||
input_ids = input_ids + [pad_token_id] * (block_size - len(input_ids))
|
|
||||||
|
|
||||||
self.examples.append(input_ids)
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.examples)
|
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
|
||||||
input_ids = torch.tensor(self.examples[idx], dtype=torch.long)
|
|
||||||
labels = input_ids.clone()
|
|
||||||
return {"input_ids": input_ids, "labels": labels}
|
|
||||||
|
|
||||||
|
|
||||||
class StreamingTextDataset(Dataset):
|
|
||||||
"""
|
|
||||||
Датасет для потоковой обработки больших текстов.
|
|
||||||
Токенизация происходит на лету, что экономит память.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, texts: List[str], tokenizer: Any, block_size: int = 128):
|
|
||||||
self.texts = texts
|
|
||||||
self.tokenizer = tokenizer
|
|
||||||
self.block_size = block_size
|
|
||||||
|
|
||||||
# Получаем pad_token_id из токенизатора
|
|
||||||
self.pad_token_id = getattr(tokenizer, "pad_token_id", 0)
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.texts)
|
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
|
||||||
text = self.texts[idx]
|
|
||||||
|
|
||||||
# Токенизация на лету
|
|
||||||
input_ids = self.tokenizer.encode(text, add_special_tokens=False)
|
|
||||||
|
|
||||||
# Обрезаем или дополняем до нужной длины
|
|
||||||
if len(input_ids) > self.block_size:
|
|
||||||
input_ids = input_ids[: self.block_size]
|
|
||||||
else:
|
|
||||||
input_ids = input_ids + [self.pad_token_id] * (
|
|
||||||
self.block_size - len(input_ids)
|
|
||||||
)
|
|
||||||
|
|
||||||
input_ids = torch.tensor(input_ids, dtype=torch.long)
|
|
||||||
labels = input_ids.clone()
|
|
||||||
|
|
||||||
return {"input_ids": input_ids, "labels": labels}
|
|
||||||
|
|
||||||
|
|
||||||
class TextDatasetWithSpecialTokens(TextDataset):
|
|
||||||
"""
|
|
||||||
Расширенная версия TextDataset с поддержкой специальных токенов.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
texts: List[str],
|
|
||||||
tokenizer: Any,
|
|
||||||
block_size: int = 128,
|
|
||||||
add_bos: bool = False,
|
|
||||||
add_eos: bool = False,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
texts: Список текстов
|
|
||||||
tokenizer: Токенизатор
|
|
||||||
block_size: Максимальная длина
|
|
||||||
add_bos: Добавлять токен начала последовательности
|
|
||||||
add_eos: Добавлять токен конца последовательности
|
|
||||||
"""
|
|
||||||
self.examples = []
|
|
||||||
self.tokenizer = tokenizer
|
|
||||||
self.block_size = block_size
|
|
||||||
self.add_bos = add_bos
|
|
||||||
self.add_eos = add_eos
|
|
||||||
|
|
||||||
for text in texts:
|
|
||||||
# Кодируем с специальными токенами
|
|
||||||
input_ids = tokenizer.encode(
|
|
||||||
text, add_special_tokens=True, add_bos_token=add_bos, add_eos_token=eos
|
|
||||||
)
|
|
||||||
|
|
||||||
# Учитываем специальные токены при обрезке/дополнении
|
|
||||||
effective_block_size = block_size
|
|
||||||
if add_bos:
|
|
||||||
effective_block_size -= 1
|
|
||||||
if add_eos:
|
|
||||||
effective_block_size -= 1
|
|
||||||
|
|
||||||
if len(input_ids) > effective_block_size:
|
|
||||||
input_ids = input_ids[:effective_block_size]
|
|
||||||
|
|
||||||
# Добавляем специальные токены если нужно
|
|
||||||
if (
|
|
||||||
add_bos
|
|
||||||
and hasattr(tokenizer, "bos_token_id")
|
|
||||||
and tokenizer.bos_token_id is not None
|
|
||||||
):
|
|
||||||
input_ids = [tokenizer.bos_token_id] + input_ids
|
|
||||||
if (
|
|
||||||
add_eos
|
|
||||||
and hasattr(tokenizer, "eos_token_id")
|
|
||||||
and tokenizer.eos_token_id is not None
|
|
||||||
):
|
|
||||||
input_ids = input_ids + [tokenizer.eos_token_id]
|
|
||||||
|
|
||||||
# Дополняем до полной длины
|
|
||||||
pad_token_id = getattr(tokenizer, "pad_token_id", 0)
|
|
||||||
if len(input_ids) < block_size:
|
|
||||||
input_ids = input_ids + [pad_token_id] * (block_size - len(input_ids))
|
|
||||||
|
|
||||||
self.examples.append(input_ids)
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.examples)
|
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
|
||||||
input_ids = torch.tensor(self.examples[idx], dtype=torch.long)
|
|
||||||
labels = input_ids.clone()
|
|
||||||
return {"input_ids": input_ids, "labels": labels}
|
|
||||||
@@ -1,9 +1,71 @@
|
|||||||
|
"""
|
||||||
|
Модуль оптимизации для обучения нейронных сетей.
|
||||||
|
|
||||||
|
В данном модуле реализована функция выбора и инициализации оптимизаторов, наиболее популярных при обучении глубоких нейросетей:
|
||||||
|
- AdamW
|
||||||
|
- Adam
|
||||||
|
- SGD
|
||||||
|
|
||||||
|
Теоретическое обоснование:
|
||||||
|
--------------------------
|
||||||
|
Задача оптимизации в обучении нейросети заключается в минимизации функции потерь (Loss) по параметрам модели W. Современные методы базируются на стохастическом градиентном спуске (SGD), а также на его адаптивных модификациях (Adam, AdamW).
|
||||||
|
|
||||||
|
**SGD** (Stochastic Gradient Descent) — стохастический градиентный спуск:
|
||||||
|
W_{t+1} = W_t - \eta \nabla_W L(W_t)
|
||||||
|
Здесь \eta — шаг обучения, \nabla_W — градиент по параметрам. SGD позволяет случайно выбирать подмножество обучающих данных для каждой итерации, что ускоряет процесс и уменьшает избыточную корреляцию между примерами.
|
||||||
|
|
||||||
|
**Adam** (Adaptive Moment Estimation) — адаптивный алгоритм, который использует скользящую среднюю не только градиентов, но и их квадратов:
|
||||||
|
m_t = \beta_1 m_{t-1} + (1-\beta_1) \nabla_W L(W_t)
|
||||||
|
v_t = \beta_2 v_{t-1} + (1-\beta_2) (\nabla_W L(W_t))^2
|
||||||
|
W_{t+1} = W_t - \eta m_t/(\sqrt{v_t}+\epsilon)
|
||||||
|
Где \beta_1, \beta_2 — коэффициенты экспоненциального сглаживания.
|
||||||
|
|
||||||
|
**AdamW** — модификация Adam, в которой weight decay (имплицитная L2-регуляризация) вводится корректно, отдельно от шага градиента, что улучшает обобщающую способность моделей:
|
||||||
|
W_{t+1} = W_t - \eta [ m_t/(\sqrt{v_t}+\epsilon) + \lambda W_t ]
|
||||||
|
Где \lambda — коэффициент weight decay.
|
||||||
|
|
||||||
|
Детальное описание: https://arxiv.org/abs/1711.05101
|
||||||
|
|
||||||
|
Пример использования:
|
||||||
|
---------------------
|
||||||
|
>>> optimizer = get_optimizer(model, lr=3e-4, weight_decay=0.01, optimizer_type="adamw")
|
||||||
|
>>> for batch in dataloader:
|
||||||
|
... loss = model(batch)
|
||||||
|
... loss.backward()
|
||||||
|
... optimizer.step()
|
||||||
|
... optimizer.zero_grad()
|
||||||
|
|
||||||
|
"""
|
||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
|
|
||||||
|
|
||||||
def get_optimizer(model, lr=3e-4, weight_decay=0.01, optimizer_type="adamw"):
|
def get_optimizer(model, lr=3e-4, weight_decay=0.01, optimizer_type="adamw"):
|
||||||
"""
|
"""
|
||||||
Возвращает оптимизатор для обучения модели.
|
Фабричная функция для создания оптимизатора PyTorch по выбранному типу.
|
||||||
|
|
||||||
|
Параметры
|
||||||
|
---------
|
||||||
|
model : torch.nn.Module
|
||||||
|
Модель, параметры которой требуется оптимизировать.
|
||||||
|
lr : float, по умолчанию 3e-4
|
||||||
|
Шаг обучения (learning rate).
|
||||||
|
weight_decay : float, по умолчанию 0.01
|
||||||
|
Коэффициент weight decay (L2-регуляризации).
|
||||||
|
optimizer_type : str, по умолчанию 'adamw'
|
||||||
|
Тип оптимизатора: 'adamw', 'adam' или 'sgd'.
|
||||||
|
|
||||||
|
Возвращаемое значение
|
||||||
|
---------------------
|
||||||
|
torch.optim.Optimizer
|
||||||
|
Объект-оптимизатор, готовый к использованию.
|
||||||
|
|
||||||
|
Исключения
|
||||||
|
----------
|
||||||
|
ValueError: Если передан неизвестный тип оптимизатора.
|
||||||
|
|
||||||
|
Пример использования:
|
||||||
|
---------------------
|
||||||
|
>>> optimizer = get_optimizer(model, lr=1e-3, optimizer_type='sgd')
|
||||||
"""
|
"""
|
||||||
if optimizer_type.lower() == "adamw":
|
if optimizer_type.lower() == "adamw":
|
||||||
return optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
|
return optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
|
||||||
|
|||||||
@@ -1,18 +1,66 @@
|
|||||||
from torch.optim.lr_scheduler import LambdaLR
|
"""
|
||||||
|
Модуль для управления динамикой шага обучения (learning rate scheduling) при обучении нейронных сетей.
|
||||||
|
|
||||||
|
Теоретическое обоснование:
|
||||||
|
--------------------------
|
||||||
|
Плавная динамика шага обучения существенно влияет на сходимость и итоговое качество моделей. Введение этапа "разогрева" (warmup) — техники, при которой шаг обучения начинается с нуля и постепенно увеличивается до целевого значения, снижает вероятность неустойчивых градиентов на старте обучения. Подобная стратегия показала свою эффективность для крупных нейронных сетей, особенно в трансформерах (Vaswani et al, 2017, https://arxiv.org/abs/1706.03762).
|
||||||
|
|
||||||
|
Линейный scheduler с warmup задаёт динамику learning rate по формуле:
|
||||||
|
- если current_step < num_warmup_steps:
|
||||||
|
lr = lr_init * (current_step / num_warmup_steps)
|
||||||
|
- иначе:
|
||||||
|
lr = lr_init * max(0, (num_training_steps - current_step) / (num_training_steps - num_warmup_steps))
|
||||||
|
|
||||||
|
Пример использования:
|
||||||
|
---------------------
|
||||||
|
>>> optimizer = get_optimizer(model, lr=3e-4)
|
||||||
|
>>> scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=1000, num_training_steps=10000)
|
||||||
|
>>> for step in range(num_training_steps):
|
||||||
|
... optimizer.step()
|
||||||
|
... scheduler.step()
|
||||||
|
"""
|
||||||
|
|
||||||
|
from torch.optim.lr_scheduler import LambdaLR
|
||||||
|
|
||||||
def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps):
|
def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps):
|
||||||
"""
|
"""
|
||||||
Линейный планировщик обучения с warmup.
|
Создаёт линейный планировщик изменения шага обучения (learning rate) с этапом warmup для оптимизатора PyTorch.
|
||||||
|
|
||||||
|
Аргументы
|
||||||
|
---------
|
||||||
|
optimizer : torch.optim.Optimizer
|
||||||
|
Оптимизатор, для которого применяется scheduler.
|
||||||
|
num_warmup_steps : int
|
||||||
|
Количество шагов разогрева (warmup) — начиная с нулевого шага и плавного увеличения lr до номинального значения.
|
||||||
|
num_training_steps : int
|
||||||
|
Общее количество шагов (эпох/итераций) обучения модели.
|
||||||
|
|
||||||
|
Возвращаемое значение
|
||||||
|
---------------------
|
||||||
|
torch.optim.lr_scheduler.LambdaLR
|
||||||
|
Планировщик lr, который следует вызывать после каждого optimizer.step() во время обучения.
|
||||||
|
|
||||||
|
Теоретическая справка
|
||||||
|
---------------------
|
||||||
|
Такой scheduler позволяет повысить стабильность и устойчивость обучения крупных моделей (особенно трансформеров), предотвращая резкие скачки градиентов в начале.
|
||||||
|
|
||||||
|
Пример:
|
||||||
|
-------
|
||||||
|
>>> optimizer = get_optimizer(model, lr=3e-4)
|
||||||
|
>>> scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=1000, num_training_steps=10000)
|
||||||
|
>>> for step in range(num_training_steps):
|
||||||
|
... optimizer.step()
|
||||||
|
... scheduler.step()
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def lr_lambda(current_step):
|
def lr_lambda(current_step):
|
||||||
|
# Линейный рост lr на этапе разогрева
|
||||||
if current_step < num_warmup_steps:
|
if current_step < num_warmup_steps:
|
||||||
return float(current_step) / float(max(1, num_warmup_steps))
|
return float(current_step) / float(max(1, num_warmup_steps))
|
||||||
|
# Линейное затухание lr после разогрева
|
||||||
return max(
|
return max(
|
||||||
0.0,
|
0.0,
|
||||||
float(num_training_steps - current_step)
|
float(num_training_steps - current_step)
|
||||||
/ float(max(1, num_training_steps - num_warmup_steps)),
|
/ float(max(1, num_training_steps - num_warmup_steps)),
|
||||||
)
|
)
|
||||||
|
|
||||||
return LambdaLR(optimizer, lr_lambda)
|
return LambdaLR(optimizer, lr_lambda)
|
||||||
|
|||||||
@@ -1,3 +1,22 @@
|
|||||||
|
"""
|
||||||
|
Модуль для организации процесса обучения больших языковых моделей (LLM).
|
||||||
|
|
||||||
|
Научное и техническое обоснование
|
||||||
|
----------------------------------
|
||||||
|
Эффективное обучение современных трансформеров (GPT, LLaMA, Mistral и др.) опирается на принципы языкового моделирования (Language Modeling):
|
||||||
|
- Предсказание вероятности следующего токена на основе предыдущих.
|
||||||
|
- Использование функции потерь кросс-энтропии (cross-entropy) с маскированием паддингов.
|
||||||
|
- Циклы обратного распространения ошибки (backpropagation), оптимизационные алгоритмы (например, AdamW), управление шагом обучения (scheduler с warmup), обрезка градиентов (grad clipping).
|
||||||
|
|
||||||
|
Реализация объединяет лучшие практики обучения LLM, универсальный API к моделям, датасетам, оптимизаторам и lr-схемам.
|
||||||
|
|
||||||
|
Подробнее: Vaswani et al. "Attention is All You Need" (2017), Radford et al. "Language Models are Unsupervised Multitask Learners" (2019)
|
||||||
|
|
||||||
|
Пример использования
|
||||||
|
--------------------
|
||||||
|
>>> trainer = Trainer(model, train_dataset, val_dataset, lr=3e-4, batch_size=8, num_epochs=3, warmup_steps=100)
|
||||||
|
>>> trainer.train()
|
||||||
|
"""
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
@@ -8,7 +27,33 @@ from llm.training.scheduler import get_linear_schedule_with_warmup
|
|||||||
|
|
||||||
class Trainer:
|
class Trainer:
|
||||||
"""
|
"""
|
||||||
Универсальный класс обучения LLM (GPT, LLaMA, Mistral и т.д.)
|
Универсальный и расширяемый класс для обучения больших языковых моделей (Large Language Models, LLM).
|
||||||
|
|
||||||
|
Поддерживаются архитектуры семейства GPT, LLaMA, Mistral и другие автогрессивные модели.
|
||||||
|
Объединяет:
|
||||||
|
- Тренировку по задаче языкового моделирования (Causal LM)
|
||||||
|
- Cross-entropy loss с автоматическим сдвигом логитов/меток
|
||||||
|
- Поддержку Grad Clipping, Scheduler, Validation
|
||||||
|
- Унифицированный даталоадер, автоматический выбор устройства (CPU/GPU)
|
||||||
|
|
||||||
|
Атрибуты
|
||||||
|
--------
|
||||||
|
model : torch.nn.Module
|
||||||
|
Модель для обучения языковому моделированию
|
||||||
|
train_loader : torch.utils.data.DataLoader
|
||||||
|
Даталоадер обучающего набора
|
||||||
|
val_loader : torch.utils.data.DataLoader или None
|
||||||
|
Даталоадер валидационного набора (если задан)
|
||||||
|
optimizer : torch.optim.Optimizer
|
||||||
|
Оптимизатор параметров модели
|
||||||
|
scheduler : torch.optim.lr_scheduler.LambdaLR
|
||||||
|
Планировщик learning rate (инициализируется в train)
|
||||||
|
device : torch.device
|
||||||
|
Устройство (CPU или CUDA), куда помещается модель
|
||||||
|
num_epochs : int
|
||||||
|
Количество эпох обучения
|
||||||
|
warmup_steps : int
|
||||||
|
Число шагов warmup для scheduler
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -21,6 +66,26 @@ class Trainer:
|
|||||||
num_epochs=3,
|
num_epochs=3,
|
||||||
warmup_steps=100,
|
warmup_steps=100,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Инициализация обучающего класса Trainer.
|
||||||
|
|
||||||
|
Аргументы
|
||||||
|
---------
|
||||||
|
model : torch.nn.Module
|
||||||
|
Модель для обучения (например, GPT, LLaMA, Mistral).
|
||||||
|
train_dataset : torch.utils.data.Dataset
|
||||||
|
Обучающий датасет с полями input_ids и labels.
|
||||||
|
val_dataset : torch.utils.data.Dataset, optional
|
||||||
|
Валидационный датасет для контроля качества обучения.
|
||||||
|
lr : float, default=3e-4
|
||||||
|
Начальный шаг обучения.
|
||||||
|
batch_size : int, default=8
|
||||||
|
Размер обучающего мини-батча.
|
||||||
|
num_epochs : int, default=3
|
||||||
|
Количество эпох обучения.
|
||||||
|
warmup_steps : int, default=100
|
||||||
|
Количество шагов разогрева (warmup) learning rate.
|
||||||
|
"""
|
||||||
self.model = model
|
self.model = model
|
||||||
self.train_loader = DataLoader(
|
self.train_loader = DataLoader(
|
||||||
train_dataset, batch_size=batch_size, shuffle=True
|
train_dataset, batch_size=batch_size, shuffle=True
|
||||||
@@ -37,26 +102,52 @@ class Trainer:
|
|||||||
|
|
||||||
def compute_lm_loss(self, logits, labels):
|
def compute_lm_loss(self, logits, labels):
|
||||||
"""
|
"""
|
||||||
Вычисляет loss для языкового моделирования.
|
Вычисляет функцию потерь (loss) для задачи автогрессивного языкового моделирования.
|
||||||
Сдвигает логиты и метки для предсказания следующего токена.
|
|
||||||
|
Производит сдвиг логитов и меток: предсказания делаются для следующего токена.
|
||||||
|
Используется кросс-энтропия (CrossEntropyLoss), что соответствует максимизации логарифма правдоподобия:
|
||||||
|
L = -log P(w_{t+1} | w_1,...,w_t)
|
||||||
|
|
||||||
|
Аргументы
|
||||||
|
---------
|
||||||
|
logits : torch.Tensor
|
||||||
|
Логиты модели: (batch_size, seq_len, vocab_size)
|
||||||
|
labels : torch.Tensor
|
||||||
|
Правильные метки: (batch_size, seq_len)
|
||||||
|
Возвращаемое значение
|
||||||
|
---------------------
|
||||||
|
loss : torch.Tensor
|
||||||
|
Средний loss по batch.
|
||||||
"""
|
"""
|
||||||
# Сдвигаем логиты и метки для языкового моделирования
|
# Сдвигаем логиты и метки для языкового моделирования (автогрессия)
|
||||||
shift_logits = logits[..., :-1, :].contiguous()
|
shift_logits = logits[..., :-1, :].contiguous()
|
||||||
shift_labels = labels[..., 1:].contiguous()
|
shift_labels = labels[..., 1:].contiguous()
|
||||||
|
|
||||||
# Вычисляем cross-entropy loss
|
# CrossEntropyLoss (игнорируем паддинги: ignore_index=-100)
|
||||||
loss = F.cross_entropy(
|
loss = F.cross_entropy(
|
||||||
shift_logits.view(-1, shift_logits.size(-1)),
|
shift_logits.view(-1, shift_logits.size(-1)),
|
||||||
shift_labels.view(-1),
|
shift_labels.view(-1),
|
||||||
ignore_index=-100, # Игнорируем padding tokens
|
ignore_index=-100, # Padding токены не участвуют в loss
|
||||||
)
|
)
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
def train(self):
|
def train(self):
|
||||||
|
"""
|
||||||
|
Запускает процесс обучения модели по заданному числу эпох.
|
||||||
|
|
||||||
|
В процессе:
|
||||||
|
- Применяет optimizer, scheduler с warmup и decay, grad clipping (обрезка градиентов)
|
||||||
|
- Вызывает функцию потерь для языкового моделирования
|
||||||
|
- Показывает динамику процесса (tqdm)
|
||||||
|
- После каждой эпохи возможно проведение валидации
|
||||||
|
|
||||||
|
Параметры задаются на этапе инициализации Trainer.
|
||||||
|
"""
|
||||||
total_steps = len(self.train_loader) * self.num_epochs
|
total_steps = len(self.train_loader) * self.num_epochs
|
||||||
self.scheduler = get_linear_schedule_with_warmup(
|
self.scheduler = get_linear_schedule_with_warmup(
|
||||||
self.optimizer, self.warmup_steps, total_steps
|
self.optimizer, self.warmup_steps, total_steps
|
||||||
)
|
)
|
||||||
|
self.loss_history = [] # добавлено: лог средних потерь
|
||||||
|
|
||||||
for epoch in range(self.num_epochs):
|
for epoch in range(self.num_epochs):
|
||||||
self.model.train()
|
self.model.train()
|
||||||
@@ -71,14 +162,14 @@ class Trainer:
|
|||||||
input_ids = batch["input_ids"].to(self.device)
|
input_ids = batch["input_ids"].to(self.device)
|
||||||
labels = batch["labels"].to(self.device)
|
labels = batch["labels"].to(self.device)
|
||||||
|
|
||||||
# Универсально обрабатываем выход (tuple/logits)
|
# Универсально обрабатываем выходы модели: tuple или просто tensor (logits)
|
||||||
outputs = self.model(input_ids)
|
outputs = self.model(input_ids)
|
||||||
if isinstance(outputs, tuple):
|
if isinstance(outputs, tuple):
|
||||||
logits = outputs[0]
|
logits = outputs[0]
|
||||||
else:
|
else:
|
||||||
logits = outputs
|
logits = outputs
|
||||||
|
|
||||||
# Trainer вычисляет loss
|
# Вычисляем loss автогрессивной LM-задачи
|
||||||
loss = self.compute_lm_loss(logits, labels)
|
loss = self.compute_lm_loss(logits, labels)
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
@@ -90,12 +181,19 @@ class Trainer:
|
|||||||
progress_bar.set_postfix(loss=loss.item())
|
progress_bar.set_postfix(loss=loss.item())
|
||||||
|
|
||||||
avg_loss = total_loss / len(self.train_loader)
|
avg_loss = total_loss / len(self.train_loader)
|
||||||
|
self.loss_history.append(avg_loss) # добавлено: запоминаем loss
|
||||||
print(f"Epoch {epoch+1} finished — avg loss: {avg_loss:.4f}")
|
print(f"Epoch {epoch+1} finished — avg loss: {avg_loss:.4f}")
|
||||||
|
|
||||||
if self.val_loader:
|
if self.val_loader:
|
||||||
self.evaluate()
|
self.evaluate()
|
||||||
|
|
||||||
def evaluate(self):
|
def evaluate(self):
|
||||||
|
"""
|
||||||
|
Оценивает модель на валидационном датасете (если задан).
|
||||||
|
|
||||||
|
В режиме eval() модели отключается dropout и все стохастические элементы.
|
||||||
|
Возвращает среднее значение функции потерь (loss) по всему validation set.
|
||||||
|
"""
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
total_loss = 0
|
total_loss = 0
|
||||||
|
|
||||||
|
|||||||
65
llm/tests/core/test_cached_decoder.py
Normal file
65
llm/tests/core/test_cached_decoder.py
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
import torch
|
||||||
|
import pytest
|
||||||
|
from llm.core.cached_decoder import CachedDecoder
|
||||||
|
from llm.core.feed_forward import FeedForward
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def decoder_config():
|
||||||
|
return dict(
|
||||||
|
num_heads=4,
|
||||||
|
emb_size=32,
|
||||||
|
head_size=8,
|
||||||
|
feed_forward_layer=FeedForward(emb_size=32, dropout=0.1, activation="gelu"),
|
||||||
|
max_seq_len=64,
|
||||||
|
dropout=0.1
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_cached_decoder_init(decoder_config):
|
||||||
|
model = CachedDecoder(**decoder_config)
|
||||||
|
assert model is not None
|
||||||
|
# Main attention block is usually stored as _heads or _attention (which itself includes _q _k _v)
|
||||||
|
assert hasattr(model, '_heads') or hasattr(model, '_attention')
|
||||||
|
assert hasattr(model, '_ff') or hasattr(model, 'feed_forward_layer')
|
||||||
|
|
||||||
|
def test_cached_decoder_forward_shape(decoder_config):
|
||||||
|
model = CachedDecoder(**decoder_config)
|
||||||
|
batch, seq_len, emb_size = 3, 10, decoder_config['emb_size']
|
||||||
|
x = torch.randn(batch, seq_len, emb_size)
|
||||||
|
output, cache = model(x, use_cache=True)
|
||||||
|
assert output.shape == (batch, seq_len, emb_size)
|
||||||
|
assert cache is not None
|
||||||
|
|
||||||
|
def test_cached_decoder_forward_no_cache(decoder_config):
|
||||||
|
model = CachedDecoder(**decoder_config)
|
||||||
|
batch, seq_len, emb_size = 2, 12, decoder_config['emb_size']
|
||||||
|
x = torch.randn(batch, seq_len, emb_size)
|
||||||
|
output, cache = model(x, use_cache=False)
|
||||||
|
assert output.shape == (batch, seq_len, emb_size)
|
||||||
|
assert cache is None
|
||||||
|
|
||||||
|
def test_cached_decoder_error_on_long_seq(decoder_config):
|
||||||
|
model = CachedDecoder(**decoder_config)
|
||||||
|
batch, seq_len, emb_size = 1, decoder_config['max_seq_len'] + 1, decoder_config['emb_size']
|
||||||
|
x = torch.randn(batch, seq_len, emb_size)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
model(x)
|
||||||
|
|
||||||
|
def test_cached_decoder_backward(decoder_config):
|
||||||
|
model = CachedDecoder(**decoder_config)
|
||||||
|
batch, seq_len, emb_size = 2, 7, decoder_config['emb_size']
|
||||||
|
x = torch.randn(batch, seq_len, emb_size, requires_grad=True)
|
||||||
|
output, cache = model(x)
|
||||||
|
loss = output.sum()
|
||||||
|
loss.backward()
|
||||||
|
assert x.grad is not None
|
||||||
|
|
||||||
|
def test_cached_decoder_kv_cache_chain(decoder_config):
|
||||||
|
model = CachedDecoder(**decoder_config)
|
||||||
|
batch, seq_len, emb_size = 1, 4, decoder_config['emb_size']
|
||||||
|
x = torch.randn(batch, seq_len, emb_size)
|
||||||
|
# Первый проход — кэша нет
|
||||||
|
_, cache = model(x, use_cache=True)
|
||||||
|
# Второй проход — передаём кэш, добавляем еще токен:
|
||||||
|
next_x = torch.randn(batch, 1, emb_size)
|
||||||
|
_, cache2 = model(next_x, use_cache=True, cache=cache)
|
||||||
|
assert cache2 is not None
|
||||||
46
llm/tests/core/test_gelu.py
Normal file
46
llm/tests/core/test_gelu.py
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
import torch
|
||||||
|
import pytest
|
||||||
|
from llm.core.gelu import GELU
|
||||||
|
|
||||||
|
def test_gelu_shapes_and_dtype():
|
||||||
|
gelu = GELU()
|
||||||
|
x = torch.randn(4, 16, 8)
|
||||||
|
y = gelu(x)
|
||||||
|
assert y.shape == x.shape
|
||||||
|
assert y.dtype == x.dtype
|
||||||
|
|
||||||
|
def test_gelu_known_values():
|
||||||
|
gelu = GELU()
|
||||||
|
x = torch.tensor([-3.0, 0.0, 3.0])
|
||||||
|
y = gelu(x)
|
||||||
|
# Сравнение с PyTorch F.gelu (которая использует точный алгоритм)
|
||||||
|
y_ref = torch.nn.functional.gelu(x)
|
||||||
|
diff = (y - y_ref).abs().max().item()
|
||||||
|
assert diff < 5e-3, f"Max difference {diff} exceeds threshold"
|
||||||
|
|
||||||
|
def test_gelu_is_smooth_and_monotonic():
|
||||||
|
gelu = GELU()
|
||||||
|
x = torch.linspace(-5, 5, 100)
|
||||||
|
y = gelu(x)
|
||||||
|
dy = y[1:] - y[:-1]
|
||||||
|
# Проверяем, что функция GELU хотя бы локально монотонна на большинстве промежутков
|
||||||
|
assert (dy.mean() > 0 or dy.mean() < 0)
|
||||||
|
|
||||||
|
def test_gelu_gradients():
|
||||||
|
gelu = GELU()
|
||||||
|
x = torch.randn(3, 5, requires_grad=True)
|
||||||
|
y = gelu(x)
|
||||||
|
loss = y.sum()
|
||||||
|
loss.backward()
|
||||||
|
assert x.grad is not None
|
||||||
|
assert x.grad.shape == x.shape
|
||||||
|
|
||||||
|
def test_gelu_large_vs_small():
|
||||||
|
gelu = GELU()
|
||||||
|
x_pos = torch.tensor([100.0])
|
||||||
|
x_neg = torch.tensor([-100.0])
|
||||||
|
y_pos = gelu(x_pos)
|
||||||
|
y_neg = gelu(x_neg)
|
||||||
|
# Для больших положительных GELU(x) ~ x, для больших отрицательных ~0
|
||||||
|
assert torch.allclose(y_pos, x_pos, rtol=1e-4, atol=1e-4)
|
||||||
|
assert torch.allclose(y_neg, torch.zeros_like(x_neg), rtol=1e-4, atol=1e-4)
|
||||||
85
llm/tests/core/test_group_query_attention.py
Normal file
85
llm/tests/core/test_group_query_attention.py
Normal file
@@ -0,0 +1,85 @@
|
|||||||
|
# llm/tests/core/test_group_query_attention.py
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import pytest
|
||||||
|
from llm.core.group_query_attention import GroupedQueryAttention
|
||||||
|
from llm.core.rope import RoPE
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def params():
|
||||||
|
return {
|
||||||
|
'num_q_heads': 4,
|
||||||
|
'num_kv_heads': 2,
|
||||||
|
'emb_size': 16,
|
||||||
|
'head_size': 4,
|
||||||
|
'max_seq_len': 32,
|
||||||
|
'window_size': 8,
|
||||||
|
'dropout': 0.0
|
||||||
|
}
|
||||||
|
|
||||||
|
def test_initialization(params):
|
||||||
|
attn = GroupedQueryAttention(**params)
|
||||||
|
assert isinstance(attn, GroupedQueryAttention)
|
||||||
|
|
||||||
|
def test_forward_shape(params):
|
||||||
|
batch, seq = 2, 10
|
||||||
|
x = torch.randn(batch, seq, params['emb_size'])
|
||||||
|
attn = GroupedQueryAttention(**params)
|
||||||
|
y, cache = attn(x)
|
||||||
|
assert y.shape == (batch, seq, params['emb_size'])
|
||||||
|
assert cache is not None
|
||||||
|
assert isinstance(y, torch.Tensor)
|
||||||
|
|
||||||
|
def test_forward_shape_with_mask(params):
|
||||||
|
batch, seq = 2, 10
|
||||||
|
x = torch.randn(batch, seq, params['emb_size'])
|
||||||
|
mask = torch.tril(torch.ones(seq, seq)).bool()
|
||||||
|
attn = GroupedQueryAttention(**params)
|
||||||
|
y, _ = attn(x, mask=mask)
|
||||||
|
assert y.shape == (batch, seq, params['emb_size'])
|
||||||
|
|
||||||
|
def test_kv_repetition(params):
|
||||||
|
batch, seq = 1, 3
|
||||||
|
attn = GroupedQueryAttention(**params)
|
||||||
|
kv = torch.randn(batch, params['num_kv_heads'], seq, params['head_size'])
|
||||||
|
rep = attn._repeat_kv_heads(kv, params['num_q_heads'], params['num_kv_heads'])
|
||||||
|
assert rep.shape == (batch, params['num_q_heads'], seq, params['head_size'])
|
||||||
|
|
||||||
|
def test_window_mask(params):
|
||||||
|
attn = GroupedQueryAttention(**params)
|
||||||
|
mask = attn._create_sliding_window_mask(8, 3)
|
||||||
|
assert mask.shape == (8, 8)
|
||||||
|
# Проверим булеву маску окна в позиции 4
|
||||||
|
expected = torch.tensor([True, True, True, True, False, False])
|
||||||
|
assert torch.equal(mask[4, 1:7], expected)
|
||||||
|
|
||||||
|
def test_forward_with_rope(params):
|
||||||
|
batch, seq = 2, 12
|
||||||
|
x = torch.randn(batch, seq, params['emb_size'])
|
||||||
|
rope = RoPE(head_size=params['head_size'], max_seq_len=params['max_seq_len'])
|
||||||
|
params2 = params.copy()
|
||||||
|
params2['rope'] = rope
|
||||||
|
attn = GroupedQueryAttention(**params2)
|
||||||
|
y, _ = attn(x)
|
||||||
|
assert y.shape == (batch, seq, params['emb_size'])
|
||||||
|
|
||||||
|
def test_cache_usage(params):
|
||||||
|
batch, seq = 1, 5
|
||||||
|
x = torch.randn(batch, seq, params['emb_size'])
|
||||||
|
attn = GroupedQueryAttention(**params)
|
||||||
|
# Первый проход - получаем кэш
|
||||||
|
_, cache = attn(x)
|
||||||
|
# Второй проход с кэшем (имитируем автокомплит seq_len=1)
|
||||||
|
x2 = torch.randn(batch, 1, params['emb_size'])
|
||||||
|
y2, cache2 = attn(x2, cache=cache)
|
||||||
|
assert cache2 is not None
|
||||||
|
assert y2.shape == (batch, 1, params['emb_size'])
|
||||||
|
|
||||||
|
def test_gradient_backward(params):
|
||||||
|
batch, seq = 2, 6
|
||||||
|
x = torch.randn(batch, seq, params['emb_size'], requires_grad=True)
|
||||||
|
attn = GroupedQueryAttention(**params)
|
||||||
|
y, _ = attn(x)
|
||||||
|
y.sum().backward()
|
||||||
|
for param in attn.parameters():
|
||||||
|
assert param.grad is not None
|
||||||
66
llm/tests/core/test_mistral_decoder.py
Normal file
66
llm/tests/core/test_mistral_decoder.py
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
import torch
|
||||||
|
import pytest
|
||||||
|
from llm.core.mistral_decoder import MistralDecoder
|
||||||
|
from llm.core.rope import RoPE
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def decoder_config():
|
||||||
|
# Current MistralDecoder is a single block (not a stack).
|
||||||
|
return dict(
|
||||||
|
num_q_heads=4,
|
||||||
|
num_kv_heads=2,
|
||||||
|
emb_size=32,
|
||||||
|
head_size=8,
|
||||||
|
max_seq_len=128,
|
||||||
|
window_size=16,
|
||||||
|
rope=RoPE(head_size=8, max_seq_len=128),
|
||||||
|
dropout=0.0
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_mistral_decoder_init(decoder_config):
|
||||||
|
model = MistralDecoder(**decoder_config)
|
||||||
|
assert model is not None
|
||||||
|
|
||||||
|
def test_mistral_decoder_forward_shapes(decoder_config):
|
||||||
|
model = MistralDecoder(**decoder_config)
|
||||||
|
batch, seq_len, emb_size = 2, 10, decoder_config['emb_size']
|
||||||
|
x = torch.randn(batch, seq_len, emb_size)
|
||||||
|
output, cache = model(x, use_cache=True)
|
||||||
|
assert output.shape == (batch, seq_len, emb_size)
|
||||||
|
assert cache is not None
|
||||||
|
|
||||||
|
def test_mistral_decoder_forward_no_cache(decoder_config):
|
||||||
|
model = MistralDecoder(**decoder_config)
|
||||||
|
batch, seq_len, emb_size = 2, 10, decoder_config['emb_size']
|
||||||
|
x = torch.randn(batch, seq_len, emb_size)
|
||||||
|
output, cache = model(x, use_cache=False)
|
||||||
|
assert output.shape == (batch, seq_len, emb_size)
|
||||||
|
assert cache is None
|
||||||
|
|
||||||
|
def test_mistral_decoder_cache_shapes(decoder_config):
|
||||||
|
model = MistralDecoder(**decoder_config)
|
||||||
|
batch, seq_len, emb_size = 2, 8, decoder_config['emb_size']
|
||||||
|
x = torch.randn(batch, seq_len, emb_size)
|
||||||
|
# Первый проход — без кэша
|
||||||
|
_, cache = model(x, use_cache=True)
|
||||||
|
# Второй проход — заполняем кэш
|
||||||
|
x_next = torch.randn(batch, 1, emb_size)
|
||||||
|
_, cache2 = model(x_next, use_cache=True, cache=cache)
|
||||||
|
# Можно проверить, что кэш не None и корректной структуры:
|
||||||
|
assert cache2 is not None
|
||||||
|
|
||||||
|
def test_mistral_decoder_shape_error(decoder_config):
|
||||||
|
model = MistralDecoder(**decoder_config)
|
||||||
|
batch, seq_len, emb_size = 2, decoder_config['max_seq_len'] + 1, decoder_config['emb_size']
|
||||||
|
x = torch.randn(batch, seq_len, emb_size)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
model(x)
|
||||||
|
|
||||||
|
def test_mistral_decoder_backward(decoder_config):
|
||||||
|
model = MistralDecoder(**decoder_config)
|
||||||
|
batch, seq_len, emb_size = 2, 10, decoder_config['emb_size']
|
||||||
|
x = torch.randn(batch, seq_len, emb_size, requires_grad=True)
|
||||||
|
output, _ = model(x, use_cache=False)
|
||||||
|
loss = output.sum()
|
||||||
|
loss.backward()
|
||||||
|
assert x.grad is not None
|
||||||
@@ -19,7 +19,7 @@ class TestMultiHeadAttention:
|
|||||||
assert attention is not None
|
assert attention is not None
|
||||||
|
|
||||||
# Check internal attributes
|
# Check internal attributes
|
||||||
assert len(attention._heads) == num_heads
|
assert attention._num_heads == num_heads
|
||||||
assert attention._layer.in_features == embed_dim
|
assert attention._layer.in_features == embed_dim
|
||||||
assert attention._layer.out_features == embed_dim
|
assert attention._layer.out_features == embed_dim
|
||||||
|
|
||||||
@@ -102,8 +102,10 @@ class TestMultiHeadAttention:
|
|||||||
|
|
||||||
# Check that gradients are computed for learnable parameters
|
# Check that gradients are computed for learnable parameters
|
||||||
assert attention._layer.weight.grad is not None
|
assert attention._layer.weight.grad is not None
|
||||||
if len(attention._heads) > 0:
|
# Проверяем, что также у градиентов весов q/k/v есть значения
|
||||||
assert attention._heads[0]._q.weight.grad is not None
|
assert attention._q.weight.grad is not None
|
||||||
|
assert attention._k.weight.grad is not None
|
||||||
|
assert attention._v.weight.grad is not None
|
||||||
|
|
||||||
def test_device_consistency(self, embed_dim, num_heads, random_embeddings, device):
|
def test_device_consistency(self, embed_dim, num_heads, random_embeddings, device):
|
||||||
"""Test that MultiHeadAttention works on correct device."""
|
"""Test that MultiHeadAttention works on correct device."""
|
||||||
|
|||||||
47
llm/tests/core/test_rms_norm.py
Normal file
47
llm/tests/core/test_rms_norm.py
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
import torch
|
||||||
|
import pytest
|
||||||
|
from llm.core.rms_norm import RMSNorm
|
||||||
|
|
||||||
|
def test_rmsnorm_shape_preservation():
|
||||||
|
norm = RMSNorm(64)
|
||||||
|
x = torch.randn(3, 5, 64)
|
||||||
|
y = norm(x)
|
||||||
|
assert y.shape == x.shape
|
||||||
|
|
||||||
|
def test_rmsnorm_dtype_and_device():
|
||||||
|
norm = RMSNorm(32)
|
||||||
|
x = torch.randn(8, 32, device='cpu', dtype=torch.float64)
|
||||||
|
y = norm(x)
|
||||||
|
assert y.dtype == torch.float64
|
||||||
|
assert y.device == x.device
|
||||||
|
|
||||||
|
def test_rmsnorm_mean_no_shift():
|
||||||
|
norm = RMSNorm(32)
|
||||||
|
x = torch.randn(3, 128, 32)
|
||||||
|
y = norm(x)
|
||||||
|
rms = torch.sqrt((y ** 2).mean(dim=-1))
|
||||||
|
w_mean = norm._w.mean().item()
|
||||||
|
assert torch.allclose(rms.mean(), torch.tensor(w_mean), rtol=0.2, atol=0.2)
|
||||||
|
|
||||||
|
def test_rmsnorm_backward():
|
||||||
|
norm = RMSNorm(16)
|
||||||
|
x = torch.randn(2, 15, 16, requires_grad=True)
|
||||||
|
y = norm(x)
|
||||||
|
loss = y.sum()
|
||||||
|
loss.backward()
|
||||||
|
assert x.grad is not None
|
||||||
|
assert norm._w.grad is not None
|
||||||
|
|
||||||
|
def test_rmsnorm_fp16():
|
||||||
|
norm = RMSNorm(8).half()
|
||||||
|
x = torch.randn(2, 6, 8).half()
|
||||||
|
y = norm(x)
|
||||||
|
assert y.shape == x.shape
|
||||||
|
assert y.dtype == torch.float16
|
||||||
|
|
||||||
|
def test_rmsnorm_large_eps_stability():
|
||||||
|
norm = RMSNorm(16, eps=1)
|
||||||
|
x = torch.zeros(2, 5, 16)
|
||||||
|
y = norm(x)
|
||||||
|
assert not torch.isnan(y).any()
|
||||||
|
assert not torch.isinf(y).any()
|
||||||
55
llm/tests/core/test_rope.py
Normal file
55
llm/tests/core/test_rope.py
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
import torch
|
||||||
|
import pytest
|
||||||
|
from llm.core.rope import RoPE
|
||||||
|
|
||||||
|
def test_rope_shapes_and_dtype():
|
||||||
|
rope = RoPE(head_size=8, max_seq_len=32)
|
||||||
|
x = torch.randn(2, 4, 16, 8) # [batch, num_heads, seq_len, head_size]
|
||||||
|
y = rope(x)
|
||||||
|
assert y.shape == x.shape
|
||||||
|
assert y.dtype == x.dtype
|
||||||
|
|
||||||
|
def test_rope_raises_on_bad_ndim():
|
||||||
|
rope = RoPE(head_size=8, max_seq_len=16)
|
||||||
|
x = torch.randn(2, 16, 8) # [batch, seq_len, head_size] (3D)
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
_ = rope(x)
|
||||||
|
|
||||||
|
def test_rope_preserves_norm():
|
||||||
|
rope = RoPE(head_size=8, max_seq_len=16)
|
||||||
|
x = torch.randn(2, 3, 7, 8)
|
||||||
|
x_norm = x.norm(dim=-1)
|
||||||
|
y = rope(x)
|
||||||
|
y_norm = y.norm(dim=-1)
|
||||||
|
# Нормы могут немного отличаться из-за float, сравниваем с допуском
|
||||||
|
assert torch.allclose(x_norm, y_norm, rtol=1e-5, atol=1e-7)
|
||||||
|
|
||||||
|
def test_rope_backward_pass():
|
||||||
|
rope = RoPE(head_size=8, max_seq_len=16)
|
||||||
|
x = torch.randn(2, 2, 8, 8, requires_grad=True)
|
||||||
|
out = rope(x)
|
||||||
|
loss = out.sum()
|
||||||
|
loss.backward()
|
||||||
|
assert x.grad is not None
|
||||||
|
assert x.grad.shape == x.shape
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("batch,num_heads,seq_len,head_size", [
|
||||||
|
(1, 1, 4, 8),
|
||||||
|
(2, 4, 16, 8),
|
||||||
|
(3, 2, 7, 8),
|
||||||
|
])
|
||||||
|
def test_rope_various_shapes(batch, num_heads, seq_len, head_size):
|
||||||
|
rope = RoPE(head_size=head_size, max_seq_len=32)
|
||||||
|
x = torch.randn(batch, num_heads, seq_len, head_size)
|
||||||
|
y = rope(x)
|
||||||
|
assert y.shape == x.shape
|
||||||
|
|
||||||
|
def test_rope_start_pos():
|
||||||
|
rope = RoPE(head_size=8, max_seq_len=32)
|
||||||
|
x_full = torch.randn(1, 2, 8, 8)
|
||||||
|
# Сравниваем участок результата для разных start_pos
|
||||||
|
out1 = rope(x_full)
|
||||||
|
out2 = rope(x_full, start_pos=2)
|
||||||
|
assert not torch.allclose(out1, out2)
|
||||||
|
# Для одинакового start_pos и x должны совпадать
|
||||||
|
assert torch.allclose(rope(x_full, start_pos=1), rope(x_full, start_pos=1))
|
||||||
42
llm/tests/core/test_silu.py
Normal file
42
llm/tests/core/test_silu.py
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
import torch
|
||||||
|
import pytest
|
||||||
|
from llm.core.silu import SiLU
|
||||||
|
|
||||||
|
def test_silu_shape_and_dtype():
|
||||||
|
silu = SiLU()
|
||||||
|
x = torch.randn(3, 10, 8)
|
||||||
|
y = silu(x)
|
||||||
|
assert y.shape == x.shape
|
||||||
|
assert y.dtype == x.dtype
|
||||||
|
|
||||||
|
def test_silu_known_values():
|
||||||
|
silu = SiLU()
|
||||||
|
x = torch.tensor([-2.0, 0.0, 2.0])
|
||||||
|
y = silu(x)
|
||||||
|
# PyTorch эталон
|
||||||
|
y_ref = torch.nn.functional.silu(x)
|
||||||
|
assert torch.allclose(y, y_ref, atol=1e-6)
|
||||||
|
|
||||||
|
def test_silu_large_vs_small():
|
||||||
|
silu = SiLU()
|
||||||
|
x_pos = torch.tensor([100.0])
|
||||||
|
x_neg = torch.tensor([-100.0])
|
||||||
|
y_pos = silu(x_pos)
|
||||||
|
y_neg = silu(x_neg)
|
||||||
|
assert torch.allclose(y_pos, x_pos, rtol=1e-4, atol=1e-4) # SiLU(x) ~ x для больших x>0
|
||||||
|
assert torch.allclose(y_neg, torch.zeros_like(x_neg), rtol=1e-4, atol=1e-4) # SiLU(x) ~ 0 для x<0
|
||||||
|
|
||||||
|
def test_silu_gradients():
|
||||||
|
silu = SiLU()
|
||||||
|
x = torch.randn(4, 4, requires_grad=True)
|
||||||
|
y = silu(x)
|
||||||
|
loss = y.sum()
|
||||||
|
loss.backward()
|
||||||
|
assert x.grad is not None
|
||||||
|
assert x.grad.shape == x.shape
|
||||||
|
|
||||||
|
def test_silu_broadcast():
|
||||||
|
silu = SiLU()
|
||||||
|
x = torch.randn(3, 1, 16)
|
||||||
|
y = silu(x)
|
||||||
|
assert y.shape == x.shape
|
||||||
39
llm/tests/core/test_swi_glu.py
Normal file
39
llm/tests/core/test_swi_glu.py
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
import torch
|
||||||
|
import pytest
|
||||||
|
from llm.core.swi_glu import SwiGLU
|
||||||
|
|
||||||
|
def test_swiglu_shape_and_dtype():
|
||||||
|
swiglu = SwiGLU(emb_size=32, dropout=0.1)
|
||||||
|
x = torch.randn(4, 10, 32)
|
||||||
|
y = swiglu(x)
|
||||||
|
assert y.shape == x.shape
|
||||||
|
assert y.dtype == x.dtype
|
||||||
|
|
||||||
|
def test_swiglu_forward_range():
|
||||||
|
swiglu = SwiGLU(emb_size=16, dropout=0.0)
|
||||||
|
x = torch.randn(3, 7, 16)
|
||||||
|
y = swiglu(x)
|
||||||
|
assert y.abs().max() < 20
|
||||||
|
|
||||||
|
def test_swiglu_gradients():
|
||||||
|
swiglu = SwiGLU(emb_size=8, dropout=0.0)
|
||||||
|
x = torch.randn(2, 5, 8, requires_grad=True)
|
||||||
|
out = swiglu(x)
|
||||||
|
loss = out.pow(2).sum()
|
||||||
|
loss.backward()
|
||||||
|
assert x.grad is not None
|
||||||
|
assert x.grad.shape == x.shape
|
||||||
|
|
||||||
|
def test_swiglu_fp16():
|
||||||
|
swiglu = SwiGLU(emb_size=16, dropout=0.0).half()
|
||||||
|
x = torch.randn(1, 8, 16).half()
|
||||||
|
y = swiglu(x)
|
||||||
|
assert y.shape == x.shape
|
||||||
|
assert y.dtype == torch.float16
|
||||||
|
|
||||||
|
def test_swiglu_reproducibility():
|
||||||
|
swiglu = SwiGLU(emb_size=8, dropout=0.0)
|
||||||
|
x = torch.ones(2, 4, 8)
|
||||||
|
y1 = swiglu(x)
|
||||||
|
y2 = swiglu(x)
|
||||||
|
assert torch.allclose(y1, y2)
|
||||||
49
llm/tests/datasets/test_streaming_text_dataset.py
Normal file
49
llm/tests/datasets/test_streaming_text_dataset.py
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
import torch
|
||||||
|
import pytest
|
||||||
|
from llm.datasets.streaming_text_dataset import StreamingTextDataset
|
||||||
|
|
||||||
|
class DummyTokenizer:
|
||||||
|
def __init__(self, vocab_size=100):
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
def encode(self, text, **kwargs):
|
||||||
|
return [len(w) % self.vocab_size for w in text.strip().split()]
|
||||||
|
|
||||||
|
def test_streaming_textdataset_basic_shape():
|
||||||
|
texts = ["hello world", "big transformers are fun", "LLM test string"]
|
||||||
|
tokenizer = DummyTokenizer(50)
|
||||||
|
block_size = 7
|
||||||
|
ds = StreamingTextDataset(texts, tokenizer, block_size)
|
||||||
|
assert len(ds) == 3
|
||||||
|
for i in range(len(ds)):
|
||||||
|
item = ds[i]
|
||||||
|
assert isinstance(item, dict)
|
||||||
|
assert "input_ids" in item
|
||||||
|
assert item["input_ids"].shape == (block_size,)
|
||||||
|
assert "labels" in item
|
||||||
|
assert item["labels"].shape == (block_size,)
|
||||||
|
|
||||||
|
def test_streaming_textdataset_padding_and_truncation():
|
||||||
|
texts = ["short", "one two three four five six seven eight nine ten"]
|
||||||
|
tokenizer = DummyTokenizer(40)
|
||||||
|
block_size = 4
|
||||||
|
ds = StreamingTextDataset(texts, tokenizer, block_size)
|
||||||
|
# короткое предложение padded
|
||||||
|
assert (ds[0]["input_ids"].shape[0] == block_size)
|
||||||
|
# длинное предложение truncated
|
||||||
|
assert (ds[1]["input_ids"].shape[0] == block_size)
|
||||||
|
|
||||||
|
def test_streaming_textdataset_index_error():
|
||||||
|
texts = ["sample"]
|
||||||
|
tokenizer = DummyTokenizer(10)
|
||||||
|
ds = StreamingTextDataset(texts, tokenizer, block_size=5)
|
||||||
|
with pytest.raises(IndexError):
|
||||||
|
_ = ds[1]
|
||||||
|
|
||||||
|
def test_streaming_textdataset_content_matching():
|
||||||
|
texts = ["foo bar baz", "abc def"]
|
||||||
|
tokenizer = DummyTokenizer(99)
|
||||||
|
block_size = 5
|
||||||
|
ds = StreamingTextDataset(texts, tokenizer, block_size)
|
||||||
|
# Проверка, что input_ids и labels совпадают точно
|
||||||
|
for i in range(len(ds)):
|
||||||
|
assert torch.equal(ds[i]["input_ids"], ds[i]["labels"])
|
||||||
49
llm/tests/datasets/test_text_dataset.py
Normal file
49
llm/tests/datasets/test_text_dataset.py
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
import torch
|
||||||
|
import pytest
|
||||||
|
from llm.datasets.text_dataset import TextDataset
|
||||||
|
|
||||||
|
class DummyTokenizer:
|
||||||
|
def __init__(self, vocab_size=100):
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
def encode(self, text, **kwargs):
|
||||||
|
return [len(w) % self.vocab_size for w in text.strip().split()]
|
||||||
|
|
||||||
|
def test_textdataset_shape_and_basic():
|
||||||
|
texts = ["hello world", "this is a test", "Transformer model"]
|
||||||
|
tokenizer = DummyTokenizer(50)
|
||||||
|
block_size = 6
|
||||||
|
dataset = TextDataset(texts, tokenizer, block_size=block_size)
|
||||||
|
for i in range(len(dataset)):
|
||||||
|
x = dataset[i]
|
||||||
|
assert isinstance(x, dict)
|
||||||
|
assert "input_ids" in x
|
||||||
|
assert isinstance(x["input_ids"], torch.Tensor)
|
||||||
|
assert x["input_ids"].shape == (block_size,)
|
||||||
|
|
||||||
|
def test_textdataset_truncation_and_padding():
|
||||||
|
texts = ["one two three four five six seven", "short"]
|
||||||
|
tokenizer = DummyTokenizer(100)
|
||||||
|
block_size = 5
|
||||||
|
dataset = TextDataset(texts, tokenizer, block_size=block_size)
|
||||||
|
assert isinstance(dataset[0], dict)
|
||||||
|
assert dataset[0]["input_ids"].shape[0] == 5
|
||||||
|
assert dataset[1]["input_ids"].shape[0] == 5
|
||||||
|
|
||||||
|
def test_textdataset_index_error():
|
||||||
|
texts = ["a", "b"]
|
||||||
|
tokenizer = DummyTokenizer(10)
|
||||||
|
dataset = TextDataset(texts, tokenizer, block_size=3)
|
||||||
|
with pytest.raises(IndexError):
|
||||||
|
_ = dataset[2]
|
||||||
|
|
||||||
|
def test_textdataset_encoding():
|
||||||
|
texts = ["привет", "мир"]
|
||||||
|
tokenizer = DummyTokenizer(20)
|
||||||
|
block_size = 4
|
||||||
|
dataset = TextDataset(texts, tokenizer, block_size=block_size)
|
||||||
|
assert len(dataset) == 2
|
||||||
|
x = dataset[0]
|
||||||
|
assert isinstance(x, dict)
|
||||||
|
assert "input_ids" in x
|
||||||
|
assert isinstance(x["input_ids"], torch.Tensor)
|
||||||
|
assert x["input_ids"].shape == (block_size,)
|
||||||
62
llm/tests/datasets/test_text_with_special_tokens_dataset.py
Normal file
62
llm/tests/datasets/test_text_with_special_tokens_dataset.py
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
import torch
|
||||||
|
import pytest
|
||||||
|
from llm.datasets.text_with_special_tokens_dataset import TextWithSpecialTokensDataset
|
||||||
|
|
||||||
|
class DummyTokenizer:
|
||||||
|
def __init__(self):
|
||||||
|
self.bos_token_id = 101
|
||||||
|
self.eos_token_id = 102
|
||||||
|
self.pad_token_id = 0
|
||||||
|
def encode(self, text, add_special_tokens=False, add_bos_token=False, add_eos_token=False):
|
||||||
|
ids = [ord(c) % 50 for c in text.strip()]
|
||||||
|
if add_bos_token:
|
||||||
|
ids = [self.bos_token_id] + ids
|
||||||
|
if add_eos_token:
|
||||||
|
ids = ids + [self.eos_token_id]
|
||||||
|
return ids
|
||||||
|
|
||||||
|
def test_specialtokens_basic_bos_eos():
|
||||||
|
texts = ["abc", "d"]
|
||||||
|
tokenizer = DummyTokenizer()
|
||||||
|
block_size = 6
|
||||||
|
ds = TextWithSpecialTokensDataset(texts, tokenizer, block_size=block_size, add_bos=True, add_eos=True)
|
||||||
|
for i in range(len(ds)):
|
||||||
|
item = ds[i]
|
||||||
|
assert isinstance(item, dict)
|
||||||
|
assert "input_ids" in item
|
||||||
|
assert item["input_ids"].shape == (block_size,)
|
||||||
|
assert item["input_ids"][0] == tokenizer.bos_token_id
|
||||||
|
assert item["input_ids"][item["input_ids"].ne(tokenizer.pad_token_id).sum() - 1] == tokenizer.eos_token_id
|
||||||
|
|
||||||
|
def test_specialtokens_padding_and_truncation():
|
||||||
|
texts = ["qwertyuiop", "z"]
|
||||||
|
tokenizer = DummyTokenizer()
|
||||||
|
block_size = 5
|
||||||
|
ds = TextWithSpecialTokensDataset(texts, tokenizer, block_size=block_size, add_bos=True)
|
||||||
|
assert ds[0]["input_ids"].shape[0] == block_size
|
||||||
|
assert ds[1]["input_ids"][-1] == tokenizer.pad_token_id
|
||||||
|
|
||||||
|
def test_specialtokens_no_bos_eos():
|
||||||
|
texts = ["xyz"]
|
||||||
|
tokenizer = DummyTokenizer()
|
||||||
|
block_size = 6
|
||||||
|
ds = TextWithSpecialTokensDataset(texts, tokenizer, block_size=block_size, add_bos=False, add_eos=False)
|
||||||
|
item = ds[0]["input_ids"]
|
||||||
|
assert tokenizer.bos_token_id not in item
|
||||||
|
assert tokenizer.eos_token_id not in item
|
||||||
|
assert item.shape == (block_size,)
|
||||||
|
|
||||||
|
def test_specialtokens_index_error():
|
||||||
|
texts = ["sample"]
|
||||||
|
tokenizer = DummyTokenizer()
|
||||||
|
ds = TextWithSpecialTokensDataset(texts, tokenizer, block_size=8)
|
||||||
|
with pytest.raises(IndexError):
|
||||||
|
_ = ds[1]
|
||||||
|
|
||||||
|
def test_specialtokens_labels():
|
||||||
|
texts = ["abcd"]
|
||||||
|
tokenizer = DummyTokenizer()
|
||||||
|
block_size = 7
|
||||||
|
ds = TextWithSpecialTokensDataset(texts, tokenizer, block_size=block_size, add_bos=True, add_eos=True)
|
||||||
|
item = ds[0]
|
||||||
|
assert torch.equal(item["input_ids"], item["labels"])
|
||||||
@@ -145,7 +145,11 @@ class TestGPT:
|
|||||||
assert model._token_embeddings._embedding.weight.grad is not None
|
assert model._token_embeddings._embedding.weight.grad is not None
|
||||||
assert model._linear.weight.grad is not None
|
assert model._linear.weight.grad is not None
|
||||||
if len(model._decoders) > 0:
|
if len(model._decoders) > 0:
|
||||||
assert model._decoders[0]._heads._heads[0]._q.weight.grad is not None
|
# Проверяем через новый интерфейс attention оптимизации:
|
||||||
|
attn = model._decoders[0]._heads
|
||||||
|
assert attn._q.weight.grad is not None
|
||||||
|
assert attn._k.weight.grad is not None
|
||||||
|
assert attn._v.weight.grad is not None
|
||||||
|
|
||||||
def test_device_consistency(self, gpt_config, random_inputs, device):
|
def test_device_consistency(self, gpt_config, random_inputs, device):
|
||||||
"""Test that GPT works on correct device."""
|
"""Test that GPT works on correct device."""
|
||||||
|
|||||||
53
llm/tests/models/test_gpt2.py
Normal file
53
llm/tests/models/test_gpt2.py
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
import torch
|
||||||
|
import pytest
|
||||||
|
from llm.models.gpt.gpt2 import GPT2
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def config():
|
||||||
|
return {
|
||||||
|
"vocab_size": 100,
|
||||||
|
"embed_dim": 32,
|
||||||
|
"num_heads": 4,
|
||||||
|
"num_layers": 2,
|
||||||
|
"max_position_embeddings": 16,
|
||||||
|
"dropout": 0.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def model(config):
|
||||||
|
return GPT2(config)
|
||||||
|
|
||||||
|
def test_forward_basic(model):
|
||||||
|
x = torch.randint(0, 100, (2, 8))
|
||||||
|
logits, cache = model(x)
|
||||||
|
assert logits.shape == (2, 8, 100)
|
||||||
|
assert isinstance(cache, list)
|
||||||
|
assert len(cache) == model._decoders.__len__()
|
||||||
|
|
||||||
|
def test_forward_with_cache(model):
|
||||||
|
x = torch.randint(0, 100, (2, 4))
|
||||||
|
logits, cache = model(x, use_cache=True)
|
||||||
|
x2 = torch.randint(0, 100, (2, 1))
|
||||||
|
logits2, cache2 = model(x2, use_cache=True, cache=cache)
|
||||||
|
assert logits2.shape == (2, 1, 100)
|
||||||
|
assert isinstance(cache2, list)
|
||||||
|
|
||||||
|
def test_generate_and_shape(model):
|
||||||
|
x = torch.randint(0, 100, (1, 5))
|
||||||
|
result = model.generate(x, max_new_tokens=3, do_sample=False)
|
||||||
|
assert result.shape == (1, 8)
|
||||||
|
|
||||||
|
def test_forward_sequence_too_long(model, config):
|
||||||
|
x = torch.randint(0, 100, (1, config["max_position_embeddings"] + 1))
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
model(x)
|
||||||
|
|
||||||
|
def test_generate_with_sampling_topk(model):
|
||||||
|
x = torch.randint(0, 100, (1, 3))
|
||||||
|
out = model.generate(x, max_new_tokens=2, do_sample=True, top_k=5)
|
||||||
|
assert out.shape == (1, 5)
|
||||||
|
|
||||||
|
def test_generate_with_sampling_topp(model):
|
||||||
|
x = torch.randint(0, 100, (1, 3))
|
||||||
|
out = model.generate(x, max_new_tokens=2, do_sample=True, top_p=0.8)
|
||||||
|
assert out.shape == (1, 5)
|
||||||
53
llm/tests/models/test_llama.py
Normal file
53
llm/tests/models/test_llama.py
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
import torch
|
||||||
|
import pytest
|
||||||
|
from llm.models.llama.llama import Llama
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def config():
|
||||||
|
return {
|
||||||
|
"vocab_size": 100,
|
||||||
|
"embed_dim": 32,
|
||||||
|
"num_heads": 4,
|
||||||
|
"num_layers": 2,
|
||||||
|
"max_position_embeddings": 16,
|
||||||
|
"dropout": 0.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def model(config):
|
||||||
|
return Llama(config)
|
||||||
|
|
||||||
|
def test_forward_basic(model):
|
||||||
|
x = torch.randint(0, 100, (2, 8))
|
||||||
|
logits, cache = model(x)
|
||||||
|
assert logits.shape == (2, 8, 100)
|
||||||
|
assert isinstance(cache, list)
|
||||||
|
assert len(cache) == model._decoders.__len__()
|
||||||
|
|
||||||
|
def test_forward_with_cache(model):
|
||||||
|
x = torch.randint(0, 100, (2, 4))
|
||||||
|
logits, cache = model(x, use_cache=True)
|
||||||
|
x2 = torch.randint(0, 100, (2, 1))
|
||||||
|
logits2, cache2 = model(x2, use_cache=True, cache=cache)
|
||||||
|
assert logits2.shape == (2, 1, 100)
|
||||||
|
assert isinstance(cache2, list)
|
||||||
|
|
||||||
|
def test_generate_and_shape(model):
|
||||||
|
x = torch.randint(0, 100, (1, 5))
|
||||||
|
result = model.generate(x, max_new_tokens=3, do_sample=False)
|
||||||
|
assert result.shape == (1, 8)
|
||||||
|
|
||||||
|
def test_forward_sequence_too_long(model, config):
|
||||||
|
x = torch.randint(0, 100, (1, config["max_position_embeddings"] + 1))
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
model(x)
|
||||||
|
|
||||||
|
def test_generate_with_sampling_topk(model):
|
||||||
|
x = torch.randint(0, 100, (1, 3))
|
||||||
|
out = model.generate(x, max_new_tokens=2, do_sample=True, top_k=5)
|
||||||
|
assert out.shape == (1, 5)
|
||||||
|
|
||||||
|
def test_generate_with_sampling_topp(model):
|
||||||
|
x = torch.randint(0, 100, (1, 3))
|
||||||
|
out = model.generate(x, max_new_tokens=2, do_sample=True, top_p=0.8)
|
||||||
|
assert out.shape == (1, 5)
|
||||||
59
llm/tests/models/test_mistral.py
Normal file
59
llm/tests/models/test_mistral.py
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
# llm/src/llm/tests/models/test_mistral.py
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import pytest
|
||||||
|
from llm.models.mistral.mistral import Mistral
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def config():
|
||||||
|
return {
|
||||||
|
"vocab_size": 100,
|
||||||
|
"embed_dim": 32,
|
||||||
|
"num_q_heads": 4,
|
||||||
|
"num_kv_heads": 2,
|
||||||
|
"num_layers": 2,
|
||||||
|
"max_position_embeddings": 16,
|
||||||
|
"window_size": 8,
|
||||||
|
"dropout": 0.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def model(config):
|
||||||
|
return Mistral(config)
|
||||||
|
|
||||||
|
def test_forward_basic(model):
|
||||||
|
x = torch.randint(0, 100, (2, 8))
|
||||||
|
logits, cache = model(x)
|
||||||
|
assert logits.shape == (2, 8, 100)
|
||||||
|
assert isinstance(cache, list)
|
||||||
|
assert len(cache) == model._decoders.__len__()
|
||||||
|
|
||||||
|
def test_forward_with_cache(model):
|
||||||
|
x = torch.randint(0, 100, (2, 4))
|
||||||
|
logits, cache = model(x, use_cache=True)
|
||||||
|
# Прогоняем еще раз с кэшем и 1 новым токеном
|
||||||
|
x2 = torch.randint(0, 100, (2, 1))
|
||||||
|
logits2, cache2 = model(x2, use_cache=True, cache=cache)
|
||||||
|
assert logits2.shape == (2, 1, 100)
|
||||||
|
assert isinstance(cache2, list)
|
||||||
|
|
||||||
|
def test_generate_and_shape(model):
|
||||||
|
x = torch.randint(0, 100, (1, 5))
|
||||||
|
result = model.generate(x, max_new_tokens=3, do_sample=False)
|
||||||
|
# Вход=5, добавить 3 → итоговая длина 8
|
||||||
|
assert result.shape == (1, 8)
|
||||||
|
|
||||||
|
def test_forward_sequence_too_long(model, config):
|
||||||
|
x = torch.randint(0, 100, (1, config["max_position_embeddings"] + 1))
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
model(x)
|
||||||
|
|
||||||
|
def test_generate_with_sampling_topk(model):
|
||||||
|
x = torch.randint(0, 100, (1, 3))
|
||||||
|
out = model.generate(x, max_new_tokens=2, do_sample=True, top_k=5)
|
||||||
|
assert out.shape == (1, 5)
|
||||||
|
|
||||||
|
def test_generate_with_sampling_topp(model):
|
||||||
|
x = torch.randint(0, 100, (1, 3))
|
||||||
|
out = model.generate(x, max_new_tokens=2, do_sample=True, top_p=0.8)
|
||||||
|
assert out.shape == (1, 5)
|
||||||
35
llm/tests/training/test_optimizer.py
Normal file
35
llm/tests/training/test_optimizer.py
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
import pytest
|
||||||
|
import torch.nn as nn
|
||||||
|
from llm.training.optimizer import get_optimizer
|
||||||
|
|
||||||
|
class DummyModel(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.linear = nn.Linear(10, 1)
|
||||||
|
|
||||||
|
def test_get_optimizer_adamw():
|
||||||
|
model = DummyModel()
|
||||||
|
optimizer = get_optimizer(model, lr=1e-3, weight_decay=0.02, optimizer_type="adamw")
|
||||||
|
assert optimizer.__class__.__name__ == 'AdamW'
|
||||||
|
assert optimizer.defaults['lr'] == 1e-3
|
||||||
|
assert optimizer.defaults['weight_decay'] == 0.02
|
||||||
|
|
||||||
|
def test_get_optimizer_adam():
|
||||||
|
model = DummyModel()
|
||||||
|
optimizer = get_optimizer(model, lr=1e-4, weight_decay=0.01, optimizer_type="adam")
|
||||||
|
assert optimizer.__class__.__name__ == 'Adam'
|
||||||
|
assert optimizer.defaults['lr'] == 1e-4
|
||||||
|
assert optimizer.defaults['weight_decay'] == 0.01
|
||||||
|
|
||||||
|
def test_get_optimizer_sgd():
|
||||||
|
model = DummyModel()
|
||||||
|
optimizer = get_optimizer(model, lr=0.1, optimizer_type="sgd")
|
||||||
|
assert optimizer.__class__.__name__ == 'SGD'
|
||||||
|
assert optimizer.defaults['lr'] == 0.1
|
||||||
|
# SGD: weight_decay по умолчанию 0 для этого вызова
|
||||||
|
assert optimizer.defaults['momentum'] == 0.9
|
||||||
|
|
||||||
|
def test_get_optimizer_invalid():
|
||||||
|
model = DummyModel()
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
get_optimizer(model, optimizer_type="nonexistent")
|
||||||
62
llm/tests/training/test_scheduler.py
Normal file
62
llm/tests/training/test_scheduler.py
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from llm.training.scheduler import get_linear_schedule_with_warmup
|
||||||
|
from llm.training.optimizer import get_optimizer
|
||||||
|
|
||||||
|
class DummyModel(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.linear = nn.Linear(2, 2)
|
||||||
|
|
||||||
|
def test_scheduler_warmup_and_decay():
|
||||||
|
model = DummyModel()
|
||||||
|
base_lr = 0.1
|
||||||
|
warmup_steps = 5
|
||||||
|
total_steps = 20
|
||||||
|
optimizer = get_optimizer(model, lr=base_lr, optimizer_type="sgd")
|
||||||
|
scheduler = get_linear_schedule_with_warmup(
|
||||||
|
optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps)
|
||||||
|
|
||||||
|
lrs = [optimizer.param_groups[0]['lr']] # lr до первого .step()
|
||||||
|
for _ in range(total_steps):
|
||||||
|
optimizer.step()
|
||||||
|
scheduler.step()
|
||||||
|
lrs.append(optimizer.param_groups[0]['lr'])
|
||||||
|
|
||||||
|
# Проверяем warmup: lr должен расти линейно в первых warmup_steps (начиная с шага 1)
|
||||||
|
for i in range(warmup_steps + 1):
|
||||||
|
expected = base_lr * min(i, warmup_steps) / max(1, warmup_steps)
|
||||||
|
assert abs(lrs[i] - expected) < 1e-6, f"Warmup step {i}: lr={lrs[i]}, expected={expected}"
|
||||||
|
# Проверяем decay: после warmup lr затухает
|
||||||
|
for i in range(warmup_steps + 1, total_steps + 1):
|
||||||
|
expected = base_lr * max(0.0, (total_steps - (i - 0)) / max(1, total_steps - warmup_steps))
|
||||||
|
assert abs(lrs[i] - expected) < 1e-6, f"Decay step {i}: lr={lrs[i]}, expected={expected}"
|
||||||
|
assert lrs[-1] == 0.0
|
||||||
|
|
||||||
|
def test_scheduler_no_warmup():
|
||||||
|
model = DummyModel()
|
||||||
|
base_lr = 0.1
|
||||||
|
warmup_steps = 0
|
||||||
|
total_steps = 10
|
||||||
|
optimizer = get_optimizer(model, lr=base_lr, optimizer_type="adam")
|
||||||
|
scheduler = get_linear_schedule_with_warmup(
|
||||||
|
optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps)
|
||||||
|
lrs = [optimizer.param_groups[0]['lr']]
|
||||||
|
for _ in range(total_steps):
|
||||||
|
optimizer.step()
|
||||||
|
scheduler.step()
|
||||||
|
lrs.append(optimizer.param_groups[0]['lr'])
|
||||||
|
|
||||||
|
for i in range(total_steps + 1):
|
||||||
|
expected = base_lr * max(0.0, (total_steps - i) / max(1, total_steps - warmup_steps))
|
||||||
|
assert abs(lrs[i] - expected) < 1e-6, f"Step {i}: lr={lrs[i]}, expected={expected}"
|
||||||
|
assert lrs[-1] == 0.0
|
||||||
|
|
||||||
|
def test_scheduler_full_decay_to_zero():
|
||||||
|
model = DummyModel()
|
||||||
|
optimizer = get_optimizer(model, lr=1.0, optimizer_type="adamw")
|
||||||
|
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=2, num_training_steps=2)
|
||||||
|
scheduler.step()
|
||||||
|
scheduler.step()
|
||||||
|
for param_group in optimizer.param_groups:
|
||||||
|
assert param_group['lr'] == 0.0
|
||||||
62
llm/tests/training/test_trainer.py
Normal file
62
llm/tests/training/test_trainer.py
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
from llm.training.trainer import Trainer
|
||||||
|
|
||||||
|
# Синтетический небольшой датасет для автогрессивной LM задачи
|
||||||
|
class ToyLMDataset(Dataset):
|
||||||
|
def __init__(self, num_samples=16, seq_len=8, vocab_size=16):
|
||||||
|
self.data = torch.randint(1, vocab_size, (num_samples, seq_len))
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.data)
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
# labels == input_ids (identity task)
|
||||||
|
return {"input_ids": self.data[idx], "labels": self.data[idx]}
|
||||||
|
|
||||||
|
# Простая dummy-модель — 1 слой linear over vocab
|
||||||
|
class TinyModel(nn.Module):
|
||||||
|
def __init__(self, vocab_size=16, seq_len=8):
|
||||||
|
super().__init__()
|
||||||
|
self.linear = nn.Linear(seq_len, vocab_size)
|
||||||
|
def forward(self, x):
|
||||||
|
# logits: (batch, seq_len, vocab_size)
|
||||||
|
# Для простоты делаем транспонирование
|
||||||
|
return self.linear(x.float()).unsqueeze(1).expand(-1, x.shape[1], -1)
|
||||||
|
|
||||||
|
def test_train_runs_without_errors():
|
||||||
|
train_data = ToyLMDataset(num_samples=16, seq_len=8, vocab_size=16)
|
||||||
|
model = TinyModel(vocab_size=16, seq_len=8)
|
||||||
|
trainer = Trainer(model, train_data, lr=1e-3, batch_size=4, num_epochs=1, warmup_steps=2)
|
||||||
|
trainer.train()
|
||||||
|
|
||||||
|
def test_trainer_evaluate_runs():
|
||||||
|
train_data = ToyLMDataset(num_samples=8)
|
||||||
|
val_data = ToyLMDataset(num_samples=8)
|
||||||
|
model = TinyModel()
|
||||||
|
trainer = Trainer(model, train_data, val_data, lr=1e-3, batch_size=4, num_epochs=1, warmup_steps=2)
|
||||||
|
trainer.train()
|
||||||
|
trainer.evaluate()
|
||||||
|
|
||||||
|
def test_trainer_tuple_output():
|
||||||
|
# Модель, возвращающая кортеж (logits, extra)
|
||||||
|
class TupleModel(nn.Module):
|
||||||
|
def __init__(self, vocab_size=16, seq_len=8):
|
||||||
|
super().__init__()
|
||||||
|
self.linear = nn.Linear(seq_len, vocab_size)
|
||||||
|
def forward(self, x):
|
||||||
|
logits = self.linear(x.float()).unsqueeze(1).expand(-1, x.shape[1], -1)
|
||||||
|
extra = torch.zeros(1)
|
||||||
|
return logits, extra
|
||||||
|
|
||||||
|
train_data = ToyLMDataset(num_samples=8)
|
||||||
|
model = TupleModel()
|
||||||
|
trainer = Trainer(model, train_data, lr=1e-3, batch_size=2, num_epochs=1, warmup_steps=1)
|
||||||
|
trainer.train()
|
||||||
|
|
||||||
|
def test_trainer_loss_decreases():
|
||||||
|
train_data = ToyLMDataset(num_samples=32, seq_len=8, vocab_size=8)
|
||||||
|
model = TinyModel(vocab_size=8, seq_len=8)
|
||||||
|
trainer = Trainer(model, train_data, lr=0.05, batch_size=8, num_epochs=2, warmup_steps=1)
|
||||||
|
trainer.train()
|
||||||
|
avg_losses = trainer.loss_history
|
||||||
|
assert avg_losses[-1] <= avg_losses[0] or abs(avg_losses[-1] - avg_losses[0]) < 1e-3
|
||||||
3267
notebooks/mistral.ipynb
Normal file
3267
notebooks/mistral.ipynb
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user