mirror of
https://github.com/pese-git/llm-arch-research.git
synced 2026-01-23 21:10:54 +00:00
Compare commits
44 Commits
feature/ll
...
7744658716
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7744658716 | ||
|
|
21cfd79c19 | ||
|
|
9e2796e6be | ||
|
|
25caf69ced | ||
|
|
ddc4924a37 | ||
|
|
92a34551b8 | ||
|
|
ea932a36f3 | ||
|
|
cfb4b6dfb1 | ||
|
|
58c4a00b48 | ||
|
|
c9da4c841b | ||
|
|
b1737bbce2 | ||
|
|
1aba02cab9 | ||
|
|
9794db3e18 | ||
|
|
d947b7beb3 | ||
|
|
613d784565 | ||
|
|
38c271ca3c | ||
|
|
aec3c8adb6 | ||
|
|
90eb2f4467 | ||
|
|
a3415d404a | ||
|
|
9837ea3c3d | ||
|
|
baafca0546 | ||
|
|
516f9580fb | ||
|
|
64d33783e0 | ||
|
|
6efc946027 | ||
|
|
8018efae2a | ||
|
|
0832d78acf | ||
|
|
c338556cfe | ||
|
|
3a356f5d79 | ||
|
|
923aa51e2a | ||
|
|
ba3b04cec2 | ||
|
|
e6ca8dee6f | ||
|
|
2e72dbaf07 | ||
|
|
dc440a3938 | ||
|
|
50d7593023 | ||
|
|
38682e8c9d | ||
|
|
e791f7cd93 | ||
|
|
d10044e4a7 | ||
|
|
ec0d2bd8d0 | ||
|
|
e5706a690d | ||
|
|
3e4815fcc6 | ||
|
|
0cc7850848 | ||
|
|
237b86421e | ||
|
|
712278e33c | ||
|
|
332cad6159 |
78
README.md
78
README.md
@@ -1,15 +1,16 @@
|
||||
# LLM Architecture Research
|
||||
|
||||
Исследовательский проект для разработки и обучения архитектур больших языковых моделей (LLM).
|
||||
Исследовательский проект по разработке, обучению и сравнительному анализу современных архитектур больших языковых моделей (LLM): **GPT, GPT-2, LLaMA, Mistral**. Прямая поддержка интеграции с HuggingFace (через модуль `hf-proxy`).
|
||||
|
||||
|
||||
## 🏗️ Архитектура проекта
|
||||
|
||||
Проект организован как монорепозиторий с использованием **uv** workspace:
|
||||
|
||||
- **`llm`** — основная библиотека с реализацией архитектур LLM (GPT, GPT-2)
|
||||
- **`hf-proxy`** — адаптер для интеграции с HuggingFace
|
||||
- **`experiments`** — скрипты обучения и экспериментов
|
||||
- **`notebooks`** — исследовательские ноутбуки
|
||||
- **`llm`** — основная библиотека с реализацией архитектур LLM (**GPT, GPT-2, LLaMA, Mistral**)
|
||||
- **`hf-proxy`** — экспериментальный адаптер для интеграции с HuggingFace (загрузка, токенизация, экспериментальные скрипты). Функционал может изменяться и не гарантирует полной совместимости с будущими версиями HuggingFace Transformers.
|
||||
- **`experiments`** — скрипты обучения и генерации (включая HF и собственные модели)
|
||||
- **`notebooks`** — исследовательские ноутбуки, анализ архитектур
|
||||
|
||||
## 📁 Структура проекта
|
||||
|
||||
@@ -41,8 +42,11 @@ llm-arch-research/
|
||||
│ │ │ ├── gpt.py
|
||||
│ │ │ ├── gpt2.py
|
||||
│ │ │ └── __init__.py
|
||||
│ │ └── llama/ # LLaMA архитектура
|
||||
│ │ ├── llama.py
|
||||
│ │ ├── llama/ # LLaMA архитектура
|
||||
│ │ │ ├── llama.py
|
||||
│ │ │ └── __init__.py
|
||||
│ │ └── mistral/ # Mistral архитектура
|
||||
│ │ ├── mistral.py
|
||||
│ │ └── __init__.py
|
||||
│ ├── training/ # утилиты обучения
|
||||
│ │ ├── dataset.py
|
||||
@@ -81,6 +85,18 @@ llm-arch-research/
|
||||
|
||||
## 🚀 Быстрый старт
|
||||
|
||||
**Пример запуска обучения и генерации для любых архитектур:**
|
||||
|
||||
```bash
|
||||
python experiments/llm_only/run_llm_experiment.py --model mistral --action generate --config experiments/llm_only/configs/mistral_generate.json
|
||||
```
|
||||
|
||||
**Использование собственных моделей с HuggingFace-интерфейсом:**
|
||||
```python
|
||||
from hf_proxy.hf_adapter import HFAdapter
|
||||
hf_model = HFAdapter("mistralai/Mistral-7B-v0.1")
|
||||
```
|
||||
|
||||
### Установка зависимостей
|
||||
|
||||
```bash
|
||||
@@ -91,15 +107,17 @@ uv sync
|
||||
uv sync --extra dev
|
||||
```
|
||||
|
||||
### Запуск обучения GPT
|
||||
## ⚡ Работа с экспериментами (experiments/llm_only, experiments/hf_integration)
|
||||
|
||||
```bash
|
||||
# Обучение базовой GPT модели
|
||||
uv run python experiments/llm_only/train_gpt_bpe.py
|
||||
- В `experiments/llm_only`: универсальный скрипт для обучения и генерации LLM (включая LLaMA и Mistral) без HuggingFace — всё через собственную реализацию.
|
||||
- В `experiments/hf_integration`: скрипты и примеры для генерации, обучения и тестирования моделей с помощью HuggingFace API (через hf-proxy). Позволяет использовать свои модели и токенизаторы как стандартные HF-объекты.
|
||||
|
||||
# Обучение с интеграцией HuggingFace
|
||||
uv run python experiments/hf_integration/simple_hf_training.py
|
||||
```
|
||||
**Для моделей Mistral/Llama доступны оба сценария: прямая работа или через HuggingFace-прокси.**
|
||||
|
||||
*Конфиги и примеры см. в соответствующих папках.*
|
||||
|
||||
|
||||
---
|
||||
|
||||
### Тестирование hf-proxy
|
||||
|
||||
@@ -212,33 +230,23 @@ dependencies = [
|
||||
|
||||
## 🎯 Реализованные возможности
|
||||
|
||||
### Архитектуры GPT и GPT-2
|
||||
- ✅ Токенные и позиционные эмбеддинги
|
||||
- ✅ Многоголовое внимание с causal mask
|
||||
- ✅ Декодерные блоки с residual connections
|
||||
- ✅ Layer normalization
|
||||
- ✅ Dropout регуляризация
|
||||
- ✅ Отдельные реализации GPT и GPT-2 (различия в масштабе и деталях архитектуры)
|
||||
### Архитектуры
|
||||
- ✅ GPT, GPT-2: Полностью воспроизводимые реализации, токенные и позиционные эмбеддинги, causal multi-head attention, LayerNorm
|
||||
- ✅ LLaMA: Rotary Positional Embeddings (RoPE), RMSNorm, SwiGLU, оптимизированная память
|
||||
- ✅ Mistral: Sliding Window Attention (оконное внимание), Grouped Query Attention (GQA), совместимость с HF
|
||||
- ✅ Все архитектуры поддерживают обучение и генерацию текста
|
||||
|
||||
### Генерация текста
|
||||
- ✅ Жадный поиск (greedy decoding)
|
||||
- ✅ Вероятностное сэмплирование
|
||||
- ✅ Top-k сэмплирование
|
||||
- ✅ Nucleus sampling (top-p)
|
||||
- ✅ Контроль температуры
|
||||
- ✅ Greedy, sampling (Top-k, Top-p), контроль температуры, efficient caching
|
||||
|
||||
### Обучение
|
||||
- ✅ Датасет для языкового моделирования
|
||||
- ✅ Базовый тренировочный цикл
|
||||
- ✅ Оптимизатор AdamW
|
||||
- ✅ Сохранение чекпоинтов
|
||||
- ✅ Языковое моделирование с кастомными и HF-токенизаторами
|
||||
- ✅ AdamW, кастомные датасеты, сохранение чекпоинтов
|
||||
|
||||
### Интеграция с HuggingFace (hf-proxy)
|
||||
- ✅ Адаптер моделей для совместимости с HF интерфейсами
|
||||
- ✅ Адаптер токенизаторов с поддержкой всех методов HF
|
||||
- ✅ Сохранение и загрузка в HF формате
|
||||
- ✅ Совместимость с HF Trainer и pipelines
|
||||
- ✅ Генерация через стандартные HF интерфейсы
|
||||
- ✅ Экспорт/импорт моделей и токенизаторов в HF совместимый формат
|
||||
- ✅ Генерация и обучение через HF Trainer, pipelines и т.д.
|
||||
- ✅ Двусторонняя поддержка: собственные модели становятся HF-совместимыми и наоборот
|
||||
|
||||
## 🔬 Эксперименты с hf-proxy
|
||||
|
||||
|
||||
148
assets/drawio/gpt1-architecture.drawio
Normal file
148
assets/drawio/gpt1-architecture.drawio
Normal file
@@ -0,0 +1,148 @@
|
||||
<mxfile host="65bd71144e">
|
||||
<diagram name="GPT Architecture" id="DEYydPS-O6mnllJWumln">
|
||||
<mxGraphModel dx="1216" dy="316" grid="1" gridSize="10" guides="1" tooltips="1" connect="1" arrows="1" fold="1" page="1" pageScale="1" pageWidth="827" pageHeight="1169" math="0" shadow="0">
|
||||
<root>
|
||||
<mxCell id="0"/>
|
||||
<mxCell id="1" parent="0"/>
|
||||
<mxCell id="92" value="" style="group" vertex="1" connectable="0" parent="1">
|
||||
<mxGeometry x="40" y="320" width="1286" height="160" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="3" value="" style="rounded=0;whiteSpace=wrap;html=1;fillColor=#fff2cc;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
|
||||
<mxGeometry x="230" width="440" height="160" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="4" value="<div>Masked</div>Multi+Head<br>Attention" style="rounded=0;whiteSpace=wrap;html=1;strokeColor=#6c8ebf;fillColor=#dae8fc;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" vertex="1">
|
||||
<mxGeometry x="51.42776556776556" y="50" width="78.97435897435898" height="60" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="22" value="" style="edgeStyle=none;html=1;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" source="5" edge="1">
|
||||
<mxGeometry relative="1" as="geometry">
|
||||
<mxPoint x="350" y="80" as="targetPoint"/>
|
||||
</mxGeometry>
|
||||
</mxCell>
|
||||
<mxCell id="5" value="Feed<div>Forward</div><div>Network</div>" style="rounded=0;whiteSpace=wrap;html=1;strokeColor=#9673a6;fillColor=#e1d5e7;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" vertex="1">
|
||||
<mxGeometry x="260.9564102564102" y="50" width="71.9230769230769" height="60" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="7" value="Norm" style="rounded=0;whiteSpace=wrap;html=1;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" vertex="1">
|
||||
<mxGeometry x="379.997619047619" y="60" width="37.87142857142857" height="40" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="21" style="edgeStyle=none;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" source="12" target="5" edge="1">
|
||||
<mxGeometry relative="1" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="12" value="Norm" style="rounded=0;whiteSpace=wrap;html=1;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" vertex="1">
|
||||
<mxGeometry x="177.14285714285714" y="60" width="41.904761904761905" height="40" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="13" value="" style="endArrow=classic;html=1;exitX=0;exitY=0.5;exitDx=0;exitDy=0;edgeStyle=elbowEdgeStyle;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" source="3" target="4" edge="1">
|
||||
<mxGeometry width="50" height="50" relative="1" as="geometry">
|
||||
<mxPoint x="20" y="80.00000000000011" as="sourcePoint"/>
|
||||
<mxPoint x="229.52380952380952" y="-50" as="targetPoint"/>
|
||||
</mxGeometry>
|
||||
</mxCell>
|
||||
<mxCell id="14" value="" style="endArrow=classic;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;edgeStyle=orthogonalEdgeStyle;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" source="18" target="12" edge="1">
|
||||
<mxGeometry width="50" height="50" relative="1" as="geometry">
|
||||
<mxPoint x="155.71428571427464" y="79.99999999999989" as="sourcePoint"/>
|
||||
<mxPoint x="229.52380952380952" y="-50" as="targetPoint"/>
|
||||
</mxGeometry>
|
||||
</mxCell>
|
||||
<mxCell id="18" value="+" style="ellipse;whiteSpace=wrap;html=1;aspect=fixed;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" vertex="1">
|
||||
<mxGeometry x="150.00428571428571" y="75" width="10" height="10" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="19" value="" style="endArrow=classic;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;edgeStyle=orthogonalEdgeStyle;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" source="4" target="18" edge="1">
|
||||
<mxGeometry width="50" height="50" relative="1" as="geometry">
|
||||
<mxPoint x="213.80952380952382" y="410" as="sourcePoint"/>
|
||||
<mxPoint x="145.71428571428578" y="80.00000000000011" as="targetPoint"/>
|
||||
</mxGeometry>
|
||||
</mxCell>
|
||||
<mxCell id="23" value="" style="endArrow=classic;html=1;entryX=0.5;entryY=1;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" target="24" edge="1">
|
||||
<mxGeometry width="50" height="50" relative="1" as="geometry">
|
||||
<mxPoint x="236.85714285714286" y="80" as="sourcePoint"/>
|
||||
<mxPoint x="349.7619047619048" y="85" as="targetPoint"/>
|
||||
<Array as="points">
|
||||
<mxPoint x="236.67904761904765" y="125"/>
|
||||
<mxPoint x="292.38095238095235" y="125"/>
|
||||
<mxPoint x="355" y="125"/>
|
||||
</Array>
|
||||
</mxGeometry>
|
||||
</mxCell>
|
||||
<mxCell id="28" value="" style="edgeStyle=none;html=1;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" source="24" target="7" edge="1">
|
||||
<mxGeometry relative="1" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="24" value="+" style="ellipse;whiteSpace=wrap;html=1;aspect=fixed;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" vertex="1">
|
||||
<mxGeometry x="350.00190476190477" y="75" width="10" height="10" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="25" value="" style="endArrow=classic;html=1;entryX=0.5;entryY=1;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" target="18" edge="1">
|
||||
<mxGeometry width="50" height="50" relative="1" as="geometry">
|
||||
<mxPoint x="34.325581395348834" y="80" as="sourcePoint"/>
|
||||
<mxPoint x="150.71428571428578" y="85" as="targetPoint"/>
|
||||
<Array as="points">
|
||||
<mxPoint x="34.25858250276859" y="130"/>
|
||||
<mxPoint x="89.96048726467328" y="130"/>
|
||||
<mxPoint x="155" y="130"/>
|
||||
</Array>
|
||||
</mxGeometry>
|
||||
</mxCell>
|
||||
<mxCell id="36" style="edgeStyle=none;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="32" target="3" edge="1">
|
||||
<mxGeometry relative="1" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="32" value="+" style="rounded=0;whiteSpace=wrap;html=1;fillColor=#f5f5f5;fontColor=#333333;strokeColor=#666666;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
|
||||
<mxGeometry x="90" width="110" height="160" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="33" value="Token Emb" style="rounded=0;whiteSpace=wrap;html=1;strokeColor=#6c8ebf;fillColor=#dae8fc;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
|
||||
<mxGeometry x="95" y="17.5" width="100" height="42.5" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="34" value="Position Emb" style="rounded=0;whiteSpace=wrap;html=1;strokeColor=#9673a6;fillColor=#e1d5e7;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
|
||||
<mxGeometry x="95" y="100" width="100" height="42.5" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="46" style="edgeStyle=none;html=1;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="37" target="40" edge="1">
|
||||
<mxGeometry relative="1" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="37" value="Decoder" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#d5e8d4;strokeColor=#82b366;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
|
||||
<mxGeometry x="690" width="70" height="160" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="38" style="edgeStyle=none;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="7" target="37" edge="1">
|
||||
<mxGeometry relative="1" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="47" style="edgeStyle=none;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="40" target="44" edge="1">
|
||||
<mxGeometry relative="1" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="40" value="Decoder" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#d5e8d4;strokeColor=#82b366;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
|
||||
<mxGeometry x="790" width="70" height="160" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="49" style="edgeStyle=none;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="41" target="42" edge="1">
|
||||
<mxGeometry relative="1" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="41" value="Decoder" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#d5e8d4;strokeColor=#82b366;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
|
||||
<mxGeometry x="950" width="70" height="160" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="52" style="edgeStyle=none;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="42" target="50" edge="1">
|
||||
<mxGeometry relative="1" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="42" value="Decoder" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#d5e8d4;strokeColor=#82b366;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
|
||||
<mxGeometry x="1050" width="70" height="160" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="48" style="edgeStyle=none;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="44" target="41" edge="1">
|
||||
<mxGeometry relative="1" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="44" value=".<div>.</div><div>.</div>" style="rounded=1;whiteSpace=wrap;html=1;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
|
||||
<mxGeometry x="890" y="40" width="30" height="80" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="53" style="edgeStyle=none;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="50" target="51" edge="1">
|
||||
<mxGeometry relative="1" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="50" value="Linear" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#f5f5f5;fontColor=#333333;strokeColor=#666666;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
|
||||
<mxGeometry x="1150" y="5" width="50" height="150" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="51" value="Softmax" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#e1d5e7;strokeColor=#9673a6;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
|
||||
<mxGeometry x="1236" y="5" width="50" height="150" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="54" value="Tokens" style="rounded=1;whiteSpace=wrap;html=1;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
|
||||
<mxGeometry y="40" width="60" height="90" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="55" style="edgeStyle=none;html=1;entryX=-0.025;entryY=0.538;entryDx=0;entryDy=0;entryPerimeter=0;exitX=1;exitY=0.5;exitDx=0;exitDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="54" edge="1">
|
||||
<mxGeometry relative="1" as="geometry">
|
||||
<mxPoint x="42.75" y="84.66941747572821" as="sourcePoint"/>
|
||||
<mxPoint x="90" y="85.33000000000004" as="targetPoint"/>
|
||||
</mxGeometry>
|
||||
</mxCell>
|
||||
</root>
|
||||
</mxGraphModel>
|
||||
</diagram>
|
||||
</mxfile>
|
||||
413
assets/drawio/gpt1-attention.drawio
Normal file
413
assets/drawio/gpt1-attention.drawio
Normal file
@@ -0,0 +1,413 @@
|
||||
<mxfile host="65bd71144e">
|
||||
<diagram name="GPT Architecture" id="DEYydPS-O6mnllJWumln">
|
||||
<mxGraphModel dx="2176" dy="702" grid="1" gridSize="10" guides="1" tooltips="1" connect="1" arrows="1" fold="1" page="1" pageScale="1" pageWidth="827" pageHeight="1169" math="0" shadow="0">
|
||||
<root>
|
||||
<mxCell id="0"/>
|
||||
<mxCell id="1" parent="0"/>
|
||||
<mxCell id="92" value="" style="group" parent="1" vertex="1" connectable="0">
|
||||
<mxGeometry x="40" y="320" width="1286" height="160" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="3" value="" style="rounded=0;whiteSpace=wrap;html=1;fillColor=#fff2cc;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;container=0;" parent="92" vertex="1">
|
||||
<mxGeometry x="230" width="440" height="160" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="36" style="edgeStyle=none;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="32" target="3" edge="1">
|
||||
<mxGeometry relative="1" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="32" value="+" style="rounded=0;whiteSpace=wrap;html=1;fillColor=#f5f5f5;fontColor=#333333;strokeColor=#666666;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
|
||||
<mxGeometry x="90" width="110" height="160" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="33" value="Token Emb" style="rounded=0;whiteSpace=wrap;html=1;strokeColor=#6c8ebf;fillColor=#dae8fc;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
|
||||
<mxGeometry x="95" y="17.5" width="100" height="42.5" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="34" value="Position Emb" style="rounded=0;whiteSpace=wrap;html=1;strokeColor=#9673a6;fillColor=#e1d5e7;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
|
||||
<mxGeometry x="95" y="100" width="100" height="42.5" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="46" style="edgeStyle=none;html=1;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="37" target="40" edge="1">
|
||||
<mxGeometry relative="1" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="37" value="Decoder" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#d5e8d4;strokeColor=#82b366;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
|
||||
<mxGeometry x="690" width="70" height="160" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="38" style="edgeStyle=none;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="7" target="37" edge="1">
|
||||
<mxGeometry relative="1" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="47" style="edgeStyle=none;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="40" target="44" edge="1">
|
||||
<mxGeometry relative="1" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="40" value="Decoder" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#d5e8d4;strokeColor=#82b366;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
|
||||
<mxGeometry x="790" width="70" height="160" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="49" style="edgeStyle=none;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="41" target="42" edge="1">
|
||||
<mxGeometry relative="1" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="41" value="Decoder" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#d5e8d4;strokeColor=#82b366;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
|
||||
<mxGeometry x="950" width="70" height="160" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="52" style="edgeStyle=none;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="42" target="50" edge="1">
|
||||
<mxGeometry relative="1" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="42" value="Decoder" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#d5e8d4;strokeColor=#82b366;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
|
||||
<mxGeometry x="1050" width="70" height="160" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="48" style="edgeStyle=none;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="44" target="41" edge="1">
|
||||
<mxGeometry relative="1" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="44" value=".<div>.</div><div>.</div>" style="rounded=1;whiteSpace=wrap;html=1;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
|
||||
<mxGeometry x="890" y="40" width="30" height="80" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="53" style="edgeStyle=none;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="50" target="51" edge="1">
|
||||
<mxGeometry relative="1" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="50" value="Linear" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#f5f5f5;fontColor=#333333;strokeColor=#666666;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
|
||||
<mxGeometry x="1150" y="5" width="50" height="150" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="51" value="Softmax" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#e1d5e7;strokeColor=#9673a6;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
|
||||
<mxGeometry x="1236" y="5" width="50" height="150" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="54" value="Tokens" style="rounded=1;whiteSpace=wrap;html=1;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
|
||||
<mxGeometry y="40" width="60" height="90" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="55" style="edgeStyle=none;html=1;entryX=-0.025;entryY=0.538;entryDx=0;entryDy=0;entryPerimeter=0;exitX=1;exitY=0.5;exitDx=0;exitDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="54" edge="1">
|
||||
<mxGeometry relative="1" as="geometry">
|
||||
<mxPoint x="42.75" y="84.66941747572821" as="sourcePoint"/>
|
||||
<mxPoint x="90" y="85.33000000000004" as="targetPoint"/>
|
||||
</mxGeometry>
|
||||
</mxCell>
|
||||
<mxCell id="4" value="<div>Masked</div>Multi+Head<br>Attention" style="rounded=0;whiteSpace=wrap;html=1;strokeColor=#FF3333;fillColor=#dae8fc;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
|
||||
<mxGeometry x="281.42776556776556" y="50" width="78.97435897435898" height="60" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="22" value="" style="edgeStyle=none;html=1;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="5" edge="1">
|
||||
<mxGeometry relative="1" as="geometry">
|
||||
<mxPoint x="580" y="80" as="targetPoint"/>
|
||||
</mxGeometry>
|
||||
</mxCell>
|
||||
<mxCell id="5" value="Feed<div>Forward</div><div>Network</div>" style="rounded=0;whiteSpace=wrap;html=1;strokeColor=#9673a6;fillColor=#e1d5e7;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
|
||||
<mxGeometry x="490.9564102564102" y="50" width="71.9230769230769" height="60" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="7" value="Norm" style="rounded=0;whiteSpace=wrap;html=1;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
|
||||
<mxGeometry x="609.9976190476191" y="60" width="37.87142857142857" height="40" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="21" style="edgeStyle=none;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="12" target="5" edge="1">
|
||||
<mxGeometry relative="1" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="12" value="Norm" style="rounded=0;whiteSpace=wrap;html=1;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
|
||||
<mxGeometry x="407.1428571428571" y="60" width="41.904761904761905" height="40" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="13" value="" style="endArrow=classic;html=1;exitX=0;exitY=0.5;exitDx=0;exitDy=0;edgeStyle=elbowEdgeStyle;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="3" target="4" edge="1">
|
||||
<mxGeometry width="50" height="50" relative="1" as="geometry">
|
||||
<mxPoint x="250" y="80.00000000000011" as="sourcePoint"/>
|
||||
<mxPoint x="459.5238095238095" y="-50" as="targetPoint"/>
|
||||
</mxGeometry>
|
||||
</mxCell>
|
||||
<mxCell id="14" value="" style="endArrow=classic;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;edgeStyle=orthogonalEdgeStyle;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="18" target="12" edge="1">
|
||||
<mxGeometry width="50" height="50" relative="1" as="geometry">
|
||||
<mxPoint x="385.71428571427464" y="79.99999999999989" as="sourcePoint"/>
|
||||
<mxPoint x="459.5238095238095" y="-50" as="targetPoint"/>
|
||||
</mxGeometry>
|
||||
</mxCell>
|
||||
<mxCell id="18" value="+" style="ellipse;whiteSpace=wrap;html=1;aspect=fixed;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
|
||||
<mxGeometry x="380.00428571428574" y="75" width="10" height="10" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="19" value="" style="endArrow=classic;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;edgeStyle=orthogonalEdgeStyle;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="4" target="18" edge="1">
|
||||
<mxGeometry width="50" height="50" relative="1" as="geometry">
|
||||
<mxPoint x="443.80952380952385" y="410" as="sourcePoint"/>
|
||||
<mxPoint x="375.7142857142858" y="80.00000000000011" as="targetPoint"/>
|
||||
</mxGeometry>
|
||||
</mxCell>
|
||||
<mxCell id="23" value="" style="endArrow=classic;html=1;entryX=0.5;entryY=1;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" target="24" edge="1">
|
||||
<mxGeometry width="50" height="50" relative="1" as="geometry">
|
||||
<mxPoint x="466.8571428571429" y="80" as="sourcePoint"/>
|
||||
<mxPoint x="579.7619047619048" y="85" as="targetPoint"/>
|
||||
<Array as="points">
|
||||
<mxPoint x="466.67904761904765" y="125"/>
|
||||
<mxPoint x="522.3809523809523" y="125"/>
|
||||
<mxPoint x="585" y="125"/>
|
||||
</Array>
|
||||
</mxGeometry>
|
||||
</mxCell>
|
||||
<mxCell id="28" value="" style="edgeStyle=none;html=1;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="24" target="7" edge="1">
|
||||
<mxGeometry relative="1" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="24" value="+" style="ellipse;whiteSpace=wrap;html=1;aspect=fixed;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
|
||||
<mxGeometry x="580.0019047619048" y="75" width="10" height="10" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="25" value="" style="endArrow=classic;html=1;entryX=0.5;entryY=1;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" target="18" edge="1">
|
||||
<mxGeometry width="50" height="50" relative="1" as="geometry">
|
||||
<mxPoint x="264.3255813953488" y="80" as="sourcePoint"/>
|
||||
<mxPoint x="380.7142857142858" y="85" as="targetPoint"/>
|
||||
<Array as="points">
|
||||
<mxPoint x="264.2585825027686" y="130"/>
|
||||
<mxPoint x="319.96048726467325" y="130"/>
|
||||
<mxPoint x="385" y="130"/>
|
||||
</Array>
|
||||
</mxGeometry>
|
||||
</mxCell>
|
||||
<mxCell id="141" value="" style="endArrow=none;dashed=1;html=1;" edge="1" parent="92">
|
||||
<mxGeometry width="50" height="50" relative="1" as="geometry">
|
||||
<mxPoint x="180" y="250" as="sourcePoint"/>
|
||||
<mxPoint x="281" y="110" as="targetPoint"/>
|
||||
</mxGeometry>
|
||||
</mxCell>
|
||||
<mxCell id="142" value="" style="endArrow=none;dashed=1;html=1;entryX=1;entryY=1;entryDx=0;entryDy=0;" edge="1" parent="1" target="4">
|
||||
<mxGeometry width="50" height="50" relative="1" as="geometry">
|
||||
<mxPoint x="620" y="560" as="sourcePoint"/>
|
||||
<mxPoint x="660" y="520" as="targetPoint"/>
|
||||
</mxGeometry>
|
||||
</mxCell>
|
||||
<mxCell id="218" value="" style="group;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" connectable="0" parent="1">
|
||||
<mxGeometry x="130" y="660" width="680" height="160" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="195" style="edgeStyle=orthogonalEdgeStyle;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" edge="1" parent="218" source="143" target="147">
|
||||
<mxGeometry relative="1" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="196" style="edgeStyle=none;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" edge="1" parent="218" source="143" target="148">
|
||||
<mxGeometry relative="1" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="197" style="edgeStyle=orthogonalEdgeStyle;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" edge="1" parent="218" source="143" target="149">
|
||||
<mxGeometry relative="1" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="143" value="X" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#fff2cc;strokeColor=#d6b656;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="218">
|
||||
<mxGeometry y="60" width="40" height="40" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="147" value="W<sub>k</sub>" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#f5f5f5;fontColor=#333333;strokeColor=#666666;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="218">
|
||||
<mxGeometry x="80" width="40" height="40" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="199" style="edgeStyle=none;html=1;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" edge="1" parent="218" source="148" target="151">
|
||||
<mxGeometry relative="1" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="148" value="W<sub>q</sub>" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#f5f5f5;fontColor=#333333;strokeColor=#666666;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="218">
|
||||
<mxGeometry x="80" y="60" width="40" height="40" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="149" value="W<sub>v</sub>" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#f5f5f5;fontColor=#333333;strokeColor=#666666;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="218">
|
||||
<mxGeometry x="80" y="120" width="40" height="40" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="207" style="edgeStyle=orthogonalEdgeStyle;html=1;entryX=0;entryY=1;entryDx=0;entryDy=0;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" edge="1" parent="218" source="150" target="158">
|
||||
<mxGeometry relative="1" as="geometry">
|
||||
<Array as="points">
|
||||
<mxPoint x="236" y="20"/>
|
||||
<mxPoint x="236" y="50"/>
|
||||
</Array>
|
||||
</mxGeometry>
|
||||
</mxCell>
|
||||
<mxCell id="150" value="K" style="rounded=1;whiteSpace=wrap;html=1;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="218">
|
||||
<mxGeometry x="158.97000000000003" width="41.03" height="40" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="208" style="edgeStyle=orthogonalEdgeStyle;html=1;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" edge="1" parent="218" source="151" target="190">
|
||||
<mxGeometry relative="1" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="151" value="Q" style="rounded=1;whiteSpace=wrap;html=1;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="218">
|
||||
<mxGeometry x="158.97000000000003" y="60" width="41.03" height="40" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="214" style="edgeStyle=orthogonalEdgeStyle;html=1;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" edge="1" parent="218" source="152" target="187">
|
||||
<mxGeometry relative="1" as="geometry">
|
||||
<Array as="points">
|
||||
<mxPoint x="600" y="140"/>
|
||||
<mxPoint x="600" y="80"/>
|
||||
</Array>
|
||||
</mxGeometry>
|
||||
</mxCell>
|
||||
<mxCell id="152" value="V" style="rounded=1;whiteSpace=wrap;html=1;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="218">
|
||||
<mxGeometry x="158.97000000000003" y="120" width="40" height="40" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="215" style="edgeStyle=none;html=1;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" edge="1" parent="218" source="187">
|
||||
<mxGeometry relative="1" as="geometry">
|
||||
<mxPoint x="680" y="80" as="targetPoint"/>
|
||||
</mxGeometry>
|
||||
</mxCell>
|
||||
<mxCell id="187" value="O" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#fff2cc;strokeColor=#d6b656;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="218">
|
||||
<mxGeometry x="620" y="60" width="40" height="40" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="211" style="edgeStyle=none;html=1;entryX=0;entryY=0;entryDx=0;entryDy=0;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" edge="1" parent="218" source="188" target="179">
|
||||
<mxGeometry relative="1" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="188" value="Scale" style="rounded=1;whiteSpace=wrap;html=1;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;direction=east;rotation=90;fillColor=#f8cecc;strokeColor=#b85450;" vertex="1" parent="218">
|
||||
<mxGeometry x="370" y="37.5" width="50" height="25" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="213" style="edgeStyle=orthogonalEdgeStyle;html=1;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" edge="1" parent="218" source="189" target="187">
|
||||
<mxGeometry relative="1" as="geometry">
|
||||
<Array as="points">
|
||||
<mxPoint x="600" y="50"/>
|
||||
<mxPoint x="600" y="80"/>
|
||||
</Array>
|
||||
</mxGeometry>
|
||||
</mxCell>
|
||||
<mxCell id="189" value="Softmax" style="rounded=1;whiteSpace=wrap;html=1;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;direction=east;rotation=90;fillColor=#e1d5e7;strokeColor=#9673a6;" vertex="1" parent="218">
|
||||
<mxGeometry x="530" y="37.5" width="50" height="25" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="209" style="edgeStyle=none;html=1;entryX=0.5;entryY=1;entryDx=0;entryDy=0;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" edge="1" parent="218" source="190" target="188">
|
||||
<mxGeometry relative="1" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="190" value="" style="group;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" connectable="0" parent="218">
|
||||
<mxGeometry x="272.5" y="10" width="80" height="80" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="153" value="" style="whiteSpace=wrap;html=1;aspect=fixed;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="190">
|
||||
<mxGeometry width="80" height="80" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="154" value="" style="rounded=0;whiteSpace=wrap;html=1;fillColor=#dae8fc;strokeColor=#6c8ebf;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="190">
|
||||
<mxGeometry width="20" height="20" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="155" value="" style="rounded=0;whiteSpace=wrap;html=1;fillColor=#dae8fc;strokeColor=#6c8ebf;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="190">
|
||||
<mxGeometry x="20" width="20" height="20" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="156" value="" style="rounded=0;whiteSpace=wrap;html=1;fillColor=#dae8fc;strokeColor=#6c8ebf;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="190">
|
||||
<mxGeometry x="40" width="20" height="20" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="157" value="" style="rounded=0;whiteSpace=wrap;html=1;fillColor=#dae8fc;strokeColor=#6c8ebf;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="190">
|
||||
<mxGeometry x="60" width="20" height="20" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="158" value="" style="rounded=0;whiteSpace=wrap;html=1;fillColor=#dae8fc;strokeColor=#6c8ebf;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="190">
|
||||
<mxGeometry y="20" width="20" height="20" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="159" value="" style="rounded=0;whiteSpace=wrap;html=1;fillColor=#dae8fc;strokeColor=#6c8ebf;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="190">
|
||||
<mxGeometry x="20" y="20" width="20" height="20" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="160" value="" style="rounded=0;whiteSpace=wrap;html=1;fillColor=#dae8fc;strokeColor=#6c8ebf;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="190">
|
||||
<mxGeometry x="40" y="20" width="20" height="20" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="161" value="" style="rounded=0;whiteSpace=wrap;html=1;fillColor=#dae8fc;strokeColor=#6c8ebf;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="190">
|
||||
<mxGeometry x="60" y="20" width="20" height="20" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="162" value="" style="rounded=0;whiteSpace=wrap;html=1;fillColor=#dae8fc;strokeColor=#6c8ebf;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="190">
|
||||
<mxGeometry y="40" width="20" height="20" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="163" value="" style="rounded=0;whiteSpace=wrap;html=1;fillColor=#dae8fc;strokeColor=#6c8ebf;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="190">
|
||||
<mxGeometry x="20" y="40" width="20" height="20" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="164" value="" style="rounded=0;whiteSpace=wrap;html=1;fillColor=#dae8fc;strokeColor=#6c8ebf;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="190">
|
||||
<mxGeometry x="40" y="40" width="20" height="20" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="165" value="" style="rounded=0;whiteSpace=wrap;html=1;fillColor=#dae8fc;strokeColor=#6c8ebf;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="190">
|
||||
<mxGeometry x="60" y="40" width="20" height="20" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="166" value="" style="rounded=0;whiteSpace=wrap;html=1;fillColor=#dae8fc;strokeColor=#6c8ebf;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="190">
|
||||
<mxGeometry y="60" width="20" height="20" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="167" value="" style="rounded=0;whiteSpace=wrap;html=1;fillColor=#dae8fc;strokeColor=#6c8ebf;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="190">
|
||||
<mxGeometry x="20" y="60" width="20" height="20" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="168" value="" style="rounded=0;whiteSpace=wrap;html=1;fillColor=#dae8fc;strokeColor=#6c8ebf;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="190">
|
||||
<mxGeometry x="40" y="60" width="20" height="20" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="169" value="" style="rounded=0;whiteSpace=wrap;html=1;fillColor=#dae8fc;strokeColor=#6c8ebf;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="190">
|
||||
<mxGeometry x="60" y="60" width="20" height="20" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="191" value="" style="group;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" connectable="0" parent="218">
|
||||
<mxGeometry x="440" y="10" width="80" height="80" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="170" value="" style="whiteSpace=wrap;html=1;aspect=fixed;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="191">
|
||||
<mxGeometry width="80" height="80" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="171" value="" style="rounded=0;whiteSpace=wrap;html=1;fillColor=#dae8fc;strokeColor=#6c8ebf;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="191">
|
||||
<mxGeometry width="20" height="20" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="172" value="" style="rounded=0;whiteSpace=wrap;html=1;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="191">
|
||||
<mxGeometry x="20" width="20" height="20" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="173" value="" style="rounded=0;whiteSpace=wrap;html=1;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="191">
|
||||
<mxGeometry x="40" width="20" height="20" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="174" value="" style="rounded=0;whiteSpace=wrap;html=1;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="191">
|
||||
<mxGeometry x="60" width="20" height="20" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="175" value="" style="rounded=0;whiteSpace=wrap;html=1;fillColor=#dae8fc;strokeColor=#6c8ebf;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="191">
|
||||
<mxGeometry y="20" width="20" height="20" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="176" value="" style="rounded=0;whiteSpace=wrap;html=1;fillColor=#dae8fc;strokeColor=#6c8ebf;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="191">
|
||||
<mxGeometry x="20" y="20" width="20" height="20" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="177" value="" style="rounded=0;whiteSpace=wrap;html=1;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="191">
|
||||
<mxGeometry x="40" y="20" width="20" height="20" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="178" value="" style="rounded=0;whiteSpace=wrap;html=1;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="191">
|
||||
<mxGeometry x="60" y="20" width="20" height="20" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="179" value="" style="rounded=0;whiteSpace=wrap;html=1;fillColor=#dae8fc;strokeColor=#6c8ebf;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="191">
|
||||
<mxGeometry y="40" width="20" height="20" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="180" value="" style="rounded=0;whiteSpace=wrap;html=1;fillColor=#dae8fc;strokeColor=#6c8ebf;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="191">
|
||||
<mxGeometry x="20" y="40" width="20" height="20" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="181" value="" style="rounded=0;whiteSpace=wrap;html=1;fillColor=#dae8fc;strokeColor=#6c8ebf;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="191">
|
||||
<mxGeometry x="40" y="40" width="20" height="20" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="182" value="" style="rounded=0;whiteSpace=wrap;html=1;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="191">
|
||||
<mxGeometry x="60" y="40" width="20" height="20" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="183" value="" style="rounded=0;whiteSpace=wrap;html=1;fillColor=#dae8fc;strokeColor=#6c8ebf;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="191">
|
||||
<mxGeometry y="60" width="20" height="20" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="184" value="" style="rounded=0;whiteSpace=wrap;html=1;fillColor=#dae8fc;strokeColor=#6c8ebf;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="191">
|
||||
<mxGeometry x="20" y="60" width="20" height="20" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="185" value="" style="rounded=0;whiteSpace=wrap;html=1;fillColor=#dae8fc;strokeColor=#6c8ebf;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="191">
|
||||
<mxGeometry x="40" y="60" width="20" height="20" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="186" value="" style="rounded=0;whiteSpace=wrap;html=1;fillColor=#dae8fc;strokeColor=#6c8ebf;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="191">
|
||||
<mxGeometry x="60" y="60" width="20" height="20" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="198" style="edgeStyle=none;html=1;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" edge="1" parent="218" source="147" target="150">
|
||||
<mxGeometry relative="1" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="200" style="edgeStyle=none;html=1;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" edge="1" parent="218" target="152">
|
||||
<mxGeometry relative="1" as="geometry">
|
||||
<mxPoint x="120" y="140" as="sourcePoint"/>
|
||||
<mxPoint x="148.97000000000008" y="139.8599999999998" as="targetPoint"/>
|
||||
</mxGeometry>
|
||||
</mxCell>
|
||||
<mxCell id="212" value="" style="endArrow=classic;html=1;exitX=1;exitY=0;exitDx=0;exitDy=0;entryX=0.5;entryY=1;entryDx=0;entryDy=0;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" edge="1" parent="218" source="182" target="189">
|
||||
<mxGeometry width="50" height="50" relative="1" as="geometry">
|
||||
<mxPoint x="610" y="50" as="sourcePoint"/>
|
||||
<mxPoint x="660" as="targetPoint"/>
|
||||
</mxGeometry>
|
||||
</mxCell>
|
||||
<mxCell id="219" value="" style="group;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" connectable="0" parent="1">
|
||||
<mxGeometry x="289.99776556776555" y="520" width="250.00223443223445" height="90" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="145" style="edgeStyle=none;html=1;exitX=0.5;exitY=0;exitDx=0;exitDy=0;entryX=0.5;entryY=1;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" edge="1" parent="219" source="133" target="144">
|
||||
<mxGeometry relative="1" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="133" value="Concat" style="rounded=1;whiteSpace=wrap;html=1;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;direction=east;rotation=90;" vertex="1" parent="219">
|
||||
<mxGeometry x="132.50223443223445" y="32.5" width="50" height="25" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="136" value="" style="endArrow=classic;html=1;entryX=0.5;entryY=1;entryDx=0;entryDy=0;exitX=1;exitY=0.5;exitDx=0;exitDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" edge="1" parent="219" target="133">
|
||||
<mxGeometry width="50" height="50" relative="1" as="geometry">
|
||||
<mxPoint x="108.97435897435912" y="45" as="sourcePoint"/>
|
||||
<mxPoint x="250.00223443223445" y="25" as="targetPoint"/>
|
||||
</mxGeometry>
|
||||
</mxCell>
|
||||
<mxCell id="129" value="<div>Masked</div>Multi+Head<br>Attention" style="rounded=0;whiteSpace=wrap;html=1;strokeColor=#FF3333;fillColor=#dae8fc;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" vertex="1" parent="219">
|
||||
<mxGeometry x="30" width="78.97435897435898" height="60" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="130" value="<div>Masked</div>Multi+Head<br>Attention" style="rounded=0;whiteSpace=wrap;html=1;strokeColor=#FF3333;fillColor=#dae8fc;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" vertex="1" parent="219">
|
||||
<mxGeometry x="20" y="10" width="78.97435897435898" height="60" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="131" value="<div>Masked</div>Multi+Head<br>Attention" style="rounded=0;whiteSpace=wrap;html=1;strokeColor=#FF3333;fillColor=#dae8fc;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" vertex="1" parent="219">
|
||||
<mxGeometry x="10" y="20" width="78.97435897435898" height="60" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="132" value="<div>Masked</div>Multi+Head<br>Attention" style="rounded=0;whiteSpace=wrap;html=1;strokeColor=#FF3333;fillColor=#dae8fc;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" vertex="1" parent="219">
|
||||
<mxGeometry y="30" width="78.97435897435898" height="60" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="146" style="edgeStyle=none;html=1;exitX=0.5;exitY=0;exitDx=0;exitDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" edge="1" parent="219" source="144">
|
||||
<mxGeometry relative="1" as="geometry">
|
||||
<mxPoint x="250.00223443223445" y="44.969696969697" as="targetPoint"/>
|
||||
</mxGeometry>
|
||||
</mxCell>
|
||||
<mxCell id="144" value="Linear" style="rounded=1;whiteSpace=wrap;html=1;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;direction=east;rotation=90;" vertex="1" parent="219">
|
||||
<mxGeometry x="182.50223443223445" y="32.5" width="50" height="25" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="220" value="" style="endArrow=none;dashed=1;html=1;" edge="1" parent="1">
|
||||
<mxGeometry width="50" height="50" relative="1" as="geometry">
|
||||
<mxPoint x="90" y="690" as="sourcePoint"/>
|
||||
<mxPoint x="290" y="610" as="targetPoint"/>
|
||||
</mxGeometry>
|
||||
</mxCell>
|
||||
<mxCell id="221" value="" style="endArrow=none;dashed=1;html=1;" edge="1" parent="1">
|
||||
<mxGeometry width="50" height="50" relative="1" as="geometry">
|
||||
<mxPoint x="370" y="610" as="sourcePoint"/>
|
||||
<mxPoint x="830" y="700" as="targetPoint"/>
|
||||
</mxGeometry>
|
||||
</mxCell>
|
||||
</root>
|
||||
</mxGraphModel>
|
||||
</diagram>
|
||||
</mxfile>
|
||||
148
assets/drawio/gpt1-decoder.drawio
Normal file
148
assets/drawio/gpt1-decoder.drawio
Normal file
@@ -0,0 +1,148 @@
|
||||
<mxfile host="65bd71144e">
|
||||
<diagram name="GPT Architecture" id="DEYydPS-O6mnllJWumln">
|
||||
<mxGraphModel dx="979" dy="301" grid="1" gridSize="10" guides="1" tooltips="1" connect="1" arrows="1" fold="1" page="1" pageScale="1" pageWidth="827" pageHeight="1169" math="0" shadow="0">
|
||||
<root>
|
||||
<mxCell id="0"/>
|
||||
<mxCell id="1" parent="0"/>
|
||||
<mxCell id="92" value="" style="group" parent="1" vertex="1" connectable="0">
|
||||
<mxGeometry x="40" y="320" width="1286" height="160" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="3" value="" style="rounded=0;whiteSpace=wrap;html=1;fillColor=#fff2cc;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;strokeColor=#FF3333;" parent="92" vertex="1">
|
||||
<mxGeometry x="230" width="440" height="160" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="4" value="<div>Masked</div>Multi+Head<br>Attention" style="rounded=0;whiteSpace=wrap;html=1;strokeColor=#6c8ebf;fillColor=#dae8fc;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" vertex="1">
|
||||
<mxGeometry x="51.42776556776556" y="50" width="78.97435897435898" height="60" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="22" value="" style="edgeStyle=none;html=1;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" source="5" edge="1">
|
||||
<mxGeometry relative="1" as="geometry">
|
||||
<mxPoint x="350" y="80" as="targetPoint"/>
|
||||
</mxGeometry>
|
||||
</mxCell>
|
||||
<mxCell id="5" value="Feed<div>Forward</div><div>Network</div>" style="rounded=0;whiteSpace=wrap;html=1;strokeColor=#9673a6;fillColor=#e1d5e7;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" vertex="1">
|
||||
<mxGeometry x="260.9564102564102" y="50" width="71.9230769230769" height="60" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="7" value="Norm" style="rounded=0;whiteSpace=wrap;html=1;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" vertex="1">
|
||||
<mxGeometry x="379.997619047619" y="60" width="37.87142857142857" height="40" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="21" style="edgeStyle=none;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" source="12" target="5" edge="1">
|
||||
<mxGeometry relative="1" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="12" value="Norm" style="rounded=0;whiteSpace=wrap;html=1;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" vertex="1">
|
||||
<mxGeometry x="177.14285714285714" y="60" width="41.904761904761905" height="40" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="13" value="" style="endArrow=classic;html=1;exitX=0;exitY=0.5;exitDx=0;exitDy=0;edgeStyle=elbowEdgeStyle;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" source="3" target="4" edge="1">
|
||||
<mxGeometry width="50" height="50" relative="1" as="geometry">
|
||||
<mxPoint x="20" y="80.00000000000011" as="sourcePoint"/>
|
||||
<mxPoint x="229.52380952380952" y="-50" as="targetPoint"/>
|
||||
</mxGeometry>
|
||||
</mxCell>
|
||||
<mxCell id="14" value="" style="endArrow=classic;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;edgeStyle=orthogonalEdgeStyle;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" source="18" target="12" edge="1">
|
||||
<mxGeometry width="50" height="50" relative="1" as="geometry">
|
||||
<mxPoint x="155.71428571427464" y="79.99999999999989" as="sourcePoint"/>
|
||||
<mxPoint x="229.52380952380952" y="-50" as="targetPoint"/>
|
||||
</mxGeometry>
|
||||
</mxCell>
|
||||
<mxCell id="18" value="+" style="ellipse;whiteSpace=wrap;html=1;aspect=fixed;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" vertex="1">
|
||||
<mxGeometry x="150.00428571428571" y="75" width="10" height="10" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="19" value="" style="endArrow=classic;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;edgeStyle=orthogonalEdgeStyle;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" source="4" target="18" edge="1">
|
||||
<mxGeometry width="50" height="50" relative="1" as="geometry">
|
||||
<mxPoint x="213.80952380952382" y="410" as="sourcePoint"/>
|
||||
<mxPoint x="145.71428571428578" y="80.00000000000011" as="targetPoint"/>
|
||||
</mxGeometry>
|
||||
</mxCell>
|
||||
<mxCell id="23" value="" style="endArrow=classic;html=1;entryX=0.5;entryY=1;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" target="24" edge="1">
|
||||
<mxGeometry width="50" height="50" relative="1" as="geometry">
|
||||
<mxPoint x="236.85714285714286" y="80" as="sourcePoint"/>
|
||||
<mxPoint x="349.7619047619048" y="85" as="targetPoint"/>
|
||||
<Array as="points">
|
||||
<mxPoint x="236.67904761904765" y="125"/>
|
||||
<mxPoint x="292.38095238095235" y="125"/>
|
||||
<mxPoint x="355" y="125"/>
|
||||
</Array>
|
||||
</mxGeometry>
|
||||
</mxCell>
|
||||
<mxCell id="28" value="" style="edgeStyle=none;html=1;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" source="24" target="7" edge="1">
|
||||
<mxGeometry relative="1" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="24" value="+" style="ellipse;whiteSpace=wrap;html=1;aspect=fixed;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" vertex="1">
|
||||
<mxGeometry x="350.00190476190477" y="75" width="10" height="10" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="25" value="" style="endArrow=classic;html=1;entryX=0.5;entryY=1;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" target="18" edge="1">
|
||||
<mxGeometry width="50" height="50" relative="1" as="geometry">
|
||||
<mxPoint x="34.325581395348834" y="80" as="sourcePoint"/>
|
||||
<mxPoint x="150.71428571428578" y="85" as="targetPoint"/>
|
||||
<Array as="points">
|
||||
<mxPoint x="34.25858250276859" y="130"/>
|
||||
<mxPoint x="89.96048726467328" y="130"/>
|
||||
<mxPoint x="155" y="130"/>
|
||||
</Array>
|
||||
</mxGeometry>
|
||||
</mxCell>
|
||||
<mxCell id="36" style="edgeStyle=none;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="32" target="3" edge="1">
|
||||
<mxGeometry relative="1" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="32" value="+" style="rounded=0;whiteSpace=wrap;html=1;fillColor=#f5f5f5;fontColor=#333333;strokeColor=#666666;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
|
||||
<mxGeometry x="90" width="110" height="160" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="33" value="Token Emb" style="rounded=0;whiteSpace=wrap;html=1;strokeColor=#6c8ebf;fillColor=#dae8fc;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
|
||||
<mxGeometry x="95" y="17.5" width="100" height="42.5" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="34" value="Position Emb" style="rounded=0;whiteSpace=wrap;html=1;strokeColor=#9673a6;fillColor=#e1d5e7;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
|
||||
<mxGeometry x="95" y="100" width="100" height="42.5" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="46" style="edgeStyle=none;html=1;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="37" target="40" edge="1">
|
||||
<mxGeometry relative="1" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="37" value="Decoder" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#d5e8d4;strokeColor=#82b366;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
|
||||
<mxGeometry x="690" width="70" height="160" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="38" style="edgeStyle=none;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="7" target="37" edge="1">
|
||||
<mxGeometry relative="1" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="47" style="edgeStyle=none;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="40" target="44" edge="1">
|
||||
<mxGeometry relative="1" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="40" value="Decoder" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#d5e8d4;strokeColor=#82b366;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
|
||||
<mxGeometry x="790" width="70" height="160" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="49" style="edgeStyle=none;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="41" target="42" edge="1">
|
||||
<mxGeometry relative="1" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="41" value="Decoder" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#d5e8d4;strokeColor=#82b366;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
|
||||
<mxGeometry x="950" width="70" height="160" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="52" style="edgeStyle=none;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="42" target="50" edge="1">
|
||||
<mxGeometry relative="1" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="42" value="Decoder" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#d5e8d4;strokeColor=#82b366;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
|
||||
<mxGeometry x="1050" width="70" height="160" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="48" style="edgeStyle=none;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="44" target="41" edge="1">
|
||||
<mxGeometry relative="1" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="44" value=".<div>.</div><div>.</div>" style="rounded=1;whiteSpace=wrap;html=1;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
|
||||
<mxGeometry x="890" y="40" width="30" height="80" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="53" style="edgeStyle=none;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="50" target="51" edge="1">
|
||||
<mxGeometry relative="1" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="50" value="Linear" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#f5f5f5;fontColor=#333333;strokeColor=#666666;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
|
||||
<mxGeometry x="1150" y="5" width="50" height="150" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="51" value="Softmax" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#e1d5e7;strokeColor=#9673a6;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
|
||||
<mxGeometry x="1236" y="5" width="50" height="150" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="54" value="Tokens" style="rounded=1;whiteSpace=wrap;html=1;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
|
||||
<mxGeometry y="40" width="60" height="90" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="55" style="edgeStyle=none;html=1;entryX=-0.025;entryY=0.538;entryDx=0;entryDy=0;entryPerimeter=0;exitX=1;exitY=0.5;exitDx=0;exitDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="54" edge="1">
|
||||
<mxGeometry relative="1" as="geometry">
|
||||
<mxPoint x="42.75" y="84.66941747572821" as="sourcePoint"/>
|
||||
<mxPoint x="90" y="85.33000000000004" as="targetPoint"/>
|
||||
</mxGeometry>
|
||||
</mxCell>
|
||||
</root>
|
||||
</mxGraphModel>
|
||||
</diagram>
|
||||
</mxfile>
|
||||
148
assets/drawio/gpt1-embeddings.drawio
Normal file
148
assets/drawio/gpt1-embeddings.drawio
Normal file
@@ -0,0 +1,148 @@
|
||||
<mxfile host="65bd71144e">
|
||||
<diagram name="GPT Architecture" id="DEYydPS-O6mnllJWumln">
|
||||
<mxGraphModel dx="1216" dy="316" grid="1" gridSize="10" guides="1" tooltips="1" connect="1" arrows="1" fold="1" page="1" pageScale="1" pageWidth="827" pageHeight="1169" math="0" shadow="0">
|
||||
<root>
|
||||
<mxCell id="0"/>
|
||||
<mxCell id="1" parent="0"/>
|
||||
<mxCell id="91" value="" style="group;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="1" vertex="1" connectable="0">
|
||||
<mxGeometry x="40" y="360" width="1286" height="160" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="56" value="" style="rounded=0;whiteSpace=wrap;html=1;fillColor=#fff2cc;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="91" vertex="1">
|
||||
<mxGeometry x="230" width="440" height="160" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="57" value="<div>Masked</div>Multi+Head<br>Attention" style="rounded=0;whiteSpace=wrap;html=1;strokeColor=#6c8ebf;fillColor=#dae8fc;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="56" vertex="1">
|
||||
<mxGeometry x="51.42776556776556" y="50" width="78.97435897435898" height="60" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="58" value="" style="edgeStyle=none;html=1;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="56" source="59" edge="1">
|
||||
<mxGeometry relative="1" as="geometry">
|
||||
<mxPoint x="350" y="80" as="targetPoint"/>
|
||||
</mxGeometry>
|
||||
</mxCell>
|
||||
<mxCell id="59" value="Feed<div>Forward</div><div>Network</div>" style="rounded=0;whiteSpace=wrap;html=1;strokeColor=#9673a6;fillColor=#e1d5e7;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="56" vertex="1">
|
||||
<mxGeometry x="260.9564102564102" y="50" width="71.9230769230769" height="60" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="60" value="Norm" style="rounded=0;whiteSpace=wrap;html=1;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="56" vertex="1">
|
||||
<mxGeometry x="379.997619047619" y="60" width="37.87142857142857" height="40" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="61" style="edgeStyle=none;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="56" source="62" target="59" edge="1">
|
||||
<mxGeometry relative="1" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="62" value="Norm" style="rounded=0;whiteSpace=wrap;html=1;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="56" vertex="1">
|
||||
<mxGeometry x="177.14285714285714" y="60" width="41.904761904761905" height="40" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="63" value="" style="endArrow=classic;html=1;exitX=0;exitY=0.5;exitDx=0;exitDy=0;edgeStyle=elbowEdgeStyle;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="56" source="56" target="57" edge="1">
|
||||
<mxGeometry width="50" height="50" relative="1" as="geometry">
|
||||
<mxPoint x="20" y="80.00000000000011" as="sourcePoint"/>
|
||||
<mxPoint x="229.52380952380952" y="-50" as="targetPoint"/>
|
||||
</mxGeometry>
|
||||
</mxCell>
|
||||
<mxCell id="64" value="" style="endArrow=classic;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;edgeStyle=orthogonalEdgeStyle;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="56" source="65" target="62" edge="1">
|
||||
<mxGeometry width="50" height="50" relative="1" as="geometry">
|
||||
<mxPoint x="155.71428571427464" y="79.99999999999989" as="sourcePoint"/>
|
||||
<mxPoint x="229.52380952380952" y="-50" as="targetPoint"/>
|
||||
</mxGeometry>
|
||||
</mxCell>
|
||||
<mxCell id="65" value="+" style="ellipse;whiteSpace=wrap;html=1;aspect=fixed;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="56" vertex="1">
|
||||
<mxGeometry x="150.00428571428571" y="75" width="10" height="10" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="66" value="" style="endArrow=classic;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;edgeStyle=orthogonalEdgeStyle;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="56" source="57" target="65" edge="1">
|
||||
<mxGeometry width="50" height="50" relative="1" as="geometry">
|
||||
<mxPoint x="213.80952380952382" y="410" as="sourcePoint"/>
|
||||
<mxPoint x="145.71428571428578" y="80.00000000000011" as="targetPoint"/>
|
||||
</mxGeometry>
|
||||
</mxCell>
|
||||
<mxCell id="67" value="" style="endArrow=classic;html=1;entryX=0.5;entryY=1;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="56" target="69" edge="1">
|
||||
<mxGeometry width="50" height="50" relative="1" as="geometry">
|
||||
<mxPoint x="236.85714285714286" y="80" as="sourcePoint"/>
|
||||
<mxPoint x="349.7619047619048" y="85" as="targetPoint"/>
|
||||
<Array as="points">
|
||||
<mxPoint x="236.67904761904765" y="125"/>
|
||||
<mxPoint x="292.38095238095235" y="125"/>
|
||||
<mxPoint x="355" y="125"/>
|
||||
</Array>
|
||||
</mxGeometry>
|
||||
</mxCell>
|
||||
<mxCell id="68" value="" style="edgeStyle=none;html=1;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="56" source="69" target="60" edge="1">
|
||||
<mxGeometry relative="1" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="69" value="+" style="ellipse;whiteSpace=wrap;html=1;aspect=fixed;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="56" vertex="1">
|
||||
<mxGeometry x="350.00190476190477" y="75" width="10" height="10" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="70" value="" style="endArrow=classic;html=1;entryX=0.5;entryY=1;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="56" target="65" edge="1">
|
||||
<mxGeometry width="50" height="50" relative="1" as="geometry">
|
||||
<mxPoint x="34.325581395348834" y="80" as="sourcePoint"/>
|
||||
<mxPoint x="150.71428571428578" y="85" as="targetPoint"/>
|
||||
<Array as="points">
|
||||
<mxPoint x="34.25858250276859" y="130"/>
|
||||
<mxPoint x="89.96048726467328" y="130"/>
|
||||
<mxPoint x="155" y="130"/>
|
||||
</Array>
|
||||
</mxGeometry>
|
||||
</mxCell>
|
||||
<mxCell id="71" style="edgeStyle=none;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="91" source="72" target="56" edge="1">
|
||||
<mxGeometry relative="1" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="72" value="+" style="rounded=0;whiteSpace=wrap;html=1;fillColor=#f5f5f5;fontColor=#333333;strokeColor=#FF3333;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="91" vertex="1">
|
||||
<mxGeometry x="90" width="110" height="160" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="73" value="Token Emb" style="rounded=0;whiteSpace=wrap;html=1;strokeColor=#6c8ebf;fillColor=#dae8fc;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="91" vertex="1">
|
||||
<mxGeometry x="95" y="17.5" width="100" height="42.5" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="74" value="Position Emb" style="rounded=0;whiteSpace=wrap;html=1;strokeColor=#9673a6;fillColor=#e1d5e7;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="91" vertex="1">
|
||||
<mxGeometry x="95" y="100" width="100" height="42.5" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="75" style="edgeStyle=none;html=1;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="91" source="76" target="79" edge="1">
|
||||
<mxGeometry relative="1" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="76" value="Decoder" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#d5e8d4;strokeColor=#82b366;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="91" vertex="1">
|
||||
<mxGeometry x="690" width="70" height="160" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="77" style="edgeStyle=none;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="91" source="60" target="76" edge="1">
|
||||
<mxGeometry relative="1" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="78" style="edgeStyle=none;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="91" source="79" target="85" edge="1">
|
||||
<mxGeometry relative="1" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="79" value="Decoder" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#d5e8d4;strokeColor=#82b366;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="91" vertex="1">
|
||||
<mxGeometry x="790" width="70" height="160" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="80" style="edgeStyle=none;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="91" source="81" target="83" edge="1">
|
||||
<mxGeometry relative="1" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="81" value="Decoder" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#d5e8d4;strokeColor=#82b366;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="91" vertex="1">
|
||||
<mxGeometry x="950" width="70" height="160" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="82" style="edgeStyle=none;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="91" source="83" target="87" edge="1">
|
||||
<mxGeometry relative="1" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="83" value="Decoder" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#d5e8d4;strokeColor=#82b366;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="91" vertex="1">
|
||||
<mxGeometry x="1050" width="70" height="160" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="84" style="edgeStyle=none;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="91" source="85" target="81" edge="1">
|
||||
<mxGeometry relative="1" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="85" value=".<div>.</div><div>.</div>" style="rounded=1;whiteSpace=wrap;html=1;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="91" vertex="1">
|
||||
<mxGeometry x="890" y="40" width="30" height="80" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="86" style="edgeStyle=none;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="91" source="87" target="88" edge="1">
|
||||
<mxGeometry relative="1" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="87" value="Linear" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#f5f5f5;fontColor=#333333;strokeColor=#666666;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="91" vertex="1">
|
||||
<mxGeometry x="1150" y="5" width="50" height="150" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="88" value="Softmax" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#e1d5e7;strokeColor=#9673a6;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="91" vertex="1">
|
||||
<mxGeometry x="1236" y="5" width="50" height="150" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="89" value="Tokens" style="rounded=1;whiteSpace=wrap;html=1;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="91" vertex="1">
|
||||
<mxGeometry y="40" width="60" height="90" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="90" style="edgeStyle=none;html=1;entryX=-0.025;entryY=0.538;entryDx=0;entryDy=0;entryPerimeter=0;exitX=1;exitY=0.5;exitDx=0;exitDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="91" source="89" edge="1">
|
||||
<mxGeometry relative="1" as="geometry">
|
||||
<mxPoint x="42.75" y="84.66941747572821" as="sourcePoint"/>
|
||||
<mxPoint x="90" y="85.33000000000004" as="targetPoint"/>
|
||||
</mxGeometry>
|
||||
</mxCell>
|
||||
</root>
|
||||
</mxGraphModel>
|
||||
</diagram>
|
||||
</mxfile>
|
||||
192
assets/drawio/gpt1-forward.drawio
Normal file
192
assets/drawio/gpt1-forward.drawio
Normal file
@@ -0,0 +1,192 @@
|
||||
<mxfile host="65bd71144e">
|
||||
<diagram name="GPT Architecture" id="DEYydPS-O6mnllJWumln">
|
||||
<mxGraphModel dx="2176" dy="1029" grid="1" gridSize="10" guides="1" tooltips="1" connect="1" arrows="1" fold="1" page="1" pageScale="1" pageWidth="827" pageHeight="1169" math="0" shadow="0">
|
||||
<root>
|
||||
<mxCell id="0"/>
|
||||
<mxCell id="1" parent="0"/>
|
||||
<mxCell id="107" value="" style="group" vertex="1" connectable="0" parent="1">
|
||||
<mxGeometry x="120" y="170" width="1286" height="265" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="92" value="" style="group" parent="107" vertex="1" connectable="0">
|
||||
<mxGeometry width="1286" height="160" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="3" value="" style="rounded=0;whiteSpace=wrap;html=1;fillColor=#fff2cc;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
|
||||
<mxGeometry x="230" width="440" height="160" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="4" value="<div>Masked</div>Multi+Head<br>Attention" style="rounded=0;whiteSpace=wrap;html=1;strokeColor=#6c8ebf;fillColor=#dae8fc;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" vertex="1">
|
||||
<mxGeometry x="51.42776556776556" y="50" width="78.97435897435898" height="60" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="22" value="" style="edgeStyle=none;html=1;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" source="5" edge="1">
|
||||
<mxGeometry relative="1" as="geometry">
|
||||
<mxPoint x="350" y="80" as="targetPoint"/>
|
||||
</mxGeometry>
|
||||
</mxCell>
|
||||
<mxCell id="5" value="Feed<div>Forward</div><div>Network</div>" style="rounded=0;whiteSpace=wrap;html=1;strokeColor=#FF3333;fillColor=#e1d5e7;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" vertex="1">
|
||||
<mxGeometry x="260.9564102564102" y="50" width="71.9230769230769" height="60" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="7" value="Norm" style="rounded=0;whiteSpace=wrap;html=1;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" vertex="1">
|
||||
<mxGeometry x="379.997619047619" y="60" width="37.87142857142857" height="40" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="21" style="edgeStyle=none;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" source="12" target="5" edge="1">
|
||||
<mxGeometry relative="1" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="12" value="Norm" style="rounded=0;whiteSpace=wrap;html=1;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" vertex="1">
|
||||
<mxGeometry x="177.14285714285714" y="60" width="41.904761904761905" height="40" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="13" value="" style="endArrow=classic;html=1;exitX=0;exitY=0.5;exitDx=0;exitDy=0;edgeStyle=elbowEdgeStyle;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" source="3" target="4" edge="1">
|
||||
<mxGeometry width="50" height="50" relative="1" as="geometry">
|
||||
<mxPoint x="20" y="80.00000000000011" as="sourcePoint"/>
|
||||
<mxPoint x="229.52380952380952" y="-50" as="targetPoint"/>
|
||||
</mxGeometry>
|
||||
</mxCell>
|
||||
<mxCell id="14" value="" style="endArrow=classic;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;edgeStyle=orthogonalEdgeStyle;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" source="18" target="12" edge="1">
|
||||
<mxGeometry width="50" height="50" relative="1" as="geometry">
|
||||
<mxPoint x="155.71428571427464" y="79.99999999999989" as="sourcePoint"/>
|
||||
<mxPoint x="229.52380952380952" y="-50" as="targetPoint"/>
|
||||
</mxGeometry>
|
||||
</mxCell>
|
||||
<mxCell id="18" value="+" style="ellipse;whiteSpace=wrap;html=1;aspect=fixed;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" vertex="1">
|
||||
<mxGeometry x="150.00428571428571" y="75" width="10" height="10" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="19" value="" style="endArrow=classic;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;edgeStyle=orthogonalEdgeStyle;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" source="4" target="18" edge="1">
|
||||
<mxGeometry width="50" height="50" relative="1" as="geometry">
|
||||
<mxPoint x="213.80952380952382" y="410" as="sourcePoint"/>
|
||||
<mxPoint x="145.71428571428578" y="80.00000000000011" as="targetPoint"/>
|
||||
</mxGeometry>
|
||||
</mxCell>
|
||||
<mxCell id="23" value="" style="endArrow=classic;html=1;entryX=0.5;entryY=1;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" target="24" edge="1">
|
||||
<mxGeometry width="50" height="50" relative="1" as="geometry">
|
||||
<mxPoint x="236.85714285714286" y="80" as="sourcePoint"/>
|
||||
<mxPoint x="349.7619047619048" y="85" as="targetPoint"/>
|
||||
<Array as="points">
|
||||
<mxPoint x="236.67904761904765" y="125"/>
|
||||
<mxPoint x="292.38095238095235" y="125"/>
|
||||
<mxPoint x="355" y="125"/>
|
||||
</Array>
|
||||
</mxGeometry>
|
||||
</mxCell>
|
||||
<mxCell id="28" value="" style="edgeStyle=none;html=1;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" source="24" target="7" edge="1">
|
||||
<mxGeometry relative="1" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="24" value="+" style="ellipse;whiteSpace=wrap;html=1;aspect=fixed;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" vertex="1">
|
||||
<mxGeometry x="350.00190476190477" y="75" width="10" height="10" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="25" value="" style="endArrow=classic;html=1;entryX=0.5;entryY=1;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" target="18" edge="1">
|
||||
<mxGeometry width="50" height="50" relative="1" as="geometry">
|
||||
<mxPoint x="34.325581395348834" y="80" as="sourcePoint"/>
|
||||
<mxPoint x="150.71428571428578" y="85" as="targetPoint"/>
|
||||
<Array as="points">
|
||||
<mxPoint x="34.25858250276859" y="130"/>
|
||||
<mxPoint x="89.96048726467328" y="130"/>
|
||||
<mxPoint x="155" y="130"/>
|
||||
</Array>
|
||||
</mxGeometry>
|
||||
</mxCell>
|
||||
<mxCell id="104" value="" style="endArrow=none;dashed=1;html=1;" edge="1" parent="3">
|
||||
<mxGeometry width="50" height="50" relative="1" as="geometry">
|
||||
<mxPoint x="200" y="200" as="sourcePoint"/>
|
||||
<mxPoint x="260.96000000000004" y="110" as="targetPoint"/>
|
||||
</mxGeometry>
|
||||
</mxCell>
|
||||
<mxCell id="36" style="edgeStyle=none;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="32" target="3" edge="1">
|
||||
<mxGeometry relative="1" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="32" value="+" style="rounded=0;whiteSpace=wrap;html=1;fillColor=#f5f5f5;fontColor=#333333;strokeColor=#666666;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
|
||||
<mxGeometry x="90" width="110" height="160" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="33" value="Token Emb" style="rounded=0;whiteSpace=wrap;html=1;strokeColor=#6c8ebf;fillColor=#dae8fc;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
|
||||
<mxGeometry x="95" y="17.5" width="100" height="42.5" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="34" value="Position Emb" style="rounded=0;whiteSpace=wrap;html=1;strokeColor=#9673a6;fillColor=#e1d5e7;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
|
||||
<mxGeometry x="95" y="100" width="100" height="42.5" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="46" style="edgeStyle=none;html=1;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="37" target="40" edge="1">
|
||||
<mxGeometry relative="1" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="37" value="Decoder" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#d5e8d4;strokeColor=#82b366;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
|
||||
<mxGeometry x="690" width="70" height="160" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="38" style="edgeStyle=none;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="7" target="37" edge="1">
|
||||
<mxGeometry relative="1" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="47" style="edgeStyle=none;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="40" target="44" edge="1">
|
||||
<mxGeometry relative="1" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="40" value="Decoder" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#d5e8d4;strokeColor=#82b366;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
|
||||
<mxGeometry x="790" width="70" height="160" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="49" style="edgeStyle=none;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="41" target="42" edge="1">
|
||||
<mxGeometry relative="1" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="41" value="Decoder" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#d5e8d4;strokeColor=#82b366;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
|
||||
<mxGeometry x="950" width="70" height="160" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="52" style="edgeStyle=none;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="42" target="50" edge="1">
|
||||
<mxGeometry relative="1" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="42" value="Decoder" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#d5e8d4;strokeColor=#82b366;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
|
||||
<mxGeometry x="1050" width="70" height="160" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="48" style="edgeStyle=none;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="44" target="41" edge="1">
|
||||
<mxGeometry relative="1" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="44" value=".<div>.</div><div>.</div>" style="rounded=1;whiteSpace=wrap;html=1;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
|
||||
<mxGeometry x="890" y="40" width="30" height="80" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="53" style="edgeStyle=none;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="50" target="51" edge="1">
|
||||
<mxGeometry relative="1" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="50" value="Linear" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#f5f5f5;fontColor=#333333;strokeColor=#666666;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
|
||||
<mxGeometry x="1150" y="5" width="50" height="150" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="51" value="Softmax" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#e1d5e7;strokeColor=#9673a6;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
|
||||
<mxGeometry x="1236" y="5" width="50" height="150" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="54" value="Tokens" style="rounded=1;whiteSpace=wrap;html=1;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
|
||||
<mxGeometry y="40" width="60" height="90" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="55" style="edgeStyle=none;html=1;entryX=-0.025;entryY=0.538;entryDx=0;entryDy=0;entryPerimeter=0;exitX=1;exitY=0.5;exitDx=0;exitDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="54" edge="1">
|
||||
<mxGeometry relative="1" as="geometry">
|
||||
<mxPoint x="42.75" y="84.66941747572821" as="sourcePoint"/>
|
||||
<mxPoint x="90" y="85.33000000000004" as="targetPoint"/>
|
||||
</mxGeometry>
|
||||
</mxCell>
|
||||
<mxCell id="105" value="" style="endArrow=none;dashed=1;html=1;exitX=1;exitY=1;exitDx=0;exitDy=0;" edge="1" parent="107" source="5">
|
||||
<mxGeometry width="50" height="50" relative="1" as="geometry">
|
||||
<mxPoint x="660" y="140" as="sourcePoint"/>
|
||||
<mxPoint x="620" y="190" as="targetPoint"/>
|
||||
</mxGeometry>
|
||||
</mxCell>
|
||||
<mxCell id="106" value="" style="group" vertex="1" connectable="0" parent="107">
|
||||
<mxGeometry x="450" y="195" width="170" height="70" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="100" value="" style="edgeStyle=none;html=1;" edge="1" parent="106" source="93" target="99">
|
||||
<mxGeometry relative="1" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="93" value="Linear" style="rounded=1;whiteSpace=wrap;html=1;rotation=90;container=0;fillColor=#f5f5f5;fontColor=#333333;strokeColor=#666666;" vertex="1" parent="106">
|
||||
<mxGeometry x="-5" y="20" width="70" height="30" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="96" value="" style="endArrow=classic;html=1;entryX=0.5;entryY=1;entryDx=0;entryDy=0;" edge="1" parent="106" target="93">
|
||||
<mxGeometry width="50" height="50" relative="1" as="geometry">
|
||||
<mxPoint y="35" as="sourcePoint"/>
|
||||
<mxPoint y="35.00999999999999" as="targetPoint"/>
|
||||
</mxGeometry>
|
||||
</mxCell>
|
||||
<mxCell id="102" style="edgeStyle=none;html=1;exitX=0.5;exitY=0;exitDx=0;exitDy=0;" edge="1" parent="106" source="98">
|
||||
<mxGeometry relative="1" as="geometry">
|
||||
<mxPoint x="170" y="35.09433962264154" as="targetPoint"/>
|
||||
</mxGeometry>
|
||||
</mxCell>
|
||||
<mxCell id="98" value="Linear" style="rounded=1;whiteSpace=wrap;html=1;rotation=90;container=0;fillColor=#f5f5f5;fontColor=#333333;strokeColor=#666666;" vertex="1" parent="106">
|
||||
<mxGeometry x="100" y="20" width="70" height="30" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="101" value="" style="edgeStyle=none;html=1;" edge="1" parent="106" source="99" target="98">
|
||||
<mxGeometry relative="1" as="geometry"/>
|
||||
</mxCell>
|
||||
<mxCell id="99" value="ReLU" style="rounded=1;whiteSpace=wrap;html=1;rotation=90;container=0;fillColor=#e1d5e7;strokeColor=#9673a6;" vertex="1" parent="106">
|
||||
<mxGeometry x="50" y="20" width="70" height="30" as="geometry"/>
|
||||
</mxCell>
|
||||
</root>
|
||||
</mxGraphModel>
|
||||
</diagram>
|
||||
</mxfile>
|
||||
BIN
assets/models/gpt1-architecture.png
Normal file
BIN
assets/models/gpt1-architecture.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 46 KiB |
BIN
assets/models/gpt1-attention.png
Normal file
BIN
assets/models/gpt1-attention.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 124 KiB |
BIN
assets/models/gpt1-decoder.png
Normal file
BIN
assets/models/gpt1-decoder.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 50 KiB |
BIN
assets/models/gpt1-embeddings.png
Normal file
BIN
assets/models/gpt1-embeddings.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 50 KiB |
BIN
assets/models/gpt1-forward.png
Normal file
BIN
assets/models/gpt1-forward.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 62 KiB |
@@ -14,54 +14,50 @@ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from hf_proxy import HFAdapter, HFTokenizerAdapter, create_hf_pipeline
|
||||
|
||||
from shared.configs import (
|
||||
TEST_PROMPTS, GENERATION_CONFIG, PATHS
|
||||
)
|
||||
from shared.data import (
|
||||
print_experiment_info, ensure_directories, ExperimentLogger
|
||||
)
|
||||
from shared.configs import TEST_PROMPTS, GENERATION_CONFIG, PATHS
|
||||
from shared.data import print_experiment_info, ensure_directories, ExperimentLogger
|
||||
|
||||
|
||||
def load_hf_model_and_tokenizer() -> tuple:
|
||||
"""
|
||||
Загружает модель и токенизатор в формате HuggingFace.
|
||||
|
||||
|
||||
Returns:
|
||||
tuple: (hf_model, hf_tokenizer, model_config)
|
||||
"""
|
||||
# Используем упрощенную версию модели
|
||||
model_path = "checkpoints/hf_simple_trained"
|
||||
tokenizer_path = "checkpoints/hf_simple_tokenizer"
|
||||
|
||||
|
||||
# Проверяем существование файлов
|
||||
if not os.path.exists(model_path):
|
||||
raise FileNotFoundError(
|
||||
f"Модель не найдена: {model_path}\n"
|
||||
f"Сначала обучите модель: uv run python experiments/hf_integration/simple_hf_training.py"
|
||||
)
|
||||
|
||||
|
||||
if not os.path.exists(tokenizer_path):
|
||||
raise FileNotFoundError(
|
||||
f"Токенизатор не найден: {tokenizer_path}"
|
||||
)
|
||||
|
||||
raise FileNotFoundError(f"Токенизатор не найден: {tokenizer_path}")
|
||||
|
||||
# Загружаем адаптированный токенизатор
|
||||
print("🔧 Загрузка адаптированного токенизатора...")
|
||||
hf_tokenizer = HFTokenizerAdapter.from_pretrained(tokenizer_path)
|
||||
print(f"✅ Токенизатор загружен (vocab_size={hf_tokenizer.vocab_size})")
|
||||
|
||||
|
||||
# Загружаем конфигурацию модели
|
||||
import json
|
||||
|
||||
config_path = os.path.join(model_path, "config.json")
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
model_config = json.load(f)
|
||||
|
||||
|
||||
# Загружаем модель через HFAdapter с правильной конфигурацией
|
||||
print("🔧 Загрузка адаптированной модели...")
|
||||
model_bin_path = os.path.join(model_path, "pytorch_model.bin")
|
||||
|
||||
|
||||
# Создаем конфигурацию из сохраненного config.json
|
||||
from hf_proxy import HFAdapterConfig
|
||||
|
||||
hf_config = HFAdapterConfig(
|
||||
vocab_size=model_config["vocab_size"],
|
||||
hidden_size=model_config["hidden_size"],
|
||||
@@ -69,26 +65,28 @@ def load_hf_model_and_tokenizer() -> tuple:
|
||||
num_attention_heads=model_config["num_attention_heads"],
|
||||
max_position_embeddings=model_config["max_position_embeddings"],
|
||||
hidden_dropout_prob=model_config.get("hidden_dropout_prob", 0.1),
|
||||
attention_probs_dropout_prob=model_config.get("attention_probs_dropout_prob", 0.1),
|
||||
attention_probs_dropout_prob=model_config.get(
|
||||
"attention_probs_dropout_prob", 0.1
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
hf_model = HFAdapter.from_pretrained(model_bin_path, hf_config=hf_config)
|
||||
hf_model.eval()
|
||||
print("✅ Модель загружена")
|
||||
|
||||
|
||||
return hf_model, hf_tokenizer, model_config
|
||||
|
||||
|
||||
def test_hf_pipeline(hf_model, hf_tokenizer):
|
||||
"""
|
||||
Тестирует создание HuggingFace pipeline.
|
||||
|
||||
|
||||
Args:
|
||||
hf_model: Адаптированная модель
|
||||
hf_tokenizer: Адаптированный токенизатор
|
||||
"""
|
||||
print("\n🧪 Тестирование HuggingFace pipeline...")
|
||||
|
||||
|
||||
try:
|
||||
# Создаем pipeline
|
||||
pipe = create_hf_pipeline(
|
||||
@@ -97,23 +95,23 @@ def test_hf_pipeline(hf_model, hf_tokenizer):
|
||||
device="cpu",
|
||||
max_length=50,
|
||||
do_sample=True,
|
||||
temperature=0.7
|
||||
temperature=0.7,
|
||||
)
|
||||
|
||||
|
||||
print("✅ HuggingFace pipeline создан")
|
||||
|
||||
|
||||
# Тестируем pipeline
|
||||
test_prompts = TEST_PROMPTS[:3]
|
||||
|
||||
|
||||
for prompt in test_prompts:
|
||||
print(f"\n🔤 Промпт: '{prompt}'")
|
||||
|
||||
|
||||
try:
|
||||
result = pipe(prompt, max_new_tokens=20)
|
||||
print(f"🎯 Результат: {result[0]['generated_text']}")
|
||||
except Exception as e:
|
||||
print(f"❌ Ошибка в pipeline: {e}")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Ошибка создания pipeline: {e}")
|
||||
|
||||
@@ -121,47 +119,49 @@ def test_hf_pipeline(hf_model, hf_tokenizer):
|
||||
def generate_with_hf_model(hf_model, hf_tokenizer, prompt: str, config: dict) -> str:
|
||||
"""
|
||||
Генерирует текст через адаптированную модель HF.
|
||||
|
||||
|
||||
Args:
|
||||
hf_model: Адаптированная модель
|
||||
hf_tokenizer: Адаптированный токенизатор
|
||||
prompt: Входной текст
|
||||
config: Конфигурация генерации
|
||||
|
||||
|
||||
Returns:
|
||||
str: Сгенерированный текст
|
||||
"""
|
||||
print(f"🔤 Промпт: '{prompt}'")
|
||||
print(f"📊 Параметры: max_tokens={config['max_new_tokens']}, "
|
||||
f"temp={config['temperature']}, sample={config['do_sample']}")
|
||||
|
||||
print(
|
||||
f"📊 Параметры: max_tokens={config['max_new_tokens']}, "
|
||||
f"temp={config['temperature']}, sample={config['do_sample']}"
|
||||
)
|
||||
|
||||
# Кодируем через адаптированный токенизатор
|
||||
inputs = hf_tokenizer(prompt, return_tensors="pt")
|
||||
|
||||
|
||||
print(f"🎯 Токены промпта: {inputs['input_ids'].tolist()[0]}")
|
||||
print("🔄 Генерация через HF адаптер...")
|
||||
|
||||
|
||||
# Генерируем через адаптированную модель
|
||||
with torch.no_grad():
|
||||
generated_ids = hf_model.generate(
|
||||
input_ids=inputs['input_ids'],
|
||||
input_ids=inputs["input_ids"],
|
||||
max_new_tokens=config["max_new_tokens"],
|
||||
do_sample=config["do_sample"],
|
||||
temperature=config["temperature"],
|
||||
top_k=config["top_k"],
|
||||
top_p=config["top_p"]
|
||||
top_p=config["top_p"],
|
||||
)
|
||||
|
||||
|
||||
# Декодируем через адаптированный токенизатор
|
||||
generated_text = hf_tokenizer.decode(generated_ids[0], skip_special_tokens=True)
|
||||
|
||||
|
||||
return generated_text
|
||||
|
||||
|
||||
def test_different_hf_strategies(hf_model, hf_tokenizer, prompt: str):
|
||||
"""
|
||||
Тестирует разные стратегии генерации через HF интерфейс.
|
||||
|
||||
|
||||
Args:
|
||||
hf_model: Адаптированная модель
|
||||
hf_tokenizer: Адаптированный токенизатор
|
||||
@@ -169,32 +169,38 @@ def test_different_hf_strategies(hf_model, hf_tokenizer, prompt: str):
|
||||
"""
|
||||
print(f"\n🎭 Сравнение стратегий генерации через HF для промпта: '{prompt}'")
|
||||
print("=" * 70)
|
||||
|
||||
|
||||
strategies = [
|
||||
{"name": "🎯 Жадный поиск", "do_sample": False, "temperature": 1.0},
|
||||
{"name": "🎲 Вероятностная (temp=0.7)", "do_sample": True, "temperature": 0.7},
|
||||
{"name": "🔥 Случайная (temp=1.2)", "do_sample": True, "temperature": 1.2},
|
||||
{"name": "❄️ Детерминированная (temp=0.3)", "do_sample": True, "temperature": 0.3},
|
||||
{
|
||||
"name": "❄️ Детерминированная (temp=0.3)",
|
||||
"do_sample": True,
|
||||
"temperature": 0.3,
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
for strategy in strategies:
|
||||
print(f"\n{strategy['name']}:")
|
||||
try:
|
||||
config = GENERATION_CONFIG.copy()
|
||||
config.update({
|
||||
"do_sample": strategy["do_sample"],
|
||||
"temperature": strategy["temperature"],
|
||||
"max_new_tokens": 20
|
||||
})
|
||||
|
||||
config.update(
|
||||
{
|
||||
"do_sample": strategy["do_sample"],
|
||||
"temperature": strategy["temperature"],
|
||||
"max_new_tokens": 20,
|
||||
}
|
||||
)
|
||||
|
||||
generated = generate_with_hf_model(hf_model, hf_tokenizer, prompt, config)
|
||||
|
||||
|
||||
# Выделяем сгенерированную часть
|
||||
generated_part = generated[len(prompt):]
|
||||
generated_part = generated[len(prompt) :]
|
||||
print(f" 📤 Промпт: '{prompt}'")
|
||||
print(f" 🎯 Сгенерировано: '{generated_part}'")
|
||||
print(f" 📄 Полный текст: '{generated}'")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f" ❌ Ошибка: {e}")
|
||||
|
||||
@@ -202,30 +208,30 @@ def test_different_hf_strategies(hf_model, hf_tokenizer, prompt: str):
|
||||
def analyze_hf_tokenization(hf_tokenizer, texts: list):
|
||||
"""
|
||||
Анализирует токенизацию через адаптированный токенизатор.
|
||||
|
||||
|
||||
Args:
|
||||
hf_tokenizer: Адаптированный токенизатор
|
||||
texts: Список текстов для анализа
|
||||
"""
|
||||
print(f"\n🔍 Анализ токенизации через HF адаптер:")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
for i, text in enumerate(texts):
|
||||
print(f"\nТекст {i+1}: '{text}'")
|
||||
|
||||
|
||||
# Токенизация через адаптер
|
||||
inputs = hf_tokenizer(text, return_tensors="pt")
|
||||
tokens = inputs['input_ids'].tolist()[0]
|
||||
tokens = inputs["input_ids"].tolist()[0]
|
||||
token_strings = hf_tokenizer.tokenize(text)
|
||||
|
||||
|
||||
print(f" Токены (ID): {tokens}")
|
||||
print(f" Токены (текст): {token_strings}")
|
||||
print(f" Количество токенов: {len(tokens)}")
|
||||
|
||||
|
||||
# Декодирование обратно
|
||||
decoded = hf_tokenizer.decode(tokens)
|
||||
print(f" Декодированный: '{decoded}'")
|
||||
|
||||
|
||||
if text == decoded:
|
||||
print(f" ✅ Декодирование корректно")
|
||||
else:
|
||||
@@ -235,51 +241,55 @@ def analyze_hf_tokenization(hf_tokenizer, texts: list):
|
||||
def interactive_hf_generation(hf_model, hf_tokenizer):
|
||||
"""
|
||||
Режим интерактивной генерации через HF интерфейс.
|
||||
|
||||
|
||||
Args:
|
||||
hf_model: Адаптированная модель
|
||||
hf_tokenizer: Адаптированный токенизатор
|
||||
"""
|
||||
print(f"\n💬 Интерактивная генерация через HF (для выхода введите 'exit')")
|
||||
print("-" * 60)
|
||||
|
||||
|
||||
while True:
|
||||
try:
|
||||
user_input = input("\n🔤 Введите промпт: ").strip()
|
||||
|
||||
if user_input.lower() in ['exit', 'quit', 'выход']:
|
||||
|
||||
if user_input.lower() in ["exit", "quit", "выход"]:
|
||||
break
|
||||
|
||||
|
||||
if not user_input:
|
||||
continue
|
||||
|
||||
|
||||
# Запрашиваем параметры
|
||||
try:
|
||||
max_tokens = int(input("📏 Макс. токенов [50]: ") or "50")
|
||||
temperature = float(input("🌡️ Температура [0.7]: ") or "0.7")
|
||||
do_sample_input = input("🎲 Сэмплирование (y/n) [y]: ").lower()
|
||||
do_sample = do_sample_input != 'n'
|
||||
do_sample = do_sample_input != "n"
|
||||
except:
|
||||
max_tokens = 50
|
||||
temperature = 0.7
|
||||
do_sample = True
|
||||
print("⚠️ Использую параметры по умолчанию")
|
||||
|
||||
|
||||
config = GENERATION_CONFIG.copy()
|
||||
config.update({
|
||||
"max_new_tokens": max_tokens,
|
||||
"temperature": temperature,
|
||||
"do_sample": do_sample
|
||||
})
|
||||
|
||||
generated = generate_with_hf_model(hf_model, hf_tokenizer, user_input, config)
|
||||
|
||||
generated_part = generated[len(user_input):]
|
||||
config.update(
|
||||
{
|
||||
"max_new_tokens": max_tokens,
|
||||
"temperature": temperature,
|
||||
"do_sample": do_sample,
|
||||
}
|
||||
)
|
||||
|
||||
generated = generate_with_hf_model(
|
||||
hf_model, hf_tokenizer, user_input, config
|
||||
)
|
||||
|
||||
generated_part = generated[len(user_input) :]
|
||||
print(f"\n🎯 Результат:")
|
||||
print(f" 📤 Промпт: '{user_input}'")
|
||||
print(f" 🎯 Сгенерировано: '{generated_part}'")
|
||||
print(f" 📄 Полный текст: '{generated}'")
|
||||
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n👋 Завершение работы...")
|
||||
break
|
||||
@@ -295,76 +305,79 @@ def main():
|
||||
"model": "GPT через HFAdapter",
|
||||
"tokenizer": "BPE через HFTokenizerAdapter",
|
||||
"инструменты": "HuggingFace pipeline & генерация",
|
||||
"стратегия": "интеграция с HF экосистемой"
|
||||
"стратегия": "интеграция с HF экосистемой",
|
||||
}
|
||||
|
||||
|
||||
print_experiment_info(experiment_name, experiment_config)
|
||||
ensure_directories()
|
||||
logger = ExperimentLogger(experiment_name)
|
||||
|
||||
|
||||
try:
|
||||
# Загружаем модель и токенизатор в HF формате
|
||||
hf_model, hf_tokenizer, model_config = load_hf_model_and_tokenizer()
|
||||
|
||||
|
||||
# === Анализ токенизации ===
|
||||
analysis_texts = [
|
||||
"Искусственный интеллект",
|
||||
"Нейронные сети",
|
||||
"Машинное обучение"
|
||||
"Нейронные сети",
|
||||
"Машинное обучение",
|
||||
]
|
||||
analyze_hf_tokenization(hf_tokenizer, analysis_texts)
|
||||
|
||||
|
||||
# === Тестирование HF pipeline ===
|
||||
test_hf_pipeline(hf_model, hf_tokenizer)
|
||||
|
||||
|
||||
# === Генерация с разными промптами ===
|
||||
print(f"\n🎯 Генерация текста через HF адаптер")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
for i, prompt in enumerate(TEST_PROMPTS):
|
||||
print(f"\n📝 Пример {i+1}/{len(TEST_PROMPTS)}")
|
||||
print("-" * 40)
|
||||
|
||||
|
||||
try:
|
||||
generated = generate_with_hf_model(hf_model, hf_tokenizer, prompt, GENERATION_CONFIG)
|
||||
|
||||
generated = generate_with_hf_model(
|
||||
hf_model, hf_tokenizer, prompt, GENERATION_CONFIG
|
||||
)
|
||||
|
||||
# Выделяем сгенерированную часть
|
||||
generated_part = generated[len(prompt):]
|
||||
|
||||
generated_part = generated[len(prompt) :]
|
||||
|
||||
print(f"📤 Промпт: '{prompt}'")
|
||||
print(f"🎯 Сгенерировано: '{generated_part}'")
|
||||
print(f"📄 Полный текст: '{generated}'")
|
||||
print(f"📏 Длина: {len(generated)} символов")
|
||||
|
||||
|
||||
# Логируем успешную генерацию
|
||||
logger.log_metric(f"hf_generation_length_{i}", len(generated))
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Ошибка при генерации: {e}")
|
||||
continue
|
||||
|
||||
|
||||
# === Сравнение стратегий генерации ===
|
||||
test_prompt = "Искусственный"
|
||||
test_different_hf_strategies(hf_model, hf_tokenizer, test_prompt)
|
||||
|
||||
|
||||
# === Интерактивная генерация ===
|
||||
interactive_hf_generation(hf_model, hf_tokenizer)
|
||||
|
||||
|
||||
# === Сохранение результатов ===
|
||||
logger.save_logs("checkpoints/hf_integration_generation_logs.json")
|
||||
|
||||
|
||||
print(f"\n🎉 Эксперимент с HF интеграцией завершен успешно!")
|
||||
print(f"\n📚 Достигнутая интеграция:")
|
||||
print(f" ✅ Загрузка модели и токенизатора в HF формате")
|
||||
print(f" ✅ Использование HF pipeline")
|
||||
print(f" ✅ Генерация через стандартные HF интерфейсы")
|
||||
print(f" ✅ Совместимость с HF экосистемой")
|
||||
|
||||
|
||||
except FileNotFoundError as e:
|
||||
print(f"❌ {e}")
|
||||
except Exception as e:
|
||||
print(f"❌ Ошибка в эксперименте: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
|
||||
@@ -19,141 +19,139 @@ from llm.tokenizers import BPETokenizer
|
||||
from hf_proxy import HFAdapter, HFTokenizerAdapter
|
||||
|
||||
from shared.configs import (
|
||||
TRAIN_TEXTS, BASE_GPT_CONFIG, BPE_CONFIG,
|
||||
TRAINING_CONFIG, PATHS, TEST_PROMPTS
|
||||
TRAIN_TEXTS,
|
||||
BASE_GPT_CONFIG,
|
||||
BPE_CONFIG,
|
||||
TRAINING_CONFIG,
|
||||
PATHS,
|
||||
TEST_PROMPTS,
|
||||
)
|
||||
|
||||
|
||||
def create_dataset(hf_tokenizer, texts, max_length=128):
|
||||
"""
|
||||
Создает простой датасет для обучения.
|
||||
|
||||
|
||||
Args:
|
||||
hf_tokenizer: Адаптированный токенизатор
|
||||
texts: Список текстов
|
||||
max_length: Максимальная длина последовательности
|
||||
|
||||
|
||||
Returns:
|
||||
list: Список тензоров input_ids
|
||||
"""
|
||||
dataset = []
|
||||
|
||||
|
||||
for text in texts:
|
||||
# Токенизируем текст
|
||||
inputs = hf_tokenizer(
|
||||
text,
|
||||
max_length=max_length,
|
||||
text,
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
padding=False,
|
||||
return_tensors="pt"
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
input_ids = inputs['input_ids'][0]
|
||||
|
||||
|
||||
input_ids = inputs["input_ids"][0]
|
||||
|
||||
# Создаем метки для языкового моделирования
|
||||
labels = input_ids.clone()
|
||||
|
||||
dataset.append({
|
||||
'input_ids': input_ids,
|
||||
'labels': labels
|
||||
})
|
||||
|
||||
|
||||
dataset.append({"input_ids": input_ids, "labels": labels})
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
def manual_training_loop(hf_model, hf_tokenizer, train_texts, val_texts, config):
|
||||
"""
|
||||
Ручной цикл обучения без использования Trainer.
|
||||
|
||||
|
||||
Args:
|
||||
hf_model: Адаптированная модель
|
||||
hf_tokenizer: Адаптированный токенизатор
|
||||
train_texts: Тексты для обучения
|
||||
val_texts: Тексты для валидации
|
||||
config: Конфигурация обучения
|
||||
|
||||
|
||||
Returns:
|
||||
dict: Результаты обучения
|
||||
"""
|
||||
print("🎯 Запуск ручного обучения...")
|
||||
|
||||
|
||||
# Создаем датасеты
|
||||
train_dataset = create_dataset(hf_tokenizer, train_texts)
|
||||
val_dataset = create_dataset(hf_tokenizer, val_texts)
|
||||
|
||||
|
||||
print(f"📊 Данные: {len(train_dataset)} train, {len(val_dataset)} validation")
|
||||
|
||||
|
||||
# Оптимизатор
|
||||
optimizer = torch.optim.AdamW(
|
||||
hf_model.parameters(),
|
||||
lr=config["learning_rate"]
|
||||
)
|
||||
|
||||
optimizer = torch.optim.AdamW(hf_model.parameters(), lr=config["learning_rate"])
|
||||
|
||||
# Функция потерь
|
||||
loss_fn = nn.CrossEntropyLoss()
|
||||
|
||||
|
||||
# Обучение
|
||||
hf_model.train()
|
||||
train_losses = []
|
||||
val_losses = []
|
||||
|
||||
|
||||
for epoch in range(config["num_epochs"]):
|
||||
print(f"\n📅 Эпоха {epoch + 1}/{config['num_epochs']}")
|
||||
|
||||
|
||||
# Обучение
|
||||
epoch_train_loss = 0
|
||||
for i, batch in enumerate(train_dataset):
|
||||
optimizer.zero_grad()
|
||||
|
||||
input_ids = batch['input_ids'].unsqueeze(0) # [1, seq_len]
|
||||
labels = batch['labels'].unsqueeze(0) # [1, seq_len]
|
||||
|
||||
|
||||
input_ids = batch["input_ids"].unsqueeze(0) # [1, seq_len]
|
||||
labels = batch["labels"].unsqueeze(0) # [1, seq_len]
|
||||
|
||||
# Forward pass
|
||||
outputs = hf_model(input_ids=input_ids, labels=labels)
|
||||
loss = outputs.loss
|
||||
|
||||
|
||||
# Backward pass
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
|
||||
epoch_train_loss += loss.item()
|
||||
|
||||
|
||||
if i % 5 == 0:
|
||||
print(f" Batch {i}/{len(train_dataset)}: loss = {loss.item():.4f}")
|
||||
|
||||
|
||||
avg_train_loss = epoch_train_loss / len(train_dataset)
|
||||
train_losses.append(avg_train_loss)
|
||||
print(f" 📊 Средняя train loss: {avg_train_loss:.4f}")
|
||||
|
||||
|
||||
# Валидация
|
||||
hf_model.eval()
|
||||
epoch_val_loss = 0
|
||||
with torch.no_grad():
|
||||
for batch in val_dataset:
|
||||
input_ids = batch['input_ids'].unsqueeze(0)
|
||||
labels = batch['labels'].unsqueeze(0)
|
||||
|
||||
input_ids = batch["input_ids"].unsqueeze(0)
|
||||
labels = batch["labels"].unsqueeze(0)
|
||||
|
||||
outputs = hf_model(input_ids=input_ids, labels=labels)
|
||||
epoch_val_loss += outputs.loss.item()
|
||||
|
||||
|
||||
avg_val_loss = epoch_val_loss / len(val_dataset)
|
||||
val_losses.append(avg_val_loss)
|
||||
print(f" 📊 Средняя val loss: {avg_val_loss:.4f}")
|
||||
|
||||
|
||||
hf_model.train()
|
||||
|
||||
|
||||
return {
|
||||
'train_losses': train_losses,
|
||||
'val_losses': val_losses,
|
||||
'final_train_loss': train_losses[-1],
|
||||
'final_val_loss': val_losses[-1]
|
||||
"train_losses": train_losses,
|
||||
"val_losses": val_losses,
|
||||
"final_train_loss": train_losses[-1],
|
||||
"final_val_loss": val_losses[-1],
|
||||
}
|
||||
|
||||
|
||||
def test_generation_after_training(hf_model, hf_tokenizer, test_prompts):
|
||||
"""
|
||||
Тестирует генерацию после обучения.
|
||||
|
||||
|
||||
Args:
|
||||
hf_model: Обученная модель
|
||||
hf_tokenizer: Токенизатор
|
||||
@@ -161,24 +159,24 @@ def test_generation_after_training(hf_model, hf_tokenizer, test_prompts):
|
||||
"""
|
||||
print("\n🧪 Тестирование генерации после обучения...")
|
||||
hf_model.eval()
|
||||
|
||||
|
||||
for prompt in test_prompts[:3]:
|
||||
print(f"\n🔤 Промпт: '{prompt}'")
|
||||
|
||||
|
||||
try:
|
||||
inputs = hf_tokenizer(prompt, return_tensors="pt")
|
||||
|
||||
|
||||
with torch.no_grad():
|
||||
generated = hf_model.generate(
|
||||
input_ids=inputs['input_ids'],
|
||||
input_ids=inputs["input_ids"],
|
||||
max_new_tokens=20,
|
||||
do_sample=True,
|
||||
temperature=0.8
|
||||
temperature=0.8,
|
||||
)
|
||||
|
||||
|
||||
generated_text = hf_tokenizer.decode(generated[0], skip_special_tokens=True)
|
||||
print(f"🎯 Результат: '{generated_text}'")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Ошибка генерации: {e}")
|
||||
|
||||
@@ -188,96 +186,102 @@ def main():
|
||||
print("=" * 60)
|
||||
print("🚀 УПРОЩЕННОЕ ОБУЧЕНИЕ GPT С HF-PROXY")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
try:
|
||||
# === Подготовка данных ===
|
||||
print("🔧 Подготовка данных...")
|
||||
train_texts = TRAIN_TEXTS[:10] # Используем меньше данных для быстрого тестирования
|
||||
train_texts = TRAIN_TEXTS[
|
||||
:10
|
||||
] # Используем меньше данных для быстрого тестирования
|
||||
val_texts = TRAIN_TEXTS[10:12]
|
||||
|
||||
|
||||
print(f"📊 Данные: {len(train_texts)} train, {len(val_texts)} validation")
|
||||
|
||||
|
||||
# === Подготовка токенизатора ===
|
||||
print("🔧 Подготовка токенизатора...")
|
||||
llm_tokenizer = BPETokenizer()
|
||||
llm_tokenizer.train(
|
||||
texts=train_texts,
|
||||
vocab_size=BPE_CONFIG["vocab_size"],
|
||||
special_tokens=BPE_CONFIG["special_tokens"]
|
||||
special_tokens=BPE_CONFIG["special_tokens"],
|
||||
)
|
||||
|
||||
|
||||
hf_tokenizer = HFTokenizerAdapter(llm_tokenizer)
|
||||
print(f"✅ Токенизатор создан (vocab_size={hf_tokenizer.vocab_size})")
|
||||
|
||||
|
||||
# === Подготовка модели ===
|
||||
print("🔧 Подготовка модели...")
|
||||
model_config = BASE_GPT_CONFIG.copy()
|
||||
model_config["vocab_size"] = hf_tokenizer.vocab_size
|
||||
|
||||
|
||||
llm_model = GPT(model_config)
|
||||
hf_model = HFAdapter.from_llm_model(llm_model)
|
||||
print(f"✅ Модель создана")
|
||||
|
||||
|
||||
# === Тестирование до обучения ===
|
||||
print("\n🧪 Тестирование до обучения...")
|
||||
test_generation_after_training(hf_model, hf_tokenizer, TEST_PROMPTS)
|
||||
|
||||
|
||||
# === Обучение ===
|
||||
print(f"\n🎯 Обучение модели...")
|
||||
training_config = {
|
||||
"learning_rate": TRAINING_CONFIG["learning_rate"],
|
||||
"num_epochs": 2, # Меньше эпох для быстрого тестирования
|
||||
"batch_size": TRAINING_CONFIG["batch_size"]
|
||||
"batch_size": TRAINING_CONFIG["batch_size"],
|
||||
}
|
||||
|
||||
|
||||
results = manual_training_loop(
|
||||
hf_model, hf_tokenizer, train_texts, val_texts, training_config
|
||||
)
|
||||
|
||||
|
||||
print(f"\n📊 Результаты обучения:")
|
||||
print(f" Final train loss: {results['final_train_loss']:.4f}")
|
||||
print(f" Final val loss: {results['final_val_loss']:.4f}")
|
||||
|
||||
|
||||
# === Тестирование после обучения ===
|
||||
print("\n🧪 Тестирование после обучения...")
|
||||
test_generation_after_training(hf_model, hf_tokenizer, TEST_PROMPTS)
|
||||
|
||||
|
||||
# === Сохранение модели ===
|
||||
print(f"\n💾 Сохранение модели...")
|
||||
|
||||
|
||||
# Создаем директории
|
||||
os.makedirs("checkpoints/hf_simple_trained", exist_ok=True)
|
||||
os.makedirs("checkpoints/hf_simple_tokenizer", exist_ok=True)
|
||||
|
||||
|
||||
# Сохраняем токенизатор
|
||||
hf_tokenizer.save_pretrained("checkpoints/hf_simple_tokenizer")
|
||||
print("✅ Токенизатор сохранен")
|
||||
|
||||
|
||||
# Сохраняем модель
|
||||
HFAdapter.save_pretrained(
|
||||
hf_model,
|
||||
"checkpoints/hf_simple_trained",
|
||||
tokenizer=hf_tokenizer
|
||||
hf_model, "checkpoints/hf_simple_trained", tokenizer=hf_tokenizer
|
||||
)
|
||||
print("✅ Модель сохранена")
|
||||
|
||||
|
||||
# Сохраняем результаты
|
||||
results_path = "checkpoints/simple_training_results.json"
|
||||
with open(results_path, 'w', encoding='utf-8') as f:
|
||||
json.dump({
|
||||
'training_config': training_config,
|
||||
'model_config': model_config,
|
||||
'results': results
|
||||
}, f, indent=2, ensure_ascii=False)
|
||||
with open(results_path, "w", encoding="utf-8") as f:
|
||||
json.dump(
|
||||
{
|
||||
"training_config": training_config,
|
||||
"model_config": model_config,
|
||||
"results": results,
|
||||
},
|
||||
f,
|
||||
indent=2,
|
||||
ensure_ascii=False,
|
||||
)
|
||||
print(f"✅ Результаты сохранены в {results_path}")
|
||||
|
||||
|
||||
print(f"\n🎉 Упрощенное обучение завершено успешно!")
|
||||
print(f"\n💡 Для использования обученной модели:")
|
||||
print(f" uv run python experiments/hf_integration/generate_with_hf_tools.py")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Ошибка в эксперименте: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
|
||||
@@ -16,158 +16,163 @@ from llm.tokenizers import BPETokenizer
|
||||
from hf_proxy import HFAdapter, HFTokenizerAdapter
|
||||
|
||||
from shared.configs import (
|
||||
TRAIN_TEXTS, BASE_GPT_CONFIG, BPE_CONFIG,
|
||||
TEST_PROMPTS, GENERATION_CONFIG
|
||||
TRAIN_TEXTS,
|
||||
BASE_GPT_CONFIG,
|
||||
BPE_CONFIG,
|
||||
TEST_PROMPTS,
|
||||
GENERATION_CONFIG,
|
||||
)
|
||||
|
||||
|
||||
def test_basic_hf_integration():
|
||||
"""Тестирует базовую интеграцию hf-proxy."""
|
||||
print("🧪 Тестирование базовой интеграции hf-proxy...")
|
||||
|
||||
|
||||
# === Подготовка токенизатора ===
|
||||
print("1. Подготовка токенизатора...")
|
||||
llm_tokenizer = BPETokenizer()
|
||||
llm_tokenizer.train(
|
||||
texts=TRAIN_TEXTS,
|
||||
vocab_size=BPE_CONFIG["vocab_size"],
|
||||
special_tokens=BPE_CONFIG["special_tokens"]
|
||||
special_tokens=BPE_CONFIG["special_tokens"],
|
||||
)
|
||||
|
||||
|
||||
hf_tokenizer = HFTokenizerAdapter(llm_tokenizer)
|
||||
print(f" ✅ Токенизатор создан (vocab_size={hf_tokenizer.vocab_size})")
|
||||
|
||||
|
||||
# === Подготовка модели ===
|
||||
print("2. Подготовка модели...")
|
||||
model_config = BASE_GPT_CONFIG.copy()
|
||||
model_config["vocab_size"] = hf_tokenizer.vocab_size
|
||||
|
||||
|
||||
llm_model = GPT(model_config)
|
||||
hf_model = HFAdapter.from_llm_model(llm_model)
|
||||
print(f" ✅ Модель создана")
|
||||
|
||||
|
||||
# === Тестирование токенизации ===
|
||||
print("3. Тестирование токенизации...")
|
||||
test_texts = ["Искусственный интеллект", "Нейронные сети"]
|
||||
|
||||
|
||||
for text in test_texts:
|
||||
print(f" 📝 Текст: '{text}'")
|
||||
|
||||
|
||||
# Оригинальный токенизатор
|
||||
original_tokens = llm_tokenizer.encode(text)
|
||||
print(f" Оригинальный: {len(original_tokens)} токенов")
|
||||
|
||||
|
||||
# HF адаптер
|
||||
hf_inputs = hf_tokenizer(text, return_tensors="pt")
|
||||
print(f" HF адаптер: {hf_inputs['input_ids'].shape}")
|
||||
|
||||
|
||||
# Декодирование
|
||||
decoded = hf_tokenizer.decode(hf_inputs['input_ids'][0])
|
||||
decoded = hf_tokenizer.decode(hf_inputs["input_ids"][0])
|
||||
print(f" Декодированный: '{decoded}'")
|
||||
|
||||
|
||||
# === Тестирование forward pass ===
|
||||
print("4. Тестирование forward pass...")
|
||||
for text in test_texts:
|
||||
hf_inputs = hf_tokenizer(text, return_tensors="pt")
|
||||
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = hf_model(**hf_inputs)
|
||||
|
||||
|
||||
print(f" 📝 '{text}' -> logits: {outputs.logits.shape}")
|
||||
|
||||
|
||||
# === Тестирование генерации ===
|
||||
print("5. Тестирование генерации...")
|
||||
hf_model.eval()
|
||||
|
||||
|
||||
for prompt in TEST_PROMPTS[:3]:
|
||||
print(f" 🔤 Промпт: '{prompt}'")
|
||||
|
||||
|
||||
try:
|
||||
inputs = hf_tokenizer(prompt, return_tensors="pt")
|
||||
|
||||
|
||||
with torch.no_grad():
|
||||
generated = hf_model.generate(
|
||||
input_ids=inputs['input_ids'],
|
||||
input_ids=inputs["input_ids"],
|
||||
max_new_tokens=10,
|
||||
do_sample=True,
|
||||
temperature=0.8
|
||||
temperature=0.8,
|
||||
)
|
||||
|
||||
|
||||
generated_text = hf_tokenizer.decode(generated[0], skip_special_tokens=True)
|
||||
print(f" 🎯 Результат: '{generated_text}'")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f" ❌ Ошибка: {e}")
|
||||
|
||||
|
||||
# === Тестирование сохранения/загрузки ===
|
||||
print("6. Тестирование сохранения/загрузки...")
|
||||
try:
|
||||
# Сохраняем токенизатор
|
||||
hf_tokenizer.save_pretrained("test_save/tokenizer")
|
||||
print(" ✅ Токенизатор сохранен")
|
||||
|
||||
|
||||
# Сохраняем модель
|
||||
HFAdapter.save_pretrained(hf_model, "test_save/model", tokenizer=hf_tokenizer)
|
||||
print(" ✅ Модель сохранена")
|
||||
|
||||
|
||||
# Загружаем токенизатор
|
||||
loaded_tokenizer = HFTokenizerAdapter.from_pretrained("test_save/tokenizer")
|
||||
print(f" ✅ Токенизатор загружен (vocab_size={loaded_tokenizer.vocab_size})")
|
||||
|
||||
|
||||
# Загружаем модель
|
||||
model_path = os.path.join("test_save/model", "pytorch_model.bin")
|
||||
loaded_model = HFAdapter.from_pretrained(model_path)
|
||||
print(" ✅ Модель загружена")
|
||||
|
||||
|
||||
# Проверяем работоспособность загруженной модели
|
||||
test_input = hf_tokenizer("Тест", return_tensors="pt")
|
||||
with torch.no_grad():
|
||||
loaded_outputs = loaded_model(**test_input)
|
||||
print(f" ✅ Загруженная модель работает (logits: {loaded_outputs.logits.shape})")
|
||||
|
||||
print(
|
||||
f" ✅ Загруженная модель работает (logits: {loaded_outputs.logits.shape})"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f" ❌ Ошибка сохранения/загрузки: {e}")
|
||||
|
||||
|
||||
print("\n🎉 Базовое тестирование hf-proxy завершено!")
|
||||
|
||||
|
||||
def test_hf_tokenizer_methods():
|
||||
"""Тестирует различные методы HF токенизатора."""
|
||||
print("\n🧪 Тестирование методов HF токенизатора...")
|
||||
|
||||
|
||||
# Создаем токенизатор
|
||||
llm_tokenizer = BPETokenizer()
|
||||
llm_tokenizer.train(
|
||||
texts=TRAIN_TEXTS[:5],
|
||||
vocab_size=500,
|
||||
special_tokens=BPE_CONFIG["special_tokens"]
|
||||
special_tokens=BPE_CONFIG["special_tokens"],
|
||||
)
|
||||
|
||||
|
||||
hf_tokenizer = HFTokenizerAdapter(llm_tokenizer)
|
||||
|
||||
|
||||
test_text = "Искусственный интеллект и машинное обучение"
|
||||
|
||||
|
||||
# Тестируем разные методы
|
||||
print("1. Метод __call__:")
|
||||
result = hf_tokenizer(test_text, return_tensors="pt")
|
||||
print(f" Результат: {result}")
|
||||
|
||||
|
||||
print("2. Метод encode:")
|
||||
encoded = hf_tokenizer.encode(test_text)
|
||||
print(f" Закодировано: {encoded}")
|
||||
|
||||
|
||||
print("3. Метод decode:")
|
||||
decoded = hf_tokenizer.decode(encoded)
|
||||
print(f" Декодировано: '{decoded}'")
|
||||
|
||||
|
||||
print("4. Метод tokenize:")
|
||||
tokens = hf_tokenizer.tokenize(test_text)
|
||||
print(f" Токены: {tokens}")
|
||||
|
||||
|
||||
print("5. Метод get_vocab:")
|
||||
vocab = hf_tokenizer.get_vocab()
|
||||
print(f" Размер словаря: {len(vocab)}")
|
||||
|
||||
|
||||
print("✅ Все методы токенизатора работают!")
|
||||
|
||||
|
||||
@@ -176,14 +181,14 @@ def main():
|
||||
print("=" * 60)
|
||||
print("🧪 ТЕСТИРОВАНИЕ HF-PROXY")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
try:
|
||||
# Тестируем базовую интеграцию
|
||||
test_basic_hf_integration()
|
||||
|
||||
|
||||
# Тестируем методы токенизатора
|
||||
test_hf_tokenizer_methods()
|
||||
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("🎉 ВСЕ ТЕСТЫ ПРОЙДЕНЫ УСПЕШНО!")
|
||||
print("=" * 60)
|
||||
@@ -195,10 +200,11 @@ def main():
|
||||
print(" ✅ Генерация текста")
|
||||
print(" ✅ Сохранение и загрузка моделей")
|
||||
print(" ✅ Все методы HF токенизатора")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n❌ Ошибка в тестировании: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
|
||||
@@ -17,28 +17,34 @@ from llm.tokenizers import BPETokenizer
|
||||
from hf_proxy import HFAdapter, HFTokenizerAdapter
|
||||
|
||||
from shared.configs import (
|
||||
TRAIN_TEXTS, BASE_GPT_CONFIG, BPE_CONFIG,
|
||||
TRAINING_CONFIG, PATHS, TEST_PROMPTS
|
||||
TRAIN_TEXTS,
|
||||
BASE_GPT_CONFIG,
|
||||
BPE_CONFIG,
|
||||
TRAINING_CONFIG,
|
||||
PATHS,
|
||||
TEST_PROMPTS,
|
||||
)
|
||||
from shared.data import (
|
||||
load_training_data, ensure_directories,
|
||||
print_experiment_info, ExperimentLogger
|
||||
load_training_data,
|
||||
ensure_directories,
|
||||
print_experiment_info,
|
||||
ExperimentLogger,
|
||||
)
|
||||
|
||||
|
||||
def setup_hf_training():
|
||||
"""
|
||||
Настраивает окружение для обучения через HuggingFace Trainer.
|
||||
|
||||
|
||||
Returns:
|
||||
tuple: (hf_model, hf_tokenizer, llm_tokenizer, model_config)
|
||||
"""
|
||||
print("🔧 Настройка HuggingFace обучения...")
|
||||
|
||||
|
||||
# === Подготовка данных ===
|
||||
train_texts, val_texts = load_training_data()
|
||||
print(f"📊 Данные: {len(train_texts)} train, {len(val_texts)} validation")
|
||||
|
||||
|
||||
# === Обучение/загрузка токенизатора ===
|
||||
if os.path.exists(PATHS["bpe_tokenizer"]):
|
||||
print("📝 Загрузка BPE токенизатора...")
|
||||
@@ -50,55 +56,55 @@ def setup_hf_training():
|
||||
llm_tokenizer.train(
|
||||
texts=TRAIN_TEXTS,
|
||||
vocab_size=BPE_CONFIG["vocab_size"],
|
||||
special_tokens=BPE_CONFIG["special_tokens"]
|
||||
special_tokens=BPE_CONFIG["special_tokens"],
|
||||
)
|
||||
llm_tokenizer.save(PATHS["bpe_tokenizer"])
|
||||
print(f"✅ Токенизатор обучен и сохранен")
|
||||
|
||||
|
||||
# === Создание адаптера токенизатора ===
|
||||
print("🔧 Создание адаптера HuggingFace для токенизатора...")
|
||||
hf_tokenizer = HFTokenizerAdapter(llm_tokenizer)
|
||||
print(f"✅ Адаптер токенизатора создан")
|
||||
|
||||
|
||||
# === Инициализация модели ===
|
||||
model_config = BASE_GPT_CONFIG.copy()
|
||||
model_config["vocab_size"] = llm_tokenizer.get_vocab_size()
|
||||
|
||||
|
||||
print("🔧 Создание GPT модели...")
|
||||
llm_model = GPT(model_config)
|
||||
|
||||
|
||||
# === Создание адаптера модели ===
|
||||
print("🔧 Создание адаптера HuggingFace для модели...")
|
||||
hf_model = HFAdapter.from_llm_model(llm_model)
|
||||
print(f"✅ Адаптер модели создан")
|
||||
|
||||
|
||||
return hf_model, hf_tokenizer, llm_tokenizer, model_config, train_texts, val_texts
|
||||
|
||||
|
||||
def test_hf_integration(hf_model, hf_tokenizer, llm_tokenizer):
|
||||
"""
|
||||
Тестирует интеграцию с HuggingFace инструментами.
|
||||
|
||||
|
||||
Args:
|
||||
hf_model: Адаптированная модель
|
||||
hf_tokenizer: Адаптированный токенизатор
|
||||
llm_tokenizer: Оригинальный токенизатор
|
||||
"""
|
||||
print("\n🧪 Тестирование интеграции с HuggingFace...")
|
||||
|
||||
|
||||
test_texts = ["Искусственный интеллект", "Нейронные сети"]
|
||||
|
||||
|
||||
for text in test_texts:
|
||||
print(f"\n🔤 Текст: '{text}'")
|
||||
|
||||
|
||||
# Тестируем адаптированный токенизатор
|
||||
hf_inputs = hf_tokenizer(text, return_tensors="pt")
|
||||
print(f" HF токенизатор: {hf_inputs['input_ids'].shape}")
|
||||
|
||||
|
||||
# Тестируем оригинальный токенизатор для сравнения
|
||||
original_tokens = llm_tokenizer.encode(text)
|
||||
print(f" Оригинальный токенизатор: {len(original_tokens)} токенов")
|
||||
|
||||
|
||||
# Тестируем forward pass через адаптированную модель
|
||||
try:
|
||||
with torch.no_grad():
|
||||
@@ -114,28 +120,35 @@ def main():
|
||||
experiment_name = "Обучение GPT через HF Trainer (с hf-proxy)"
|
||||
experiment_config = {
|
||||
"model": "GPT через HFAdapter",
|
||||
"tokenizer": "BPE через HFTokenizerAdapter",
|
||||
"tokenizer": "BPE через HFTokenizerAdapter",
|
||||
"trainer": "HuggingFace Trainer",
|
||||
"vocab_size": BPE_CONFIG["vocab_size"],
|
||||
"training_epochs": TRAINING_CONFIG["num_epochs"]
|
||||
"training_epochs": TRAINING_CONFIG["num_epochs"],
|
||||
}
|
||||
|
||||
|
||||
print_experiment_info(experiment_name, experiment_config)
|
||||
ensure_directories()
|
||||
logger = ExperimentLogger(experiment_name)
|
||||
|
||||
|
||||
try:
|
||||
# Настраиваем обучение
|
||||
hf_model, hf_tokenizer, llm_tokenizer, model_config, train_texts, val_texts = setup_hf_training()
|
||||
|
||||
(
|
||||
hf_model,
|
||||
hf_tokenizer,
|
||||
llm_tokenizer,
|
||||
model_config,
|
||||
train_texts,
|
||||
val_texts,
|
||||
) = setup_hf_training()
|
||||
|
||||
# Тестируем интеграцию
|
||||
test_hf_integration(hf_model, hf_tokenizer, llm_tokenizer)
|
||||
|
||||
|
||||
# === Подготовка датасетов HuggingFace ===
|
||||
print(f"\n📊 Подготовка датасетов HuggingFace...")
|
||||
|
||||
|
||||
from datasets import Dataset
|
||||
|
||||
|
||||
def tokenize_function(examples):
|
||||
"""Функция токенизации для HF datasets."""
|
||||
# Используем адаптированный токенизатор
|
||||
@@ -147,11 +160,11 @@ def main():
|
||||
)
|
||||
tokenized["labels"] = tokenized["input_ids"].copy()
|
||||
return tokenized
|
||||
|
||||
|
||||
# Создаем датасеты
|
||||
train_dataset = Dataset.from_dict({"text": train_texts})
|
||||
val_dataset = Dataset.from_dict({"text": val_texts})
|
||||
|
||||
|
||||
# Токенизируем
|
||||
train_dataset = train_dataset.map(
|
||||
tokenize_function,
|
||||
@@ -163,26 +176,26 @@ def main():
|
||||
batched=True,
|
||||
remove_columns=val_dataset.column_names,
|
||||
)
|
||||
|
||||
|
||||
print(f" Train датасет: {len(train_dataset)} примеров")
|
||||
print(f" Validation датасет: {len(val_dataset)} примеров")
|
||||
|
||||
|
||||
# === Настройка HuggingFace Trainer ===
|
||||
print(f"\n🔧 Настройка HuggingFace Trainer...")
|
||||
|
||||
|
||||
from transformers import (
|
||||
Trainer,
|
||||
Trainer,
|
||||
TrainingArguments,
|
||||
DataCollatorForLanguageModeling
|
||||
DataCollatorForLanguageModeling,
|
||||
)
|
||||
|
||||
|
||||
# Data collator для языкового моделирования
|
||||
data_collator = DataCollatorForLanguageModeling(
|
||||
tokenizer=hf_tokenizer,
|
||||
mlm=False,
|
||||
pad_to_multiple_of=8,
|
||||
)
|
||||
|
||||
|
||||
# Аргументы обучения
|
||||
training_args = TrainingArguments(
|
||||
output_dir=PATHS["hf_model"],
|
||||
@@ -204,7 +217,7 @@ def main():
|
||||
dataloader_pin_memory=False,
|
||||
report_to=None,
|
||||
)
|
||||
|
||||
|
||||
# Создаем Trainer
|
||||
trainer = Trainer(
|
||||
model=hf_model,
|
||||
@@ -213,84 +226,87 @@ def main():
|
||||
eval_dataset=val_dataset,
|
||||
data_collator=data_collator,
|
||||
)
|
||||
|
||||
|
||||
print("✅ HuggingFace Trainer настроен")
|
||||
|
||||
|
||||
# === Запуск обучения ===
|
||||
print(f"\n🎯 Запуск обучения через HuggingFace Trainer...")
|
||||
|
||||
|
||||
train_result = trainer.train()
|
||||
|
||||
|
||||
# Сохраняем лучшую модель
|
||||
trainer.save_model()
|
||||
hf_tokenizer.save_pretrained(PATHS["hf_model"])
|
||||
|
||||
|
||||
print("✅ Обучение завершено успешно!")
|
||||
print(f"📊 Final train loss: {train_result.metrics['train_loss']:.4f}")
|
||||
|
||||
|
||||
if "eval_loss" in train_result.metrics:
|
||||
print(f"📊 Final eval loss: {train_result.metrics['eval_loss']:.4f}")
|
||||
|
||||
|
||||
# === Сохранение через hf-proxy ===
|
||||
print(f"\n💾 Сохранение через hf-proxy...")
|
||||
|
||||
|
||||
from hf_proxy import convert_to_hf_format
|
||||
|
||||
|
||||
# Сохраняем токенизатор в HF формате
|
||||
hf_tokenizer_dir = PATHS["hf_tokenizer"]
|
||||
hf_tokenizer.save_pretrained(hf_tokenizer_dir)
|
||||
|
||||
|
||||
# Сохраняем модель через hf-proxy
|
||||
hf_proxy_dir = PATHS["hf_proxy_model"]
|
||||
HFAdapter.save_pretrained(hf_model, hf_proxy_dir, tokenizer=hf_tokenizer)
|
||||
|
||||
|
||||
print(f"✅ Модель сохранена в HF формате:")
|
||||
print(f" - {PATHS['hf_model']}: стандартный HF формат")
|
||||
print(f" - {hf_proxy_dir}: через hf-proxy")
|
||||
print(f" - {hf_tokenizer_dir}: токенизатор в HF формате")
|
||||
|
||||
|
||||
# === Тестирование генерации ===
|
||||
print(f"\n🧪 Тестирование генерации после обучения...")
|
||||
hf_model.eval()
|
||||
|
||||
|
||||
for prompt in TEST_PROMPTS[:3]:
|
||||
print(f"\n🔤 Промпт: '{prompt}'")
|
||||
|
||||
|
||||
try:
|
||||
inputs = hf_tokenizer(prompt, return_tensors="pt")
|
||||
|
||||
|
||||
with torch.no_grad():
|
||||
generated = hf_model.generate(
|
||||
input_ids=inputs['input_ids'],
|
||||
input_ids=inputs["input_ids"],
|
||||
max_new_tokens=20,
|
||||
do_sample=True,
|
||||
temperature=0.8
|
||||
temperature=0.8,
|
||||
)
|
||||
|
||||
generated_text = hf_tokenizer.decode(generated[0], skip_special_tokens=True)
|
||||
|
||||
generated_text = hf_tokenizer.decode(
|
||||
generated[0], skip_special_tokens=True
|
||||
)
|
||||
print(f"🎯 Результат: '{generated_text}'")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Ошибка генерации: {e}")
|
||||
|
||||
|
||||
# === Сохранение результатов ===
|
||||
results = {
|
||||
"experiment": experiment_name,
|
||||
"model_config": model_config,
|
||||
"training_config": TRAINING_CONFIG,
|
||||
"final_loss": train_result.metrics.get('train_loss', 'N/A'),
|
||||
"eval_loss": train_result.metrics.get('eval_loss', 'N/A')
|
||||
"final_loss": train_result.metrics.get("train_loss", "N/A"),
|
||||
"eval_loss": train_result.metrics.get("eval_loss", "N/A"),
|
||||
}
|
||||
|
||||
|
||||
logger.save_logs("checkpoints/hf_integration_training_logs.json")
|
||||
|
||||
|
||||
print(f"\n🎉 Эксперимент с HF интеграцией завершен успешно!")
|
||||
print(f"\n💡 Для использования обученной модели:")
|
||||
print(f" uv run python experiments/hf_integration/generate_with_hf_tools.py")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Ошибка в эксперименте: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
|
||||
19
experiments/llm_only/configs/gemma_generate.json
Normal file
19
experiments/llm_only/configs/gemma_generate.json
Normal file
@@ -0,0 +1,19 @@
|
||||
{
|
||||
"bpe_tokenizer": "checkpoints/bpe_tokenizer.json",
|
||||
"test_prompts": [
|
||||
"Open weights",
|
||||
"The Llama model is",
|
||||
"Efficient transformers"
|
||||
],
|
||||
"model_config_path": "checkpoints/gemma-bpe/config.json",
|
||||
"model_weights": "checkpoints/gemma-bpe/model.pt",
|
||||
"generation": {
|
||||
"max_new_tokens": 40,
|
||||
"temperature": 0.8,
|
||||
"do_sample": true,
|
||||
"top_k": null,
|
||||
"top_p": null
|
||||
},
|
||||
"log_path": "checkpoints/gemma_only_generation_logs.json"
|
||||
}
|
||||
|
||||
28
experiments/llm_only/configs/gemma_train.json
Normal file
28
experiments/llm_only/configs/gemma_train.json
Normal file
@@ -0,0 +1,28 @@
|
||||
{
|
||||
"bpe_tokenizer": "checkpoints/bpe_tokenizer.json",
|
||||
"bpe_vocab_size": 1000,
|
||||
"bpe_special_tokens": ["<pad>", "<unk>", "<bos>", "<eos>"],
|
||||
"test_prompts": ["Open source AI", "What is Llama?"],
|
||||
"model_config": {
|
||||
"vocab_size": null,
|
||||
"embed_dim": 256,
|
||||
"num_q_heads": 4,
|
||||
"num_kv_heads": 2,
|
||||
"head_size": 64,
|
||||
"num_layers": 4,
|
||||
"max_position_embeddings": 512,
|
||||
"num_experts": 8,
|
||||
"top_k_experts": 2,
|
||||
"window_size": 16,
|
||||
"dropout": 0.1
|
||||
},
|
||||
"model_weights": "checkpoints/gemma-bpe/model.pt",
|
||||
"model_config_path": "checkpoints/gemma-bpe/config.json",
|
||||
"training": {
|
||||
"learning_rate": 0.0003,
|
||||
"batch_size": 2,
|
||||
"num_epochs": 3,
|
||||
"warmup_steps": 50
|
||||
},
|
||||
"log_path": "checkpoints/gemma_only_training_logs.json"
|
||||
}
|
||||
19
experiments/llm_only/configs/gpt2_generate.json
Normal file
19
experiments/llm_only/configs/gpt2_generate.json
Normal file
@@ -0,0 +1,19 @@
|
||||
{
|
||||
"bpe_tokenizer": "checkpoints/bpe_tokenizer.json",
|
||||
"test_prompts": [
|
||||
"Нейронные сети",
|
||||
"Обработка естественного языка",
|
||||
"GPT-2 — это"
|
||||
],
|
||||
"model_config_path": "checkpoints/gpt2-bpe/config.json",
|
||||
"model_weights": "checkpoints/gpt2-bpe/model.pt",
|
||||
"generation": {
|
||||
"max_new_tokens": 40,
|
||||
"temperature": 0.8,
|
||||
"do_sample": true,
|
||||
"top_k": null,
|
||||
"top_p": null
|
||||
},
|
||||
"log_path": "checkpoints/llm_only_generation_logs.json"
|
||||
}
|
||||
|
||||
23
experiments/llm_only/configs/gpt2_train.json
Normal file
23
experiments/llm_only/configs/gpt2_train.json
Normal file
@@ -0,0 +1,23 @@
|
||||
{
|
||||
"bpe_tokenizer": "checkpoints/bpe_tokenizer.json",
|
||||
"bpe_vocab_size": 1000,
|
||||
"bpe_special_tokens": ["<pad>", "<unk>", "<bos>", "<eos>"],
|
||||
"test_prompts": ["Искусственный интеллект", "Python — это"],
|
||||
"model_config": {
|
||||
"vocab_size": null,
|
||||
"embed_dim": 256,
|
||||
"num_heads": 4,
|
||||
"num_layers": 4,
|
||||
"max_position_embeddings": 128,
|
||||
"dropout": 0.1
|
||||
},
|
||||
"model_weights": "checkpoints/gpt2-bpe/model.pt",
|
||||
"model_config_path": "checkpoints/gpt2-bpe/config.json",
|
||||
"training": {
|
||||
"learning_rate": 0.0003,
|
||||
"batch_size": 2,
|
||||
"num_epochs": 3,
|
||||
"warmup_steps": 50
|
||||
},
|
||||
"log_path": "checkpoints/gpt2_only_training_logs.json"
|
||||
}
|
||||
19
experiments/llm_only/configs/gpt_generate.json
Normal file
19
experiments/llm_only/configs/gpt_generate.json
Normal file
@@ -0,0 +1,19 @@
|
||||
{
|
||||
"bpe_tokenizer": "checkpoints/bpe_tokenizer.json",
|
||||
"test_prompts": [
|
||||
"The neural network",
|
||||
"Transformer architecture",
|
||||
"GPT models are"
|
||||
],
|
||||
"model_config_path": "checkpoints/gpt-bpe/config.json",
|
||||
"model_weights": "checkpoints/gpt-bpe/model.pt",
|
||||
"generation": {
|
||||
"max_new_tokens": 40,
|
||||
"temperature": 0.8,
|
||||
"do_sample": true,
|
||||
"top_k": null,
|
||||
"top_p": null
|
||||
},
|
||||
"log_path": "checkpoints/llm_only_generation_logs.json"
|
||||
}
|
||||
|
||||
23
experiments/llm_only/configs/gpt_train.json
Normal file
23
experiments/llm_only/configs/gpt_train.json
Normal file
@@ -0,0 +1,23 @@
|
||||
{
|
||||
"bpe_tokenizer": "checkpoints/bpe_tokenizer.json",
|
||||
"bpe_vocab_size": 1000,
|
||||
"bpe_special_tokens": ["<pad>", "<unk>", "<bos>", "<eos>"],
|
||||
"test_prompts": ["GPT language model", "Machine learning basics"],
|
||||
"model_config": {
|
||||
"vocab_size": null,
|
||||
"embed_dim": 256,
|
||||
"num_heads": 4,
|
||||
"num_layers": 4,
|
||||
"max_position_embeddings": 128,
|
||||
"dropout": 0.1
|
||||
},
|
||||
"model_weights": "checkpoints/gpt-bpe/model.pt",
|
||||
"model_config_path": "checkpoints/gpt-bpe/config.json",
|
||||
"training": {
|
||||
"learning_rate": 0.0003,
|
||||
"batch_size": 2,
|
||||
"num_epochs": 3,
|
||||
"warmup_steps": 50
|
||||
},
|
||||
"log_path": "checkpoints/gpt_only_training_logs.json"
|
||||
}
|
||||
19
experiments/llm_only/configs/llama_generate.json
Normal file
19
experiments/llm_only/configs/llama_generate.json
Normal file
@@ -0,0 +1,19 @@
|
||||
{
|
||||
"bpe_tokenizer": "checkpoints/bpe_tokenizer.json",
|
||||
"test_prompts": [
|
||||
"Open weights",
|
||||
"The Llama model is",
|
||||
"Efficient transformers"
|
||||
],
|
||||
"model_config_path": "checkpoints/llama-bpe/config.json",
|
||||
"model_weights": "checkpoints/llama-bpe/model.pt",
|
||||
"generation": {
|
||||
"max_new_tokens": 40,
|
||||
"temperature": 0.8,
|
||||
"do_sample": true,
|
||||
"top_k": null,
|
||||
"top_p": null
|
||||
},
|
||||
"log_path": "checkpoints/llm_only_generation_logs.json"
|
||||
}
|
||||
|
||||
23
experiments/llm_only/configs/llama_train.json
Normal file
23
experiments/llm_only/configs/llama_train.json
Normal file
@@ -0,0 +1,23 @@
|
||||
{
|
||||
"bpe_tokenizer": "checkpoints/bpe_tokenizer.json",
|
||||
"bpe_vocab_size": 1000,
|
||||
"bpe_special_tokens": ["<pad>", "<unk>", "<bos>", "<eos>"],
|
||||
"test_prompts": ["Open source AI", "What is Llama?"],
|
||||
"model_config": {
|
||||
"vocab_size": null,
|
||||
"embed_dim": 256,
|
||||
"num_heads": 4,
|
||||
"num_layers": 4,
|
||||
"max_position_embeddings": 128,
|
||||
"dropout": 0.1
|
||||
},
|
||||
"model_weights": "checkpoints/llama-bpe/model.pt",
|
||||
"model_config_path": "checkpoints/llama-bpe/config.json",
|
||||
"training": {
|
||||
"learning_rate": 0.0003,
|
||||
"batch_size": 2,
|
||||
"num_epochs": 3,
|
||||
"warmup_steps": 50
|
||||
},
|
||||
"log_path": "checkpoints/llama_only_training_logs.json"
|
||||
}
|
||||
19
experiments/llm_only/configs/mistral_generate.json
Normal file
19
experiments/llm_only/configs/mistral_generate.json
Normal file
@@ -0,0 +1,19 @@
|
||||
{
|
||||
"bpe_tokenizer": "checkpoints/bpe_tokenizer.json",
|
||||
"test_prompts": [
|
||||
"Open weights",
|
||||
"The Llama model is",
|
||||
"Efficient transformers"
|
||||
],
|
||||
"model_config_path": "checkpoints/mistral-bpe/config.json",
|
||||
"model_weights": "checkpoints/mistral-bpe/model.pt",
|
||||
"generation": {
|
||||
"max_new_tokens": 40,
|
||||
"temperature": 0.8,
|
||||
"do_sample": true,
|
||||
"top_k": null,
|
||||
"top_p": null
|
||||
},
|
||||
"log_path": "checkpoints/mistral_only_generation_logs.json"
|
||||
}
|
||||
|
||||
26
experiments/llm_only/configs/mistral_train.json
Normal file
26
experiments/llm_only/configs/mistral_train.json
Normal file
@@ -0,0 +1,26 @@
|
||||
{
|
||||
"bpe_tokenizer": "checkpoints/bpe_tokenizer.json",
|
||||
"bpe_vocab_size": 1000,
|
||||
"bpe_special_tokens": ["<pad>", "<unk>", "<bos>", "<eos>"],
|
||||
"test_prompts": ["Open source AI", "What is Llama?"],
|
||||
"model_config": {
|
||||
"vocab_size": null,
|
||||
"embed_dim": 256,
|
||||
"num_q_heads": 4,
|
||||
"num_kv_heads": 2,
|
||||
"head_size": 64,
|
||||
"num_layers": 4,
|
||||
"max_position_embeddings": 512,
|
||||
"window_size": 16,
|
||||
"dropout": 0.1
|
||||
},
|
||||
"model_weights": "checkpoints/mistral-bpe/model.pt",
|
||||
"model_config_path": "checkpoints/mistral-bpe/config.json",
|
||||
"training": {
|
||||
"learning_rate": 0.0003,
|
||||
"batch_size": 2,
|
||||
"num_epochs": 3,
|
||||
"warmup_steps": 50
|
||||
},
|
||||
"log_path": "checkpoints/mistral_only_training_logs.json"
|
||||
}
|
||||
19
experiments/llm_only/configs/mixtral_generate.json
Normal file
19
experiments/llm_only/configs/mixtral_generate.json
Normal file
@@ -0,0 +1,19 @@
|
||||
{
|
||||
"bpe_tokenizer": "checkpoints/bpe_tokenizer.json",
|
||||
"test_prompts": [
|
||||
"Open weights",
|
||||
"The Llama model is",
|
||||
"Efficient transformers"
|
||||
],
|
||||
"model_config_path": "checkpoints/mixtral-bpe/config.json",
|
||||
"model_weights": "checkpoints/mixtral-bpe/model.pt",
|
||||
"generation": {
|
||||
"max_new_tokens": 40,
|
||||
"temperature": 0.8,
|
||||
"do_sample": true,
|
||||
"top_k": null,
|
||||
"top_p": null
|
||||
},
|
||||
"log_path": "checkpoints/mixtral_only_generation_logs.json"
|
||||
}
|
||||
|
||||
28
experiments/llm_only/configs/mixtral_train.json
Normal file
28
experiments/llm_only/configs/mixtral_train.json
Normal file
@@ -0,0 +1,28 @@
|
||||
{
|
||||
"bpe_tokenizer": "checkpoints/bpe_tokenizer.json",
|
||||
"bpe_vocab_size": 1000,
|
||||
"bpe_special_tokens": ["<pad>", "<unk>", "<bos>", "<eos>"],
|
||||
"test_prompts": ["Open source AI", "What is Llama?"],
|
||||
"model_config": {
|
||||
"vocab_size": null,
|
||||
"embed_dim": 256,
|
||||
"num_q_heads": 4,
|
||||
"num_kv_heads": 2,
|
||||
"head_size": 64,
|
||||
"num_layers": 4,
|
||||
"max_position_embeddings": 512,
|
||||
"num_experts": 8,
|
||||
"top_k_experts": 2,
|
||||
"window_size": 16,
|
||||
"dropout": 0.1
|
||||
},
|
||||
"model_weights": "checkpoints/mixtral-bpe/model.pt",
|
||||
"model_config_path": "checkpoints/mixtral-bpe/config.json",
|
||||
"training": {
|
||||
"learning_rate": 0.0003,
|
||||
"batch_size": 2,
|
||||
"num_epochs": 3,
|
||||
"warmup_steps": 50
|
||||
},
|
||||
"log_path": "checkpoints/mixtral_only_training_logs.json"
|
||||
}
|
||||
@@ -1,313 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Experiment: generate_gpt_bpe.py
|
||||
Description: Генерация текста обученной GPT моделью с BPE токенизатором.
|
||||
Использует только библиотеку llm без зависимостей от HuggingFace.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import os
|
||||
import sys
|
||||
|
||||
# Добавляем путь к shared модулям
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from llm.models.gpt import GPT2
|
||||
from llm.tokenizers import BPETokenizer
|
||||
|
||||
from shared.configs import (
|
||||
BASE_GPT_CONFIG, TEST_PROMPTS, GENERATION_CONFIG, PATHS
|
||||
)
|
||||
from shared.data import (
|
||||
print_experiment_info, ensure_directories, ExperimentLogger
|
||||
)
|
||||
|
||||
|
||||
def load_model_and_tokenizer() -> tuple:
|
||||
"""
|
||||
Загружает обученную модель и токенизатор.
|
||||
|
||||
Returns:
|
||||
tuple: (модель, токенизатор, конфигурация)
|
||||
"""
|
||||
# Проверяем существование файлов
|
||||
if not os.path.exists(PATHS["gpt_bpe_model"]):
|
||||
raise FileNotFoundError(
|
||||
f"Модель не найдена: {PATHS['gpt_bpe_model']}\n"
|
||||
f"Сначала обучите модель: uv run python experiments/llm_only/train_gpt_bpe.py"
|
||||
)
|
||||
|
||||
if not os.path.exists(PATHS["bpe_tokenizer"]):
|
||||
raise FileNotFoundError(
|
||||
f"Токенизатор не найден: {PATHS['bpe_tokenizer']}"
|
||||
)
|
||||
|
||||
# Загружаем конфигурацию модели
|
||||
import json
|
||||
with open(PATHS["gpt_bpe_config"], 'r', encoding='utf-8') as f:
|
||||
model_config = json.load(f)
|
||||
|
||||
# Загружаем токенизатор
|
||||
print("🔧 Загрузка BPE токенизатора...")
|
||||
tokenizer = BPETokenizer.load(PATHS["bpe_tokenizer"])
|
||||
print(f"✅ Токенизатор загружен (vocab_size={tokenizer.get_vocab_size()})")
|
||||
|
||||
# Загружаем модель
|
||||
print("🔧 Загрузка GPT2 модели...")
|
||||
model = GPT2(model_config)
|
||||
model.load_state_dict(torch.load(PATHS["gpt_bpe_model"], map_location='cpu'))
|
||||
model.eval()
|
||||
print("✅ Модель загружена")
|
||||
|
||||
return model, tokenizer, model_config
|
||||
|
||||
|
||||
def generate_text(
|
||||
model: GPT2,
|
||||
tokenizer: BPETokenizer,
|
||||
prompt: str,
|
||||
config: dict
|
||||
) -> str:
|
||||
"""
|
||||
Генерирует текст на основе промпта.
|
||||
|
||||
Args:
|
||||
model: Обученная GPT модель
|
||||
tokenizer: BPE токенизатор
|
||||
prompt: Входной текст
|
||||
config: Конфигурация генерации
|
||||
|
||||
Returns:
|
||||
str: Сгенерированный текст
|
||||
"""
|
||||
print(f"🔤 Промпт: '{prompt}'")
|
||||
print(f"📊 Параметры: max_tokens={config['max_new_tokens']}, "
|
||||
f"temp={config['temperature']}, sample={config['do_sample']}")
|
||||
|
||||
# Кодируем промпт
|
||||
input_ids = tokenizer.encode(prompt, add_special_tokens=False)
|
||||
input_tensor = torch.tensor([input_ids], dtype=torch.long)
|
||||
|
||||
print(f"🎯 Токены промпта: {input_ids}")
|
||||
print(f"🎯 Токены (текст): {tokenizer.tokenize(prompt)}")
|
||||
print("🔄 Генерация...")
|
||||
|
||||
# Генерируем текст
|
||||
with torch.no_grad():
|
||||
generated_ids = model.generate(
|
||||
x=input_tensor,
|
||||
max_new_tokens=config["max_new_tokens"],
|
||||
do_sample=config["do_sample"],
|
||||
temperature=config["temperature"],
|
||||
top_k=config["top_k"],
|
||||
top_p=config["top_p"]
|
||||
)
|
||||
|
||||
# Декодируем результат
|
||||
generated_text = tokenizer.decode(generated_ids[0].tolist())
|
||||
|
||||
return generated_text
|
||||
|
||||
|
||||
def test_different_strategies(model: GPT2, tokenizer: BPETokenizer, prompt: str):
|
||||
"""
|
||||
Тестирует разные стратегии генерации на одном промпте.
|
||||
|
||||
Args:
|
||||
model: Обученная модель
|
||||
tokenizer: BPE токенизатор
|
||||
prompt: Тестовый промпт
|
||||
"""
|
||||
print(f"\n🎭 Сравнение стратегий генерации для промпта: '{prompt}'")
|
||||
print("=" * 60)
|
||||
|
||||
strategies = [
|
||||
{"name": "🎯 Жадный поиск", "do_sample": False, "temperature": 1.0},
|
||||
{"name": "🎲 Вероятностная (temp=0.7)", "do_sample": True, "temperature": 0.7},
|
||||
{"name": "🔥 Случайная (temp=1.2)", "do_sample": True, "temperature": 1.2},
|
||||
{"name": "❄️ Детерминированная (temp=0.3)", "do_sample": True, "temperature": 0.3},
|
||||
]
|
||||
|
||||
for strategy in strategies:
|
||||
print(f"\n{strategy['name']}:")
|
||||
try:
|
||||
config = GENERATION_CONFIG.copy()
|
||||
config.update({
|
||||
"do_sample": strategy["do_sample"],
|
||||
"temperature": strategy["temperature"],
|
||||
"max_new_tokens": 20
|
||||
})
|
||||
|
||||
generated = generate_text(model, tokenizer, prompt, config)
|
||||
|
||||
# Выделяем сгенерированную часть
|
||||
generated_part = generated[len(prompt):]
|
||||
print(f" 📤 Промпт: '{prompt}'")
|
||||
print(f" 🎯 Сгенерировано: '{generated_part}'")
|
||||
print(f" 📄 Полный текст: '{generated}'")
|
||||
|
||||
except Exception as e:
|
||||
print(f" ❌ Ошибка: {e}")
|
||||
|
||||
|
||||
def analyze_tokenization(tokenizer: BPETokenizer, texts: list):
|
||||
"""
|
||||
Анализирует токенизацию различных текстов.
|
||||
|
||||
Args:
|
||||
tokenizer: BPE токенизатор
|
||||
texts: Список текстов для анализа
|
||||
"""
|
||||
print(f"\n🔍 Анализ токенизации BPE:")
|
||||
print("=" * 50)
|
||||
|
||||
for i, text in enumerate(texts):
|
||||
print(f"\nТекст {i+1}: '{text}'")
|
||||
|
||||
# Токенизация
|
||||
tokens = tokenizer.encode(text, add_special_tokens=False)
|
||||
token_strings = tokenizer.tokenize(text)
|
||||
|
||||
print(f" Токены (ID): {tokens}")
|
||||
print(f" Токены (текст): {token_strings}")
|
||||
print(f" Количество токенов: {len(tokens)}")
|
||||
print(f" Эффективность: {len(text)} символов → {len(tokens)} токенов")
|
||||
|
||||
# Декодирование обратно
|
||||
decoded = tokenizer.decode(tokens)
|
||||
if text == decoded:
|
||||
print(f" ✅ Декодирование корректно")
|
||||
else:
|
||||
print(f" ⚠️ Расхождения: '{decoded}'")
|
||||
|
||||
|
||||
def interactive_generation(model: GPT2, tokenizer: BPETokenizer):
|
||||
"""
|
||||
Режим интерактивной генерации.
|
||||
|
||||
Args:
|
||||
model: Обученная модель
|
||||
tokenizer: BPE токенизатор
|
||||
"""
|
||||
print(f"\n💬 Интерактивная генерация (для выхода введите 'exit')")
|
||||
print("-" * 50)
|
||||
|
||||
while True:
|
||||
try:
|
||||
user_input = input("\n🔤 Введите промпт: ").strip()
|
||||
|
||||
if user_input.lower() in ['exit', 'quit', 'выход']:
|
||||
break
|
||||
|
||||
if not user_input:
|
||||
continue
|
||||
|
||||
# Запрашиваем параметры
|
||||
try:
|
||||
max_tokens = int(input("📏 Макс. токенов [50]: ") or "50")
|
||||
temperature = float(input("🌡️ Температура [0.7]: ") or "0.7")
|
||||
do_sample_input = input("🎲 Сэмплирование (y/n) [y]: ").lower()
|
||||
do_sample = do_sample_input != 'n'
|
||||
except:
|
||||
max_tokens = 50
|
||||
temperature = 0.7
|
||||
do_sample = True
|
||||
print("⚠️ Использую параметры по умолчанию")
|
||||
|
||||
config = GENERATION_CONFIG.copy()
|
||||
config.update({
|
||||
"max_new_tokens": max_tokens,
|
||||
"temperature": temperature,
|
||||
"do_sample": do_sample
|
||||
})
|
||||
|
||||
generated = generate_text(model, tokenizer, user_input, config)
|
||||
|
||||
generated_part = generated[len(user_input):]
|
||||
print(f"\n🎯 Результат:")
|
||||
print(f" 📤 Промпт: '{user_input}'")
|
||||
print(f" 🎯 Сгенерировано: '{generated_part}'")
|
||||
print(f" 📄 Полный текст: '{generated}'")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n👋 Завершение работы...")
|
||||
break
|
||||
except Exception as e:
|
||||
print(f"❌ Ошибка: {e}")
|
||||
|
||||
|
||||
def main():
|
||||
"""Основная функция эксперимента."""
|
||||
# === Настройка эксперимента ===
|
||||
experiment_name = "Генерация текста GPT2 + BPE (только llm)"
|
||||
experiment_config = {
|
||||
"model": "GPT2 с BPE токенизатором",
|
||||
"стратегия": "автономная генерация",
|
||||
"вход": "промпты",
|
||||
"выход": "сгенерированный текст"
|
||||
}
|
||||
|
||||
print_experiment_info(experiment_name, experiment_config)
|
||||
ensure_directories()
|
||||
logger = ExperimentLogger(experiment_name)
|
||||
|
||||
try:
|
||||
# Загружаем модель и токенизатор
|
||||
model, tokenizer, model_config = load_model_and_tokenizer()
|
||||
|
||||
# === Анализ токенизации ===
|
||||
analysis_texts = [
|
||||
"Искусственный интеллект",
|
||||
"Нейронные сети",
|
||||
"Машинное обучение",
|
||||
]
|
||||
analyze_tokenization(tokenizer, analysis_texts)
|
||||
|
||||
# === Генерация с разными промптами ===
|
||||
print(f"\n🎯 Генерация текста с разными промптами")
|
||||
print("=" * 60)
|
||||
|
||||
for i, prompt in enumerate(TEST_PROMPTS):
|
||||
print(f"\n📝 Пример {i+1}/{len(TEST_PROMPTS)}")
|
||||
print("-" * 40)
|
||||
|
||||
try:
|
||||
generated = generate_text(model, tokenizer, prompt, GENERATION_CONFIG)
|
||||
|
||||
# Выделяем сгенерированную часть
|
||||
generated_part = generated[len(prompt):]
|
||||
|
||||
print(f"📤 Промпт: '{prompt}'")
|
||||
print(f"🎯 Сгенерировано: '{generated_part}'")
|
||||
print(f"📄 Полный текст: '{generated}'")
|
||||
print(f"📏 Длина: {len(generated)} символов")
|
||||
|
||||
# Логируем успешную генерацию
|
||||
logger.log_metric(f"generation_length_{i}", len(generated))
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Ошибка при генерации: {e}")
|
||||
continue
|
||||
|
||||
# === Сравнение стратегий генерации ===
|
||||
test_prompt = "Искусственный"
|
||||
test_different_strategies(model, tokenizer, test_prompt)
|
||||
|
||||
# === Интерактивная генерация ===
|
||||
interactive_generation(model, tokenizer)
|
||||
|
||||
# === Сохранение результатов ===
|
||||
logger.save_logs("checkpoints/llm_only_generation_logs.json")
|
||||
|
||||
print(f"\n🎉 Эксперимент генерации завершен успешно!")
|
||||
|
||||
except FileNotFoundError as e:
|
||||
print(f"❌ {e}")
|
||||
except Exception as e:
|
||||
print(f"❌ Ошибка в эксперименте: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,313 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Experiment: generate_gpt_bpe.py
|
||||
Description: Генерация текста обученной GPT моделью с BPE токенизатором.
|
||||
Использует только библиотеку llm без зависимостей от HuggingFace.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import os
|
||||
import sys
|
||||
|
||||
# Добавляем путь к shared модулям
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from llm.models.gpt import GPT
|
||||
from llm.tokenizers import BPETokenizer
|
||||
|
||||
from shared.configs import (
|
||||
BASE_GPT_CONFIG, TEST_PROMPTS, GENERATION_CONFIG, PATHS
|
||||
)
|
||||
from shared.data import (
|
||||
print_experiment_info, ensure_directories, ExperimentLogger
|
||||
)
|
||||
|
||||
|
||||
def load_model_and_tokenizer() -> tuple:
|
||||
"""
|
||||
Загружает обученную модель и токенизатор.
|
||||
|
||||
Returns:
|
||||
tuple: (модель, токенизатор, конфигурация)
|
||||
"""
|
||||
# Проверяем существование файлов
|
||||
if not os.path.exists(PATHS["gpt_bpe_model"]):
|
||||
raise FileNotFoundError(
|
||||
f"Модель не найдена: {PATHS['gpt_bpe_model']}\n"
|
||||
f"Сначала обучите модель: uv run python experiments/llm_only/train_gpt_bpe.py"
|
||||
)
|
||||
|
||||
if not os.path.exists(PATHS["bpe_tokenizer"]):
|
||||
raise FileNotFoundError(
|
||||
f"Токенизатор не найден: {PATHS['bpe_tokenizer']}"
|
||||
)
|
||||
|
||||
# Загружаем конфигурацию модели
|
||||
import json
|
||||
with open(PATHS["gpt_bpe_config"], 'r', encoding='utf-8') as f:
|
||||
model_config = json.load(f)
|
||||
|
||||
# Загружаем токенизатор
|
||||
print("🔧 Загрузка BPE токенизатора...")
|
||||
tokenizer = BPETokenizer.load(PATHS["bpe_tokenizer"])
|
||||
print(f"✅ Токенизатор загружен (vocab_size={tokenizer.get_vocab_size()})")
|
||||
|
||||
# Загружаем модель
|
||||
print("🔧 Загрузка GPT модели...")
|
||||
model = GPT(model_config)
|
||||
model.load_state_dict(torch.load(PATHS["gpt_bpe_model"], map_location='cpu'))
|
||||
model.eval()
|
||||
print("✅ Модель загружена")
|
||||
|
||||
return model, tokenizer, model_config
|
||||
|
||||
|
||||
def generate_text(
|
||||
model: GPT,
|
||||
tokenizer: BPETokenizer,
|
||||
prompt: str,
|
||||
config: dict
|
||||
) -> str:
|
||||
"""
|
||||
Генерирует текст на основе промпта.
|
||||
|
||||
Args:
|
||||
model: Обученная GPT модель
|
||||
tokenizer: BPE токенизатор
|
||||
prompt: Входной текст
|
||||
config: Конфигурация генерации
|
||||
|
||||
Returns:
|
||||
str: Сгенерированный текст
|
||||
"""
|
||||
print(f"🔤 Промпт: '{prompt}'")
|
||||
print(f"📊 Параметры: max_tokens={config['max_new_tokens']}, "
|
||||
f"temp={config['temperature']}, sample={config['do_sample']}")
|
||||
|
||||
# Кодируем промпт
|
||||
input_ids = tokenizer.encode(prompt, add_special_tokens=False)
|
||||
input_tensor = torch.tensor([input_ids], dtype=torch.long)
|
||||
|
||||
print(f"🎯 Токены промпта: {input_ids}")
|
||||
print(f"🎯 Токены (текст): {tokenizer.tokenize(prompt)}")
|
||||
print("🔄 Генерация...")
|
||||
|
||||
# Генерируем текст
|
||||
with torch.no_grad():
|
||||
generated_ids = model.generate(
|
||||
x=input_tensor,
|
||||
max_new_tokens=config["max_new_tokens"],
|
||||
do_sample=config["do_sample"],
|
||||
temperature=config["temperature"],
|
||||
top_k=config["top_k"],
|
||||
top_p=config["top_p"]
|
||||
)
|
||||
|
||||
# Декодируем результат
|
||||
generated_text = tokenizer.decode(generated_ids[0].tolist())
|
||||
|
||||
return generated_text
|
||||
|
||||
|
||||
def test_different_strategies(model: GPT, tokenizer: BPETokenizer, prompt: str):
|
||||
"""
|
||||
Тестирует разные стратегии генерации на одном промпте.
|
||||
|
||||
Args:
|
||||
model: Обученная модель
|
||||
tokenizer: BPE токенизатор
|
||||
prompt: Тестовый промпт
|
||||
"""
|
||||
print(f"\n🎭 Сравнение стратегий генерации для промпта: '{prompt}'")
|
||||
print("=" * 60)
|
||||
|
||||
strategies = [
|
||||
{"name": "🎯 Жадный поиск", "do_sample": False, "temperature": 1.0},
|
||||
{"name": "🎲 Вероятностная (temp=0.7)", "do_sample": True, "temperature": 0.7},
|
||||
{"name": "🔥 Случайная (temp=1.2)", "do_sample": True, "temperature": 1.2},
|
||||
{"name": "❄️ Детерминированная (temp=0.3)", "do_sample": True, "temperature": 0.3},
|
||||
]
|
||||
|
||||
for strategy in strategies:
|
||||
print(f"\n{strategy['name']}:")
|
||||
try:
|
||||
config = GENERATION_CONFIG.copy()
|
||||
config.update({
|
||||
"do_sample": strategy["do_sample"],
|
||||
"temperature": strategy["temperature"],
|
||||
"max_new_tokens": 20
|
||||
})
|
||||
|
||||
generated = generate_text(model, tokenizer, prompt, config)
|
||||
|
||||
# Выделяем сгенерированную часть
|
||||
generated_part = generated[len(prompt):]
|
||||
print(f" 📤 Промпт: '{prompt}'")
|
||||
print(f" 🎯 Сгенерировано: '{generated_part}'")
|
||||
print(f" 📄 Полный текст: '{generated}'")
|
||||
|
||||
except Exception as e:
|
||||
print(f" ❌ Ошибка: {e}")
|
||||
|
||||
|
||||
def analyze_tokenization(tokenizer: BPETokenizer, texts: list):
|
||||
"""
|
||||
Анализирует токенизацию различных текстов.
|
||||
|
||||
Args:
|
||||
tokenizer: BPE токенизатор
|
||||
texts: Список текстов для анализа
|
||||
"""
|
||||
print(f"\n🔍 Анализ токенизации BPE:")
|
||||
print("=" * 50)
|
||||
|
||||
for i, text in enumerate(texts):
|
||||
print(f"\nТекст {i+1}: '{text}'")
|
||||
|
||||
# Токенизация
|
||||
tokens = tokenizer.encode(text, add_special_tokens=False)
|
||||
token_strings = tokenizer.tokenize(text)
|
||||
|
||||
print(f" Токены (ID): {tokens}")
|
||||
print(f" Токены (текст): {token_strings}")
|
||||
print(f" Количество токенов: {len(tokens)}")
|
||||
print(f" Эффективность: {len(text)} символов → {len(tokens)} токенов")
|
||||
|
||||
# Декодирование обратно
|
||||
decoded = tokenizer.decode(tokens)
|
||||
if text == decoded:
|
||||
print(f" ✅ Декодирование корректно")
|
||||
else:
|
||||
print(f" ⚠️ Расхождения: '{decoded}'")
|
||||
|
||||
|
||||
def interactive_generation(model: GPT, tokenizer: BPETokenizer):
|
||||
"""
|
||||
Режим интерактивной генерации.
|
||||
|
||||
Args:
|
||||
model: Обученная модель
|
||||
tokenizer: BPE токенизатор
|
||||
"""
|
||||
print(f"\n💬 Интерактивная генерация (для выхода введите 'exit')")
|
||||
print("-" * 50)
|
||||
|
||||
while True:
|
||||
try:
|
||||
user_input = input("\n🔤 Введите промпт: ").strip()
|
||||
|
||||
if user_input.lower() in ['exit', 'quit', 'выход']:
|
||||
break
|
||||
|
||||
if not user_input:
|
||||
continue
|
||||
|
||||
# Запрашиваем параметры
|
||||
try:
|
||||
max_tokens = int(input("📏 Макс. токенов [50]: ") or "50")
|
||||
temperature = float(input("🌡️ Температура [0.7]: ") or "0.7")
|
||||
do_sample_input = input("🎲 Сэмплирование (y/n) [y]: ").lower()
|
||||
do_sample = do_sample_input != 'n'
|
||||
except:
|
||||
max_tokens = 50
|
||||
temperature = 0.7
|
||||
do_sample = True
|
||||
print("⚠️ Использую параметры по умолчанию")
|
||||
|
||||
config = GENERATION_CONFIG.copy()
|
||||
config.update({
|
||||
"max_new_tokens": max_tokens,
|
||||
"temperature": temperature,
|
||||
"do_sample": do_sample
|
||||
})
|
||||
|
||||
generated = generate_text(model, tokenizer, user_input, config)
|
||||
|
||||
generated_part = generated[len(user_input):]
|
||||
print(f"\n🎯 Результат:")
|
||||
print(f" 📤 Промпт: '{user_input}'")
|
||||
print(f" 🎯 Сгенерировано: '{generated_part}'")
|
||||
print(f" 📄 Полный текст: '{generated}'")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n👋 Завершение работы...")
|
||||
break
|
||||
except Exception as e:
|
||||
print(f"❌ Ошибка: {e}")
|
||||
|
||||
|
||||
def main():
|
||||
"""Основная функция эксперимента."""
|
||||
# === Настройка эксперимента ===
|
||||
experiment_name = "Генерация текста GPT + BPE (только llm)"
|
||||
experiment_config = {
|
||||
"model": "GPT с BPE токенизатором",
|
||||
"стратегия": "автономная генерация",
|
||||
"вход": "промпты",
|
||||
"выход": "сгенерированный текст"
|
||||
}
|
||||
|
||||
print_experiment_info(experiment_name, experiment_config)
|
||||
ensure_directories()
|
||||
logger = ExperimentLogger(experiment_name)
|
||||
|
||||
try:
|
||||
# Загружаем модель и токенизатор
|
||||
model, tokenizer, model_config = load_model_and_tokenizer()
|
||||
|
||||
# === Анализ токенизации ===
|
||||
analysis_texts = [
|
||||
"Искусственный интеллект",
|
||||
"Нейронные сети",
|
||||
"Машинное обучение",
|
||||
]
|
||||
analyze_tokenization(tokenizer, analysis_texts)
|
||||
|
||||
# === Генерация с разными промптами ===
|
||||
print(f"\n🎯 Генерация текста с разными промптами")
|
||||
print("=" * 60)
|
||||
|
||||
for i, prompt in enumerate(TEST_PROMPTS):
|
||||
print(f"\n📝 Пример {i+1}/{len(TEST_PROMPTS)}")
|
||||
print("-" * 40)
|
||||
|
||||
try:
|
||||
generated = generate_text(model, tokenizer, prompt, GENERATION_CONFIG)
|
||||
|
||||
# Выделяем сгенерированную часть
|
||||
generated_part = generated[len(prompt):]
|
||||
|
||||
print(f"📤 Промпт: '{prompt}'")
|
||||
print(f"🎯 Сгенерировано: '{generated_part}'")
|
||||
print(f"📄 Полный текст: '{generated}'")
|
||||
print(f"📏 Длина: {len(generated)} символов")
|
||||
|
||||
# Логируем успешную генерацию
|
||||
logger.log_metric(f"generation_length_{i}", len(generated))
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Ошибка при генерации: {e}")
|
||||
continue
|
||||
|
||||
# === Сравнение стратегий генерации ===
|
||||
test_prompt = "Искусственный"
|
||||
test_different_strategies(model, tokenizer, test_prompt)
|
||||
|
||||
# === Интерактивная генерация ===
|
||||
interactive_generation(model, tokenizer)
|
||||
|
||||
# === Сохранение результатов ===
|
||||
logger.save_logs("checkpoints/llm_only_generation_logs.json")
|
||||
|
||||
print(f"\n🎉 Эксперимент генерации завершен успешно!")
|
||||
|
||||
except FileNotFoundError as e:
|
||||
print(f"❌ {e}")
|
||||
except Exception as e:
|
||||
print(f"❌ Ошибка в эксперименте: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
176
experiments/llm_only/run_llm_experiment.py
Normal file
176
experiments/llm_only/run_llm_experiment.py
Normal file
@@ -0,0 +1,176 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Универсальный скрипт для обучения и генерации LLM.
|
||||
Позволяет выбирать тип модели и действие через аргументы,
|
||||
а специальные параметры подавать отдельным JSON-конфигом.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
|
||||
import torch
|
||||
|
||||
# Добавляем директорию shared среди импортируемых
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from llm.tokenizers import BPETokenizer
|
||||
from llm.datasets.text_dataset import TextDataset
|
||||
from llm.training.trainer import Trainer
|
||||
|
||||
from shared.data import (
|
||||
print_experiment_info,
|
||||
ensure_directories,
|
||||
load_training_data,
|
||||
ExperimentLogger,
|
||||
)
|
||||
|
||||
|
||||
def load_config(config_path):
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def load_model_class(model_name):
|
||||
if model_name.lower() == 'gpt':
|
||||
from llm.models.gpt import GPT
|
||||
return GPT
|
||||
elif model_name.lower() == 'gpt2':
|
||||
from llm.models.gpt import GPT2
|
||||
return GPT2
|
||||
elif model_name.lower() == 'llama':
|
||||
from llm.models.llama import Llama
|
||||
return Llama
|
||||
elif model_name.lower() == 'mistral':
|
||||
from llm.models.mistral import Mistral
|
||||
return Mistral
|
||||
elif model_name.lower() == 'mixtral':
|
||||
from llm.models.mixtral import Mixtral
|
||||
return Mixtral
|
||||
elif model_name.lower() == 'gemma':
|
||||
from llm.models.gemma import Gemma
|
||||
return Gemma
|
||||
else:
|
||||
raise ValueError(f"Модель '{model_name}' не поддерживается.")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='Универсальный запуск обучения/генерации LLM.')
|
||||
parser.add_argument('--model', '-m', type=str, required=True, help='Название модели (gpt, gpt2, llama и т.д.).')
|
||||
parser.add_argument('--action', '-a', type=str, required=True, choices=['train', 'generate'], help='Действие: train или generate.')
|
||||
parser.add_argument('--config', '-c', type=str, required=True, help='Путь к JSON-конфигу с параметрами.')
|
||||
args = parser.parse_args()
|
||||
|
||||
config = load_config(args.config)
|
||||
ModelClass = load_model_class(args.model)
|
||||
logger = ExperimentLogger(f"{args.action}_{args.model}")
|
||||
|
||||
print_experiment_info(f"Эксперимент {args.action} {args.model}", config)
|
||||
ensure_directories()
|
||||
|
||||
# ==== Обучение ====
|
||||
if args.action == 'train':
|
||||
train_texts, val_texts = load_training_data()
|
||||
# --- Токенизатор ---
|
||||
if os.path.exists(config["bpe_tokenizer"]):
|
||||
print("📝 Загрузка обученного токенизатора...")
|
||||
tokenizer = BPETokenizer.load(config["bpe_tokenizer"])
|
||||
print(f"✅ Токенизатор загружен (vocab_size={tokenizer.get_vocab_size()})")
|
||||
else:
|
||||
print("🔧 Обучение BPE токенизатора...")
|
||||
tokenizer = BPETokenizer()
|
||||
tokenizer.train(
|
||||
texts=train_texts,
|
||||
vocab_size=config["bpe_vocab_size"],
|
||||
special_tokens=config["bpe_special_tokens"]
|
||||
)
|
||||
os.makedirs(os.path.dirname(config["bpe_tokenizer"]), exist_ok=True)
|
||||
tokenizer.save(config["bpe_tokenizer"])
|
||||
print(f"✅ BPE токенизатор обучен и сохранен: {config['bpe_tokenizer']}")
|
||||
|
||||
# Тестируем токенизатор (базово)
|
||||
for test_text in config.get("test_prompts", ["Тест"]):
|
||||
encoded = tokenizer.encode(test_text)
|
||||
decoded = tokenizer.decode(encoded)
|
||||
print(f"[TEST TOK] '{test_text}' → {encoded} → '{decoded}'")
|
||||
|
||||
# --- Модель ---
|
||||
model_config = config["model_config"]
|
||||
model_config["vocab_size"] = tokenizer.get_vocab_size()
|
||||
model = ModelClass(model_config)
|
||||
|
||||
# --- Датасет ---
|
||||
train_dataset = TextDataset(
|
||||
train_texts,
|
||||
tokenizer,
|
||||
block_size=model_config["max_position_embeddings"]
|
||||
)
|
||||
print(f" Размер train датасета: {len(train_dataset)} примеров")
|
||||
|
||||
# --- Trainer ---
|
||||
training = config["training"]
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
train_dataset=train_dataset,
|
||||
lr=training["learning_rate"],
|
||||
batch_size=training["batch_size"],
|
||||
num_epochs=training["num_epochs"],
|
||||
warmup_steps=training.get("warmup_steps", 0),
|
||||
)
|
||||
trainer.train()
|
||||
|
||||
# --- Сохранение модели ---
|
||||
os.makedirs(os.path.dirname(config["model_weights"]), exist_ok=True)
|
||||
torch.save(model.state_dict(), config["model_weights"])
|
||||
with open(config["model_config_path"], "w", encoding="utf-8") as f:
|
||||
json.dump(model_config, f, indent=2, ensure_ascii=False)
|
||||
print(f"✅ Модель сохранена: {config['model_weights']}")
|
||||
|
||||
logger.save_logs(config.get("log_path", "checkpoints/llm_only_training_logs.json"))
|
||||
|
||||
# ==== Генерация ====
|
||||
elif args.action == 'generate':
|
||||
# --- Загрузка ---
|
||||
if not os.path.exists(config["model_weights"]):
|
||||
raise FileNotFoundError(f"Модель не найдена: {config['model_weights']}")
|
||||
if not os.path.exists(config["bpe_tokenizer"]):
|
||||
raise FileNotFoundError(f"Токенизатор не найден: {config['bpe_tokenizer']}")
|
||||
with open(config["model_config_path"], "r", encoding="utf-8") as f:
|
||||
model_config = json.load(f)
|
||||
tokenizer = BPETokenizer.load(config["bpe_tokenizer"])
|
||||
model = ModelClass(model_config)
|
||||
model.load_state_dict(torch.load(config["model_weights"], map_location="cpu"))
|
||||
model.eval()
|
||||
|
||||
def generate(prompt, gen_cfg):
|
||||
print(f"Промпт: {prompt}")
|
||||
input_ids = tokenizer.encode(prompt, add_special_tokens=False)
|
||||
input_tensor = torch.tensor([input_ids], dtype=torch.long)
|
||||
with torch.no_grad():
|
||||
generated_ids = model.generate(
|
||||
x=input_tensor,
|
||||
max_new_tokens=gen_cfg["max_new_tokens"],
|
||||
do_sample=gen_cfg["do_sample"],
|
||||
temperature=gen_cfg["temperature"],
|
||||
top_k=gen_cfg.get("top_k"),
|
||||
top_p=gen_cfg.get("top_p"),
|
||||
)
|
||||
return tokenizer.decode(generated_ids[0].tolist())
|
||||
|
||||
prompts = config.get("test_prompts", ["Тестовый промпт"])
|
||||
gen_cfg = config.get("generation", {
|
||||
"max_new_tokens": 50,
|
||||
"temperature": 0.7,
|
||||
"do_sample": True,
|
||||
"top_k": None,
|
||||
"top_p": None
|
||||
})
|
||||
for prompt in prompts:
|
||||
generated = generate(prompt, gen_cfg)
|
||||
print(f"\n[RESULT] Prompt: '{prompt}'\n---\n{generated}\n{'='*60}")
|
||||
|
||||
logger.save_logs(config.get("log_path", "checkpoints/llm_only_generation_logs.json"))
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,231 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Experiment: train_gpt_bpe.py
|
||||
Description: Обучение GPT модели с собственным BPE токенизатором.
|
||||
Использует только библиотеку llm без зависимостей от HuggingFace.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import os
|
||||
import sys
|
||||
|
||||
# Добавляем путь к shared модулям
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from llm.models.gpt import GPT2
|
||||
from llm.tokenizers import BPETokenizer
|
||||
from llm.training.dataset import TextDataset
|
||||
from llm.training.trainer import Trainer
|
||||
|
||||
from shared.configs import (
|
||||
TRAIN_TEXTS, BASE_GPT_CONFIG, BPE_CONFIG,
|
||||
TRAINING_CONFIG, PATHS, TEST_PROMPTS
|
||||
)
|
||||
from shared.data import (
|
||||
load_training_data, ensure_directories,
|
||||
print_experiment_info, ExperimentLogger
|
||||
)
|
||||
|
||||
|
||||
def train_bpe_tokenizer(texts: list, config: dict) -> BPETokenizer:
|
||||
"""
|
||||
Обучает BPE токенизатор на текстах.
|
||||
|
||||
Args:
|
||||
texts: Список текстов для обучения
|
||||
config: Конфигурация токенизатора
|
||||
|
||||
Returns:
|
||||
BPETokenizer: Обученный токенизатор
|
||||
"""
|
||||
print("🔧 Обучение BPE токенизатора...")
|
||||
|
||||
tokenizer = BPETokenizer()
|
||||
tokenizer.train(
|
||||
texts=texts,
|
||||
vocab_size=config["vocab_size"],
|
||||
special_tokens=config["special_tokens"]
|
||||
)
|
||||
|
||||
# Сохраняем токенизатор
|
||||
os.makedirs(os.path.dirname(PATHS["bpe_tokenizer"]), exist_ok=True)
|
||||
tokenizer.save(PATHS["bpe_tokenizer"])
|
||||
|
||||
print(f"✅ BPE токенизатор обучен и сохранен: {PATHS['bpe_tokenizer']}")
|
||||
print(f"📊 Размер словаря: {tokenizer.get_vocab_size()}")
|
||||
|
||||
return tokenizer
|
||||
|
||||
|
||||
def test_tokenizer(tokenizer: BPETokenizer, texts: list):
|
||||
"""
|
||||
Тестирует токенизатор на примерах.
|
||||
|
||||
Args:
|
||||
tokenizer: Обученный токенизатор
|
||||
texts: Список тестовых текстов
|
||||
"""
|
||||
print("\n🧪 Тестирование токенизатора:")
|
||||
|
||||
for i, text in enumerate(texts[:3]):
|
||||
print(f"\nПример {i+1}:")
|
||||
print(f" Исходный текст: '{text}'")
|
||||
|
||||
# Кодирование
|
||||
tokens = tokenizer.encode(text)
|
||||
token_strings = tokenizer.tokenize(text)
|
||||
|
||||
print(f" Токены (ID): {tokens}")
|
||||
print(f" Токены (текст): {token_strings}")
|
||||
print(f" Количество токенов: {len(tokens)}")
|
||||
|
||||
# Декодирование
|
||||
decoded = tokenizer.decode(tokens)
|
||||
print(f" Декодированный: '{decoded}'")
|
||||
|
||||
if text == decoded:
|
||||
print(" ✅ Кодирование/декодирование корректно")
|
||||
else:
|
||||
print(" ⚠️ Небольшие расхождения")
|
||||
|
||||
|
||||
def main():
|
||||
"""Основная функция эксперимента."""
|
||||
# === Настройка эксперимента ===
|
||||
experiment_name = "Обучение GPT2 с BPE токенизатором (только llm)"
|
||||
experiment_config = {
|
||||
"model": "GPT2",
|
||||
"tokenizer": "BPE",
|
||||
"vocab_size": BPE_CONFIG["vocab_size"],
|
||||
"training_epochs": TRAINING_CONFIG["num_epochs"],
|
||||
"batch_size": TRAINING_CONFIG["batch_size"],
|
||||
"learning_rate": TRAINING_CONFIG["learning_rate"]
|
||||
}
|
||||
|
||||
print_experiment_info(experiment_name, experiment_config)
|
||||
ensure_directories()
|
||||
logger = ExperimentLogger(experiment_name)
|
||||
|
||||
try:
|
||||
# === Подготовка данных ===
|
||||
train_texts, val_texts = load_training_data()
|
||||
print(f"📊 Данные: {len(train_texts)} train, {len(val_texts)} validation")
|
||||
|
||||
# === Обучение токенизатора ===
|
||||
if os.path.exists(PATHS["bpe_tokenizer"]):
|
||||
print("📝 Загрузка предварительно обученного токенизатора...")
|
||||
tokenizer = BPETokenizer.load(PATHS["bpe_tokenizer"])
|
||||
print(f"✅ Токенизатор загружен (vocab_size={tokenizer.get_vocab_size()})")
|
||||
else:
|
||||
tokenizer = train_bpe_tokenizer(TRAIN_TEXTS, BPE_CONFIG)
|
||||
|
||||
# Тестируем токенизатор
|
||||
test_tokenizer(tokenizer, TEST_PROMPTS[:3])
|
||||
|
||||
# === Инициализация модели ===
|
||||
model_config = BASE_GPT_CONFIG.copy()
|
||||
model_config["vocab_size"] = tokenizer.get_vocab_size()
|
||||
|
||||
print(f"\n🔧 Инициализация GPT2 модели...")
|
||||
print(f" Размер словаря: {model_config['vocab_size']}")
|
||||
print(f" Размер эмбеддингов: {model_config['embed_dim']}")
|
||||
print(f" Количество слоев: {model_config['num_layers']}")
|
||||
print(f" Количество голов внимания: {model_config['num_heads']}")
|
||||
|
||||
model = GPT2(model_config)
|
||||
|
||||
# === Подготовка датасета ===
|
||||
print(f"\n📊 Подготовка датасета...")
|
||||
train_dataset = TextDataset(
|
||||
train_texts,
|
||||
tokenizer,
|
||||
block_size=model_config["max_position_embeddings"]
|
||||
)
|
||||
print(f" Размер train датасета: {len(train_dataset)} примеров")
|
||||
|
||||
# === Обучение модели ===
|
||||
print(f"\n🎯 Начало обучения GPT2 модели...")
|
||||
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
train_dataset=train_dataset,
|
||||
lr=TRAINING_CONFIG["learning_rate"],
|
||||
batch_size=TRAINING_CONFIG["batch_size"],
|
||||
num_epochs=TRAINING_CONFIG["num_epochs"],
|
||||
warmup_steps=TRAINING_CONFIG["warmup_steps"]
|
||||
)
|
||||
|
||||
# Запускаем обучение
|
||||
trainer.train()
|
||||
|
||||
# === Сохранение модели ===
|
||||
print(f"\n💾 Сохранение модели...")
|
||||
os.makedirs(os.path.dirname(PATHS["gpt_bpe_model"]), exist_ok=True)
|
||||
|
||||
# Сохраняем модель
|
||||
torch.save(model.state_dict(), PATHS["gpt_bpe_model"])
|
||||
|
||||
# Сохраняем конфигурацию
|
||||
import json
|
||||
with open(PATHS["gpt_bpe_config"], 'w', encoding='utf-8') as f:
|
||||
json.dump(model_config, f, indent=2, ensure_ascii=False)
|
||||
|
||||
print(f"✅ Модель сохранена:")
|
||||
print(f" - {PATHS['gpt_bpe_model']}: веса модели")
|
||||
print(f" - {PATHS['gpt_bpe_config']}: конфигурация модели")
|
||||
print(f" - {PATHS['bpe_tokenizer']}: токенизатор")
|
||||
|
||||
# === Тестирование генерации ===
|
||||
print(f"\n🧪 Тестирование генерации текста...")
|
||||
model.eval()
|
||||
|
||||
for prompt in TEST_PROMPTS[:3]:
|
||||
print(f"\n🔤 Промпт: '{prompt}'")
|
||||
|
||||
try:
|
||||
# Кодируем промпт
|
||||
input_ids = tokenizer.encode(prompt, add_special_tokens=False)
|
||||
input_tensor = torch.tensor([input_ids], dtype=torch.long)
|
||||
|
||||
# Генерируем текст
|
||||
with torch.no_grad():
|
||||
generated_ids = model.generate(
|
||||
x=input_tensor,
|
||||
max_new_tokens=20,
|
||||
do_sample=True,
|
||||
temperature=0.8
|
||||
)
|
||||
|
||||
# Декодируем результат
|
||||
generated_text = tokenizer.decode(generated_ids[0].tolist())
|
||||
generated_part = generated_text[len(prompt):]
|
||||
|
||||
print(f"🎯 Сгенерировано: '{generated_part}'")
|
||||
print(f"📄 Полный текст: '{generated_text}'")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Ошибка генерации: {e}")
|
||||
|
||||
# === Сохранение результатов ===
|
||||
results = {
|
||||
"experiment": experiment_name,
|
||||
"model_config": model_config,
|
||||
"training_config": TRAINING_CONFIG,
|
||||
"tokenizer_vocab_size": tokenizer.get_vocab_size(),
|
||||
"final_loss": "см. логи обучения" # В реальном эксперименте можно сохранить final loss
|
||||
}
|
||||
|
||||
logger.save_logs("checkpoints/llm_only_training_logs.json")
|
||||
|
||||
print(f"\n🎉 Эксперимент завершен успешно!")
|
||||
print(f"\n💡 Для использования обученной модели:")
|
||||
print(f" uv run python experiments/llm_only/generate_gpt_bpe.py")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Ошибка в эксперименте: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,231 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Experiment: train_gpt_bpe.py
|
||||
Description: Обучение GPT модели с собственным BPE токенизатором.
|
||||
Использует только библиотеку llm без зависимостей от HuggingFace.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import os
|
||||
import sys
|
||||
|
||||
# Добавляем путь к shared модулям
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from llm.models.gpt import GPT
|
||||
from llm.tokenizers import BPETokenizer
|
||||
from llm.training.dataset import TextDataset
|
||||
from llm.training.trainer import Trainer
|
||||
|
||||
from shared.configs import (
|
||||
TRAIN_TEXTS, BASE_GPT_CONFIG, BPE_CONFIG,
|
||||
TRAINING_CONFIG, PATHS, TEST_PROMPTS
|
||||
)
|
||||
from shared.data import (
|
||||
load_training_data, ensure_directories,
|
||||
print_experiment_info, ExperimentLogger
|
||||
)
|
||||
|
||||
|
||||
def train_bpe_tokenizer(texts: list, config: dict) -> BPETokenizer:
|
||||
"""
|
||||
Обучает BPE токенизатор на текстах.
|
||||
|
||||
Args:
|
||||
texts: Список текстов для обучения
|
||||
config: Конфигурация токенизатора
|
||||
|
||||
Returns:
|
||||
BPETokenizer: Обученный токенизатор
|
||||
"""
|
||||
print("🔧 Обучение BPE токенизатора...")
|
||||
|
||||
tokenizer = BPETokenizer()
|
||||
tokenizer.train(
|
||||
texts=texts,
|
||||
vocab_size=config["vocab_size"],
|
||||
special_tokens=config["special_tokens"]
|
||||
)
|
||||
|
||||
# Сохраняем токенизатор
|
||||
os.makedirs(os.path.dirname(PATHS["bpe_tokenizer"]), exist_ok=True)
|
||||
tokenizer.save(PATHS["bpe_tokenizer"])
|
||||
|
||||
print(f"✅ BPE токенизатор обучен и сохранен: {PATHS['bpe_tokenizer']}")
|
||||
print(f"📊 Размер словаря: {tokenizer.get_vocab_size()}")
|
||||
|
||||
return tokenizer
|
||||
|
||||
|
||||
def test_tokenizer(tokenizer: BPETokenizer, texts: list):
|
||||
"""
|
||||
Тестирует токенизатор на примерах.
|
||||
|
||||
Args:
|
||||
tokenizer: Обученный токенизатор
|
||||
texts: Список тестовых текстов
|
||||
"""
|
||||
print("\n🧪 Тестирование токенизатора:")
|
||||
|
||||
for i, text in enumerate(texts[:3]):
|
||||
print(f"\nПример {i+1}:")
|
||||
print(f" Исходный текст: '{text}'")
|
||||
|
||||
# Кодирование
|
||||
tokens = tokenizer.encode(text)
|
||||
token_strings = tokenizer.tokenize(text)
|
||||
|
||||
print(f" Токены (ID): {tokens}")
|
||||
print(f" Токены (текст): {token_strings}")
|
||||
print(f" Количество токенов: {len(tokens)}")
|
||||
|
||||
# Декодирование
|
||||
decoded = tokenizer.decode(tokens)
|
||||
print(f" Декодированный: '{decoded}'")
|
||||
|
||||
if text == decoded:
|
||||
print(" ✅ Кодирование/декодирование корректно")
|
||||
else:
|
||||
print(" ⚠️ Небольшие расхождения")
|
||||
|
||||
|
||||
def main():
|
||||
"""Основная функция эксперимента."""
|
||||
# === Настройка эксперимента ===
|
||||
experiment_name = "Обучение GPT с BPE токенизатором (только llm)"
|
||||
experiment_config = {
|
||||
"model": "GPT",
|
||||
"tokenizer": "BPE",
|
||||
"vocab_size": BPE_CONFIG["vocab_size"],
|
||||
"training_epochs": TRAINING_CONFIG["num_epochs"],
|
||||
"batch_size": TRAINING_CONFIG["batch_size"],
|
||||
"learning_rate": TRAINING_CONFIG["learning_rate"]
|
||||
}
|
||||
|
||||
print_experiment_info(experiment_name, experiment_config)
|
||||
ensure_directories()
|
||||
logger = ExperimentLogger(experiment_name)
|
||||
|
||||
try:
|
||||
# === Подготовка данных ===
|
||||
train_texts, val_texts = load_training_data()
|
||||
print(f"📊 Данные: {len(train_texts)} train, {len(val_texts)} validation")
|
||||
|
||||
# === Обучение токенизатора ===
|
||||
if os.path.exists(PATHS["bpe_tokenizer"]):
|
||||
print("📝 Загрузка предварительно обученного токенизатора...")
|
||||
tokenizer = BPETokenizer.load(PATHS["bpe_tokenizer"])
|
||||
print(f"✅ Токенизатор загружен (vocab_size={tokenizer.get_vocab_size()})")
|
||||
else:
|
||||
tokenizer = train_bpe_tokenizer(TRAIN_TEXTS, BPE_CONFIG)
|
||||
|
||||
# Тестируем токенизатор
|
||||
test_tokenizer(tokenizer, TEST_PROMPTS[:3])
|
||||
|
||||
# === Инициализация модели ===
|
||||
model_config = BASE_GPT_CONFIG.copy()
|
||||
model_config["vocab_size"] = tokenizer.get_vocab_size()
|
||||
|
||||
print(f"\n🔧 Инициализация GPT модели...")
|
||||
print(f" Размер словаря: {model_config['vocab_size']}")
|
||||
print(f" Размер эмбеддингов: {model_config['embed_dim']}")
|
||||
print(f" Количество слоев: {model_config['num_layers']}")
|
||||
print(f" Количество голов внимания: {model_config['num_heads']}")
|
||||
|
||||
model = GPT(model_config)
|
||||
|
||||
# === Подготовка датасета ===
|
||||
print(f"\n📊 Подготовка датасета...")
|
||||
train_dataset = TextDataset(
|
||||
train_texts,
|
||||
tokenizer,
|
||||
block_size=model_config["max_position_embeddings"]
|
||||
)
|
||||
print(f" Размер train датасета: {len(train_dataset)} примеров")
|
||||
|
||||
# === Обучение модели ===
|
||||
print(f"\n🎯 Начало обучения GPT модели...")
|
||||
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
train_dataset=train_dataset,
|
||||
lr=TRAINING_CONFIG["learning_rate"],
|
||||
batch_size=TRAINING_CONFIG["batch_size"],
|
||||
num_epochs=TRAINING_CONFIG["num_epochs"],
|
||||
warmup_steps=TRAINING_CONFIG["warmup_steps"]
|
||||
)
|
||||
|
||||
# Запускаем обучение
|
||||
trainer.train()
|
||||
|
||||
# === Сохранение модели ===
|
||||
print(f"\n💾 Сохранение модели...")
|
||||
os.makedirs(os.path.dirname(PATHS["gpt_bpe_model"]), exist_ok=True)
|
||||
|
||||
# Сохраняем модель
|
||||
torch.save(model.state_dict(), PATHS["gpt_bpe_model"])
|
||||
|
||||
# Сохраняем конфигурацию
|
||||
import json
|
||||
with open(PATHS["gpt_bpe_config"], 'w', encoding='utf-8') as f:
|
||||
json.dump(model_config, f, indent=2, ensure_ascii=False)
|
||||
|
||||
print(f"✅ Модель сохранена:")
|
||||
print(f" - {PATHS['gpt_bpe_model']}: веса модели")
|
||||
print(f" - {PATHS['gpt_bpe_config']}: конфигурация модели")
|
||||
print(f" - {PATHS['bpe_tokenizer']}: токенизатор")
|
||||
|
||||
# === Тестирование генерации ===
|
||||
print(f"\n🧪 Тестирование генерации текста...")
|
||||
model.eval()
|
||||
|
||||
for prompt in TEST_PROMPTS[:3]:
|
||||
print(f"\n🔤 Промпт: '{prompt}'")
|
||||
|
||||
try:
|
||||
# Кодируем промпт
|
||||
input_ids = tokenizer.encode(prompt, add_special_tokens=False)
|
||||
input_tensor = torch.tensor([input_ids], dtype=torch.long)
|
||||
|
||||
# Генерируем текст
|
||||
with torch.no_grad():
|
||||
generated_ids = model.generate(
|
||||
x=input_tensor,
|
||||
max_new_tokens=20,
|
||||
do_sample=True,
|
||||
temperature=0.8
|
||||
)
|
||||
|
||||
# Декодируем результат
|
||||
generated_text = tokenizer.decode(generated_ids[0].tolist())
|
||||
generated_part = generated_text[len(prompt):]
|
||||
|
||||
print(f"🎯 Сгенерировано: '{generated_part}'")
|
||||
print(f"📄 Полный текст: '{generated_text}'")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Ошибка генерации: {e}")
|
||||
|
||||
# === Сохранение результатов ===
|
||||
results = {
|
||||
"experiment": experiment_name,
|
||||
"model_config": model_config,
|
||||
"training_config": TRAINING_CONFIG,
|
||||
"tokenizer_vocab_size": tokenizer.get_vocab_size(),
|
||||
"final_loss": "см. логи обучения" # В реальном эксперименте можно сохранить final loss
|
||||
}
|
||||
|
||||
logger.save_logs("checkpoints/llm_only_training_logs.json")
|
||||
|
||||
print(f"\n🎉 Эксперимент завершен успешно!")
|
||||
print(f"\n💡 Для использования обученной модели:")
|
||||
print(f" uv run python experiments/llm_only/generate_gpt_bpe.py")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Ошибка в эксперименте: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,231 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Experiment: train_gpt_bpe.py
|
||||
Description: Обучение GPT модели с собственным BPE токенизатором.
|
||||
Использует только библиотеку llm без зависимостей от HuggingFace.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import os
|
||||
import sys
|
||||
|
||||
# Добавляем путь к shared модулям
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from llm.models.llama import Llama
|
||||
from llm.tokenizers import BPETokenizer
|
||||
from llm.training.dataset import TextDataset
|
||||
from llm.training.trainer import Trainer
|
||||
|
||||
from shared.configs import (
|
||||
TRAIN_TEXTS, BASE_GPT_CONFIG, BPE_CONFIG,
|
||||
TRAINING_CONFIG, PATHS, TEST_PROMPTS
|
||||
)
|
||||
from shared.data import (
|
||||
load_training_data, ensure_directories,
|
||||
print_experiment_info, ExperimentLogger
|
||||
)
|
||||
|
||||
|
||||
def train_bpe_tokenizer(texts: list, config: dict) -> BPETokenizer:
|
||||
"""
|
||||
Обучает BPE токенизатор на текстах.
|
||||
|
||||
Args:
|
||||
texts: Список текстов для обучения
|
||||
config: Конфигурация токенизатора
|
||||
|
||||
Returns:
|
||||
BPETokenizer: Обученный токенизатор
|
||||
"""
|
||||
print("🔧 Обучение BPE токенизатора...")
|
||||
|
||||
tokenizer = BPETokenizer()
|
||||
tokenizer.train(
|
||||
texts=texts,
|
||||
vocab_size=config["vocab_size"],
|
||||
special_tokens=config["special_tokens"]
|
||||
)
|
||||
|
||||
# Сохраняем токенизатор
|
||||
os.makedirs(os.path.dirname(PATHS["bpe_tokenizer"]), exist_ok=True)
|
||||
tokenizer.save(PATHS["bpe_tokenizer"])
|
||||
|
||||
print(f"✅ BPE токенизатор обучен и сохранен: {PATHS['bpe_tokenizer']}")
|
||||
print(f"📊 Размер словаря: {tokenizer.get_vocab_size()}")
|
||||
|
||||
return tokenizer
|
||||
|
||||
|
||||
def test_tokenizer(tokenizer: BPETokenizer, texts: list):
|
||||
"""
|
||||
Тестирует токенизатор на примерах.
|
||||
|
||||
Args:
|
||||
tokenizer: Обученный токенизатор
|
||||
texts: Список тестовых текстов
|
||||
"""
|
||||
print("\n🧪 Тестирование токенизатора:")
|
||||
|
||||
for i, text in enumerate(texts[:3]):
|
||||
print(f"\nПример {i+1}:")
|
||||
print(f" Исходный текст: '{text}'")
|
||||
|
||||
# Кодирование
|
||||
tokens = tokenizer.encode(text)
|
||||
token_strings = tokenizer.tokenize(text)
|
||||
|
||||
print(f" Токены (ID): {tokens}")
|
||||
print(f" Токены (текст): {token_strings}")
|
||||
print(f" Количество токенов: {len(tokens)}")
|
||||
|
||||
# Декодирование
|
||||
decoded = tokenizer.decode(tokens)
|
||||
print(f" Декодированный: '{decoded}'")
|
||||
|
||||
if text == decoded:
|
||||
print(" ✅ Кодирование/декодирование корректно")
|
||||
else:
|
||||
print(" ⚠️ Небольшие расхождения")
|
||||
|
||||
|
||||
def main():
|
||||
"""Основная функция эксперимента."""
|
||||
# === Настройка эксперимента ===
|
||||
experiment_name = "Обучение Llama с BPE токенизатором (только llm)"
|
||||
experiment_config = {
|
||||
"model": "Llama",
|
||||
"tokenizer": "BPE",
|
||||
"vocab_size": BPE_CONFIG["vocab_size"],
|
||||
"training_epochs": TRAINING_CONFIG["num_epochs"],
|
||||
"batch_size": TRAINING_CONFIG["batch_size"],
|
||||
"learning_rate": TRAINING_CONFIG["learning_rate"]
|
||||
}
|
||||
|
||||
print_experiment_info(experiment_name, experiment_config)
|
||||
ensure_directories()
|
||||
logger = ExperimentLogger(experiment_name)
|
||||
|
||||
try:
|
||||
# === Подготовка данных ===
|
||||
train_texts, val_texts = load_training_data()
|
||||
print(f"📊 Данные: {len(train_texts)} train, {len(val_texts)} validation")
|
||||
|
||||
# === Обучение токенизатора ===
|
||||
if os.path.exists(PATHS["bpe_tokenizer"]):
|
||||
print("📝 Загрузка предварительно обученного токенизатора...")
|
||||
tokenizer = BPETokenizer.load(PATHS["bpe_tokenizer"])
|
||||
print(f"✅ Токенизатор загружен (vocab_size={tokenizer.get_vocab_size()})")
|
||||
else:
|
||||
tokenizer = train_bpe_tokenizer(TRAIN_TEXTS, BPE_CONFIG)
|
||||
|
||||
# Тестируем токенизатор
|
||||
test_tokenizer(tokenizer, TEST_PROMPTS[:3])
|
||||
|
||||
# === Инициализация модели ===
|
||||
model_config = BASE_GPT_CONFIG.copy()
|
||||
model_config["vocab_size"] = tokenizer.get_vocab_size()
|
||||
|
||||
print(f"\n🔧 Инициализация Llama модели...")
|
||||
print(f" Размер словаря: {model_config['vocab_size']}")
|
||||
print(f" Размер эмбеддингов: {model_config['embed_dim']}")
|
||||
print(f" Количество слоев: {model_config['num_layers']}")
|
||||
print(f" Количество голов внимания: {model_config['num_heads']}")
|
||||
|
||||
model = Llama(model_config)
|
||||
|
||||
# === Подготовка датасета ===
|
||||
print(f"\n📊 Подготовка датасета...")
|
||||
train_dataset = TextDataset(
|
||||
train_texts,
|
||||
tokenizer,
|
||||
block_size=model_config["max_position_embeddings"]
|
||||
)
|
||||
print(f" Размер train датасета: {len(train_dataset)} примеров")
|
||||
|
||||
# === Обучение модели ===
|
||||
print(f"\n🎯 Начало обучения Llama модели...")
|
||||
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
train_dataset=train_dataset,
|
||||
lr=TRAINING_CONFIG["learning_rate"],
|
||||
batch_size=TRAINING_CONFIG["batch_size"],
|
||||
num_epochs=TRAINING_CONFIG["num_epochs"],
|
||||
warmup_steps=TRAINING_CONFIG["warmup_steps"]
|
||||
)
|
||||
|
||||
# Запускаем обучение
|
||||
trainer.train()
|
||||
|
||||
# === Сохранение модели ===
|
||||
print(f"\n💾 Сохранение модели...")
|
||||
os.makedirs(os.path.dirname(PATHS["gpt_bpe_model"]), exist_ok=True)
|
||||
|
||||
# Сохраняем модель
|
||||
torch.save(model.state_dict(), PATHS["gpt_bpe_model"])
|
||||
|
||||
# Сохраняем конфигурацию
|
||||
import json
|
||||
with open(PATHS["gpt_bpe_config"], 'w', encoding='utf-8') as f:
|
||||
json.dump(model_config, f, indent=2, ensure_ascii=False)
|
||||
|
||||
print(f"✅ Модель сохранена:")
|
||||
print(f" - {PATHS['gpt_bpe_model']}: веса модели")
|
||||
print(f" - {PATHS['gpt_bpe_config']}: конфигурация модели")
|
||||
print(f" - {PATHS['bpe_tokenizer']}: токенизатор")
|
||||
|
||||
# === Тестирование генерации ===
|
||||
print(f"\n🧪 Тестирование генерации текста...")
|
||||
model.eval()
|
||||
|
||||
for prompt in TEST_PROMPTS[:3]:
|
||||
print(f"\n🔤 Промпт: '{prompt}'")
|
||||
|
||||
try:
|
||||
# Кодируем промпт
|
||||
input_ids = tokenizer.encode(prompt, add_special_tokens=False)
|
||||
input_tensor = torch.tensor([input_ids], dtype=torch.long)
|
||||
|
||||
# Генерируем текст
|
||||
with torch.no_grad():
|
||||
generated_ids = model.generate(
|
||||
x=input_tensor,
|
||||
max_new_tokens=20,
|
||||
do_sample=True,
|
||||
temperature=0.8
|
||||
)
|
||||
|
||||
# Декодируем результат
|
||||
generated_text = tokenizer.decode(generated_ids[0].tolist())
|
||||
generated_part = generated_text[len(prompt):]
|
||||
|
||||
print(f"🎯 Сгенерировано: '{generated_part}'")
|
||||
print(f"📄 Полный текст: '{generated_text}'")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Ошибка генерации: {e}")
|
||||
|
||||
# === Сохранение результатов ===
|
||||
results = {
|
||||
"experiment": experiment_name,
|
||||
"model_config": model_config,
|
||||
"training_config": TRAINING_CONFIG,
|
||||
"tokenizer_vocab_size": tokenizer.get_vocab_size(),
|
||||
"final_loss": "см. логи обучения" # В реальном эксперименте можно сохранить final loss
|
||||
}
|
||||
|
||||
logger.save_logs("checkpoints/llm_only_training_logs.json")
|
||||
|
||||
print(f"\n🎉 Эксперимент завершен успешно!")
|
||||
print(f"\n💡 Для использования обученной модели:")
|
||||
print(f" uv run python experiments/llm_only/generate_gpt_bpe.py")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Ошибка в эксперименте: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -30,7 +30,7 @@ BASE_GPT_CONFIG = {
|
||||
"num_heads": 4,
|
||||
"num_layers": 4,
|
||||
"max_position_embeddings": 128,
|
||||
"dropout": 0.1
|
||||
"dropout": 0.1,
|
||||
}
|
||||
|
||||
# Конфигурация для маленькой модели (быстрое тестирование)
|
||||
@@ -40,7 +40,7 @@ SMALL_GPT_CONFIG = {
|
||||
"num_heads": 2,
|
||||
"num_layers": 2,
|
||||
"max_position_embeddings": 64,
|
||||
"dropout": 0.1
|
||||
"dropout": 0.1,
|
||||
}
|
||||
|
||||
# Конфигурация для большой модели (качественное обучение)
|
||||
@@ -50,13 +50,13 @@ LARGE_GPT_CONFIG = {
|
||||
"num_heads": 8,
|
||||
"num_layers": 6,
|
||||
"max_position_embeddings": 256,
|
||||
"dropout": 0.1
|
||||
"dropout": 0.1,
|
||||
}
|
||||
|
||||
# === Конфигурации токенизатора ===
|
||||
BPE_CONFIG = {
|
||||
"vocab_size": 1000,
|
||||
"special_tokens": ["<pad>", "<unk>", "<bos>", "<eos>"]
|
||||
"special_tokens": ["<pad>", "<unk>", "<bos>", "<eos>"],
|
||||
}
|
||||
|
||||
# === Конфигурации обучения ===
|
||||
@@ -65,7 +65,7 @@ TRAINING_CONFIG = {
|
||||
"batch_size": 2,
|
||||
"num_epochs": 3,
|
||||
"warmup_steps": 50,
|
||||
"gradient_clip": 1.0
|
||||
"gradient_clip": 1.0,
|
||||
}
|
||||
|
||||
# === Конфигурации генерации ===
|
||||
@@ -74,7 +74,7 @@ GENERATION_CONFIG = {
|
||||
"temperature": 0.7,
|
||||
"do_sample": True,
|
||||
"top_k": None,
|
||||
"top_p": None
|
||||
"top_p": None,
|
||||
}
|
||||
|
||||
# === Пути для сохранения ===
|
||||
@@ -84,7 +84,7 @@ PATHS = {
|
||||
"gpt_bpe_config": "checkpoints/gpt-bpe/config.json",
|
||||
"hf_tokenizer": "checkpoints/hf-bpe-tokenizer",
|
||||
"hf_model": "checkpoints/hf-trained",
|
||||
"hf_proxy_model": "checkpoints/hf-trained-proxy"
|
||||
"hf_proxy_model": "checkpoints/hf-trained-proxy",
|
||||
}
|
||||
|
||||
# === Тестовые промпты ===
|
||||
|
||||
@@ -10,17 +10,17 @@ from .configs import TRAIN_TEXTS, PATHS
|
||||
def load_training_data(split_ratio: float = 0.8) -> Tuple[List[str], List[str]]:
|
||||
"""
|
||||
Загружает данные для обучения и разделяет на train/validation.
|
||||
|
||||
|
||||
Args:
|
||||
split_ratio: Доля данных для обучения
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple: (train_texts, val_texts)
|
||||
"""
|
||||
train_size = int(len(TRAIN_TEXTS) * split_ratio)
|
||||
train_data = TRAIN_TEXTS[:train_size]
|
||||
val_data = TRAIN_TEXTS[train_size:]
|
||||
|
||||
|
||||
return train_data, val_data
|
||||
|
||||
|
||||
@@ -28,13 +28,13 @@ def ensure_directories():
|
||||
"""Создает необходимые директории если они не существуют."""
|
||||
directories = [
|
||||
"checkpoints",
|
||||
"checkpoints/gpt-bpe",
|
||||
"checkpoints/gpt-bpe",
|
||||
"checkpoints/hf-bpe-tokenizer",
|
||||
"checkpoints/hf-trained",
|
||||
"checkpoints/hf-trained-proxy",
|
||||
"logs"
|
||||
"logs",
|
||||
]
|
||||
|
||||
|
||||
for directory in directories:
|
||||
os.makedirs(directory, exist_ok=True)
|
||||
|
||||
@@ -42,33 +42,34 @@ def ensure_directories():
|
||||
def get_model_paths(experiment_type: str = "llm_only") -> dict:
|
||||
"""
|
||||
Возвращает пути для конкретного типа эксперимента.
|
||||
|
||||
|
||||
Args:
|
||||
experiment_type: Тип эксперимента ('llm_only' или 'hf_integration')
|
||||
|
||||
|
||||
Returns:
|
||||
dict: Словарь с путями
|
||||
"""
|
||||
base_paths = PATHS.copy()
|
||||
|
||||
|
||||
if experiment_type == "hf_integration":
|
||||
base_paths.update({
|
||||
"model": base_paths["hf_model"],
|
||||
"tokenizer": base_paths["hf_tokenizer"]
|
||||
})
|
||||
base_paths.update(
|
||||
{"model": base_paths["hf_model"], "tokenizer": base_paths["hf_tokenizer"]}
|
||||
)
|
||||
else: # llm_only
|
||||
base_paths.update({
|
||||
"model": base_paths["gpt_bpe_model"],
|
||||
"tokenizer": base_paths["bpe_tokenizer"]
|
||||
})
|
||||
|
||||
base_paths.update(
|
||||
{
|
||||
"model": base_paths["gpt_bpe_model"],
|
||||
"tokenizer": base_paths["bpe_tokenizer"],
|
||||
}
|
||||
)
|
||||
|
||||
return base_paths
|
||||
|
||||
|
||||
def print_experiment_info(experiment_name: str, config: dict):
|
||||
"""
|
||||
Выводит информацию о запускаемом эксперименте.
|
||||
|
||||
|
||||
Args:
|
||||
experiment_name: Название эксперимента
|
||||
config: Конфигурация эксперимента
|
||||
@@ -85,35 +86,35 @@ def print_experiment_info(experiment_name: str, config: dict):
|
||||
def save_experiment_results(results: dict, filepath: str):
|
||||
"""
|
||||
Сохраняет результаты эксперимента в файл.
|
||||
|
||||
|
||||
Args:
|
||||
results: Словарь с результатами
|
||||
filepath: Путь для сохранения
|
||||
"""
|
||||
import json
|
||||
|
||||
with open(filepath, 'w', encoding='utf-8') as f:
|
||||
|
||||
with open(filepath, "w", encoding="utf-8") as f:
|
||||
json.dump(results, f, ensure_ascii=False, indent=2)
|
||||
|
||||
|
||||
print(f"✅ Результаты эксперимента сохранены: {filepath}")
|
||||
|
||||
|
||||
def load_experiment_results(filepath: str) -> dict:
|
||||
"""
|
||||
Загружает результаты эксперимента из файла.
|
||||
|
||||
|
||||
Args:
|
||||
filepath: Путь к файлу с результатами
|
||||
|
||||
|
||||
Returns:
|
||||
dict: Загруженные результаты
|
||||
"""
|
||||
import json
|
||||
|
||||
|
||||
if not os.path.exists(filepath):
|
||||
return {}
|
||||
|
||||
with open(filepath, 'r', encoding='utf-8') as f:
|
||||
|
||||
with open(filepath, "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
@@ -121,42 +122,39 @@ class ExperimentLogger:
|
||||
"""
|
||||
Логгер для экспериментов.
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self, experiment_name: str):
|
||||
self.experiment_name = experiment_name
|
||||
self.metrics = {}
|
||||
|
||||
|
||||
def log_metric(self, name: str, value: float):
|
||||
"""Логирует метрику."""
|
||||
if name not in self.metrics:
|
||||
self.metrics[name] = []
|
||||
self.metrics[name].append(value)
|
||||
print(f"📈 {name}: {value:.4f}")
|
||||
|
||||
|
||||
def log_step(self, step: int, loss: float, **kwargs):
|
||||
"""Логирует шаг обучения."""
|
||||
print(f"📊 Step {step}: loss={loss:.4f}", end="")
|
||||
for key, value in kwargs.items():
|
||||
print(f", {key}={value:.4f}", end="")
|
||||
print()
|
||||
|
||||
|
||||
def log_epoch(self, epoch: int, train_loss: float, val_loss: float = None):
|
||||
"""Логирует завершение эпохи."""
|
||||
print(f"🎯 Epoch {epoch}: train_loss={train_loss:.4f}", end="")
|
||||
if val_loss is not None:
|
||||
print(f", val_loss={val_loss:.4f}", end="")
|
||||
print()
|
||||
|
||||
|
||||
def save_logs(self, filepath: str):
|
||||
"""Сохраняет логи эксперимента."""
|
||||
import json
|
||||
|
||||
logs = {
|
||||
"experiment_name": self.experiment_name,
|
||||
"metrics": self.metrics
|
||||
}
|
||||
|
||||
with open(filepath, 'w', encoding='utf-8') as f:
|
||||
|
||||
logs = {"experiment_name": self.experiment_name, "metrics": self.metrics}
|
||||
|
||||
with open(filepath, "w", encoding="utf-8") as f:
|
||||
json.dump(logs, f, ensure_ascii=False, indent=2)
|
||||
|
||||
|
||||
print(f"✅ Логи эксперимента сохранены: {filepath}")
|
||||
|
||||
@@ -27,16 +27,13 @@ __all__ = [
|
||||
# Основные классы адаптера
|
||||
"HFAdapter",
|
||||
"HFGPTAdapter",
|
||||
|
||||
# Конфигурации
|
||||
"HFAdapterConfig",
|
||||
"HFAdapterConfig",
|
||||
"HFPretrainedConfig",
|
||||
|
||||
# Адаптеры токенизаторов
|
||||
"HFTokenizerAdapter",
|
||||
"create_hf_tokenizer",
|
||||
"create_hf_tokenizer",
|
||||
"convert_to_hf_format",
|
||||
|
||||
# Утилиты
|
||||
"HFUtils",
|
||||
"TokenizerWrapper",
|
||||
|
||||
@@ -6,12 +6,12 @@ import torch
|
||||
import torch.nn as nn
|
||||
from typing import Optional, Tuple, Union, List
|
||||
from transformers import (
|
||||
PreTrainedModel,
|
||||
PreTrainedModel,
|
||||
GPT2LMHeadModel,
|
||||
GPT2Config,
|
||||
GenerationConfig,
|
||||
LogitsProcessorList,
|
||||
StoppingCriteriaList
|
||||
StoppingCriteriaList,
|
||||
)
|
||||
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
|
||||
|
||||
@@ -24,38 +24,39 @@ class HFGPTAdapter(PreTrainedModel):
|
||||
Адаптер для модели GPT из библиотеки llm.
|
||||
Позволяет использовать кастомные GPT модели с HuggingFace Transformers.
|
||||
"""
|
||||
|
||||
config_class = HFPretrainedConfig
|
||||
|
||||
|
||||
def __init__(self, config: HFPretrainedConfig, llm_model: Optional[GPT] = None):
|
||||
"""
|
||||
Инициализация адаптера.
|
||||
|
||||
|
||||
Args:
|
||||
config: Конфигурация HuggingFace
|
||||
llm_model: Опционально, предварительно созданная модель llm
|
||||
"""
|
||||
super().__init__(config)
|
||||
|
||||
|
||||
# Преобразуем HF конфигурацию в формат llm
|
||||
llm_config = self._hf_to_llm_config(config)
|
||||
|
||||
|
||||
# Создаем или используем переданную модель
|
||||
if llm_model is None:
|
||||
self.llm_model = GPT(llm_config)
|
||||
else:
|
||||
self.llm_model = llm_model
|
||||
|
||||
|
||||
# Устанавливаем веса если они есть в конфигурации
|
||||
if hasattr(config, 'state_dict') and config.state_dict is not None:
|
||||
if hasattr(config, "state_dict") and config.state_dict is not None:
|
||||
self.llm_model.load_state_dict(config.state_dict)
|
||||
|
||||
|
||||
def _hf_to_llm_config(self, hf_config: HFPretrainedConfig) -> dict:
|
||||
"""
|
||||
Преобразует конфигурацию HF в формат llm.
|
||||
|
||||
|
||||
Args:
|
||||
hf_config: Конфигурация HuggingFace
|
||||
|
||||
|
||||
Returns:
|
||||
dict: Конфигурация для llm модели
|
||||
"""
|
||||
@@ -67,7 +68,7 @@ class HFGPTAdapter(PreTrainedModel):
|
||||
"max_position_embeddings": hf_config.max_position_embeddings,
|
||||
"dropout": hf_config.hidden_dropout_prob,
|
||||
}
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
@@ -78,11 +79,11 @@ class HFGPTAdapter(PreTrainedModel):
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
|
||||
"""
|
||||
Прямой проход модели.
|
||||
|
||||
|
||||
Args:
|
||||
input_ids: Входные токены [batch_size, seq_len]
|
||||
attention_mask: Маска внимания [batch_size, seq_len]
|
||||
@@ -92,38 +93,39 @@ class HFGPTAdapter(PreTrainedModel):
|
||||
output_attentions: Возвращать веса внимания
|
||||
output_hidden_states: Возвращать скрытые состояния
|
||||
return_dict: Возвращать словарь вместо кортежа
|
||||
|
||||
|
||||
Returns:
|
||||
CausalLMOutputWithCrossAttentions или кортеж
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
|
||||
# Основной forward pass
|
||||
outputs = self.llm_model(input_ids)
|
||||
if isinstance(outputs, tuple):
|
||||
logits = outputs[0]
|
||||
else:
|
||||
logits = outputs
|
||||
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# Сдвигаем логиты и метки для языкового моделирования
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
|
||||
|
||||
# Вычисляем cross-entropy loss
|
||||
loss_fct = nn.CrossEntropyLoss()
|
||||
loss = loss_fct(
|
||||
shift_logits.view(-1, shift_logits.size(-1)),
|
||||
shift_labels.view(-1)
|
||||
shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
|
||||
)
|
||||
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,)
|
||||
if loss is not None:
|
||||
output = (loss,) + output
|
||||
return output
|
||||
|
||||
|
||||
return CausalLMOutputWithCrossAttentions(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
@@ -132,30 +134,27 @@ class HFGPTAdapter(PreTrainedModel):
|
||||
attentions=None,
|
||||
cross_attentions=None,
|
||||
)
|
||||
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
past_key_values: Optional[Tuple] = None,
|
||||
**kwargs
|
||||
self, input_ids: torch.Tensor, past_key_values: Optional[Tuple] = None, **kwargs
|
||||
) -> dict:
|
||||
"""
|
||||
Подготавливает входные данные для генерации.
|
||||
|
||||
|
||||
Args:
|
||||
input_ids: Входные токены
|
||||
past_key_values: Кешированные ключи и значения
|
||||
|
||||
|
||||
Returns:
|
||||
dict: Подготовленные входные данные
|
||||
"""
|
||||
# Наша простая реализация пока не поддерживает past_key_values
|
||||
return {"input_ids": input_ids}
|
||||
|
||||
|
||||
def can_generate(self) -> bool:
|
||||
"""Проверяет, может ли модель генерировать текст."""
|
||||
return True
|
||||
|
||||
|
||||
def generate(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
@@ -163,32 +162,32 @@ class HFGPTAdapter(PreTrainedModel):
|
||||
generation_config: Optional[GenerationConfig] = None,
|
||||
logits_processor: Optional[LogitsProcessorList] = None,
|
||||
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Генерация текста с поддержкой HuggingFace интерфейса.
|
||||
|
||||
|
||||
Args:
|
||||
input_ids: Входные токены
|
||||
attention_mask: Маска внимания
|
||||
generation_config: Конфигурация генерации
|
||||
logits_processor: Процессоры логитов
|
||||
stopping_criteria: Критерии остановки
|
||||
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Сгенерированные токены
|
||||
"""
|
||||
# Извлекаем обязательные параметры из kwargs или используем значения по умолчанию
|
||||
max_new_tokens = kwargs.pop('max_new_tokens', 50)
|
||||
do_sample = kwargs.pop('do_sample', True)
|
||||
|
||||
max_new_tokens = kwargs.pop("max_new_tokens", 50)
|
||||
do_sample = kwargs.pop("do_sample", True)
|
||||
|
||||
# Используем встроенную генерацию llm модели
|
||||
return self.llm_model.generate(
|
||||
x=input_ids,
|
||||
max_new_tokens=max_new_tokens,
|
||||
do_sample=do_sample,
|
||||
attention_mask=attention_mask,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
@@ -196,64 +195,66 @@ class HFAdapter:
|
||||
"""
|
||||
Основной класс адаптера для преобразования моделей llm в формат HuggingFace.
|
||||
"""
|
||||
|
||||
|
||||
@staticmethod
|
||||
def from_llm_model(
|
||||
llm_model: GPT,
|
||||
hf_config: Optional[HFAdapterConfig] = None
|
||||
llm_model: GPT, hf_config: Optional[HFAdapterConfig] = None
|
||||
) -> HFGPTAdapter:
|
||||
"""
|
||||
Создает адаптер из существующей llm модели.
|
||||
|
||||
|
||||
Args:
|
||||
llm_model: Обученная модель из библиотеки llm
|
||||
hf_config: Конфигурация для HuggingFace
|
||||
|
||||
|
||||
Returns:
|
||||
HFGPTAdapter: Адаптированная модель
|
||||
"""
|
||||
if hf_config is None:
|
||||
# Создаем конфигурацию из модели llm
|
||||
hf_config = HFAdapterConfig.from_llm_config(llm_model.config)
|
||||
|
||||
|
||||
# Преобразуем в PretrainedConfig
|
||||
pretrained_config = HFPretrainedConfig(**hf_config.to_dict())
|
||||
|
||||
|
||||
return HFGPTAdapter(pretrained_config, llm_model)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def from_pretrained(
|
||||
model_path: str,
|
||||
hf_config: Optional[HFAdapterConfig] = None
|
||||
model_path: str, hf_config: Optional[HFAdapterConfig] = None
|
||||
) -> HFGPTAdapter:
|
||||
"""
|
||||
Загружает модель из чекпоинта и создает адаптер.
|
||||
|
||||
|
||||
Args:
|
||||
model_path: Путь к сохраненной модели
|
||||
hf_config: Конфигурация для HuggingFace
|
||||
|
||||
|
||||
Returns:
|
||||
HFGPTAdapter: Адаптированная модель
|
||||
"""
|
||||
# Загружаем состояние модели
|
||||
state_dict = torch.load(model_path, map_location='cpu')
|
||||
|
||||
state_dict = torch.load(model_path, map_location="cpu")
|
||||
|
||||
# Определяем конфигурацию из состояния модели или используем переданную
|
||||
if hf_config is None:
|
||||
# Пытаемся определить конфигурацию из состояния модели
|
||||
# Это упрощенный подход - в реальности нужно сохранять конфигурацию отдельно
|
||||
vocab_size = state_dict.get('_token_embeddings._embedding.weight', torch.zeros(50257, 768)).shape[0]
|
||||
embed_dim = state_dict.get('_token_embeddings._embedding.weight', torch.zeros(50257, 768)).shape[1]
|
||||
|
||||
vocab_size = state_dict.get(
|
||||
"_token_embeddings._embedding.weight", torch.zeros(50257, 768)
|
||||
).shape[0]
|
||||
embed_dim = state_dict.get(
|
||||
"_token_embeddings._embedding.weight", torch.zeros(50257, 768)
|
||||
).shape[1]
|
||||
|
||||
hf_config = HFAdapterConfig(
|
||||
vocab_size=vocab_size,
|
||||
hidden_size=embed_dim,
|
||||
# Остальные параметры можно установить по умолчанию
|
||||
)
|
||||
|
||||
|
||||
pretrained_config = HFPretrainedConfig(**hf_config.to_dict())
|
||||
|
||||
|
||||
# Создаем модель llm и загружаем веса
|
||||
llm_config = {
|
||||
"vocab_size": hf_config.vocab_size,
|
||||
@@ -263,21 +264,17 @@ class HFAdapter:
|
||||
"max_position_embeddings": hf_config.max_position_embeddings,
|
||||
"dropout": hf_config.hidden_dropout_prob,
|
||||
}
|
||||
|
||||
|
||||
llm_model = GPT(llm_config)
|
||||
llm_model.load_state_dict(state_dict)
|
||||
|
||||
|
||||
return HFGPTAdapter(pretrained_config, llm_model)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def save_pretrained(
|
||||
model: HFGPTAdapter,
|
||||
save_directory: str,
|
||||
**kwargs
|
||||
):
|
||||
def save_pretrained(model: HFGPTAdapter, save_directory: str, **kwargs):
|
||||
"""
|
||||
Сохраняет адаптированную модель в формате HuggingFace.
|
||||
|
||||
|
||||
Args:
|
||||
model: Адаптированная модель
|
||||
save_directory: Директория для сохранения
|
||||
@@ -285,19 +282,19 @@ class HFAdapter:
|
||||
"""
|
||||
import os
|
||||
import json
|
||||
|
||||
|
||||
# Создаем директорию если не существует
|
||||
os.makedirs(save_directory, exist_ok=True)
|
||||
|
||||
|
||||
# Сохраняем конфигурацию
|
||||
config_path = os.path.join(save_directory, "config.json")
|
||||
with open(config_path, 'w', encoding='utf-8') as f:
|
||||
with open(config_path, "w", encoding="utf-8") as f:
|
||||
json.dump(model.config.to_dict(), f, indent=2, ensure_ascii=False)
|
||||
|
||||
|
||||
# Сохраняем веса модели
|
||||
model_path = os.path.join(save_directory, "pytorch_model.bin")
|
||||
torch.save(model.llm_model.state_dict(), model_path)
|
||||
|
||||
|
||||
# Сохраняем токенизатор если передан
|
||||
if hasattr(kwargs, 'tokenizer') and kwargs['tokenizer'] is not None:
|
||||
kwargs['tokenizer'].save_pretrained(save_directory)
|
||||
if hasattr(kwargs, "tokenizer") and kwargs["tokenizer"] is not None:
|
||||
kwargs["tokenizer"].save_pretrained(save_directory)
|
||||
|
||||
@@ -6,11 +6,12 @@ from dataclasses import dataclass, field
|
||||
from typing import Dict, Any, Optional
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class HFAdapterConfig:
|
||||
"""
|
||||
Конфигурация для адаптера HuggingFace.
|
||||
|
||||
|
||||
Параметры:
|
||||
model_type: Тип модели (gpt, llama, etc.)
|
||||
vocab_size: Размер словаря
|
||||
@@ -28,6 +29,7 @@ class HFAdapterConfig:
|
||||
eos_token_id: ID токена конца строки
|
||||
bos_token_id: ID токена начала строки
|
||||
"""
|
||||
|
||||
model_type: str = "gpt"
|
||||
vocab_size: int = 50257
|
||||
hidden_size: int = 768
|
||||
@@ -43,49 +45,50 @@ class HFAdapterConfig:
|
||||
pad_token_id: int = 50256
|
||||
eos_token_id: int = 50256
|
||||
bos_token_id: int = 50256
|
||||
|
||||
|
||||
# Дополнительные параметры для совместимости
|
||||
architectures: list = field(default_factory=lambda: ["GPT2LMHeadModel"])
|
||||
torch_dtype: str = "float32"
|
||||
transformers_version: str = "4.44.0"
|
||||
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Преобразует конфигурацию в словарь."""
|
||||
return {
|
||||
k: v for k, v in self.__dict__.items()
|
||||
if not k.startswith('_') and not callable(v)
|
||||
k: v
|
||||
for k, v in self.__dict__.items()
|
||||
if not k.startswith("_") and not callable(v)
|
||||
}
|
||||
|
||||
|
||||
@classmethod
|
||||
def from_llm_config(cls, llm_config: Dict[str, Any]) -> "HFAdapterConfig":
|
||||
"""
|
||||
Создает конфигурацию HF из конфигурации llm.
|
||||
|
||||
|
||||
Args:
|
||||
llm_config: Конфигурация модели из библиотеки llm
|
||||
|
||||
|
||||
Returns:
|
||||
HFAdapterConfig: Конфигурация для HuggingFace
|
||||
"""
|
||||
# Маппинг параметров из llm в HF формат
|
||||
mapping = {
|
||||
"embed_dim": "hidden_size",
|
||||
"num_layers": "num_hidden_layers",
|
||||
"num_layers": "num_hidden_layers",
|
||||
"num_heads": "num_attention_heads",
|
||||
"max_position_embeddings": "max_position_embeddings",
|
||||
"dropout": "hidden_dropout_prob",
|
||||
"vocab_size": "vocab_size"
|
||||
"vocab_size": "vocab_size",
|
||||
}
|
||||
|
||||
|
||||
hf_config_dict = {}
|
||||
for llm_key, hf_key in mapping.items():
|
||||
if llm_key in llm_config:
|
||||
hf_config_dict[hf_key] = llm_config[llm_key]
|
||||
|
||||
|
||||
# Устанавливаем промежуточный размер (обычно 4x hidden_size)
|
||||
if "hidden_size" in hf_config_dict:
|
||||
hf_config_dict["intermediate_size"] = hf_config_dict["hidden_size"] * 4
|
||||
|
||||
|
||||
return cls(**hf_config_dict)
|
||||
|
||||
|
||||
@@ -94,8 +97,9 @@ class HFPretrainedConfig(PretrainedConfig):
|
||||
Конфигурация для предобученных моделей HuggingFace.
|
||||
Наследуется от PretrainedConfig для полной совместимости.
|
||||
"""
|
||||
|
||||
model_type = "gpt"
|
||||
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=50257,
|
||||
@@ -112,15 +116,15 @@ class HFPretrainedConfig(PretrainedConfig):
|
||||
pad_token_id=50256,
|
||||
eos_token_id=50256,
|
||||
bos_token_id=50256,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
pad_token_id=pad_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
|
||||
@@ -12,84 +12,82 @@ class HFTokenizerAdapter:
|
||||
Упрощенный адаптер для кастомных токенизаторов llm.
|
||||
Предоставляет совместимый с HuggingFace интерфейс.
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self, llm_tokenizer: BaseTokenizer):
|
||||
"""
|
||||
Инициализация адаптера.
|
||||
|
||||
|
||||
Args:
|
||||
llm_tokenizer: Кастомный токенизатор из llm
|
||||
"""
|
||||
self.llm_tokenizer = llm_tokenizer
|
||||
|
||||
|
||||
# Получаем словарь и размер
|
||||
self._vocab = llm_tokenizer.get_vocab()
|
||||
self.vocab_size = llm_tokenizer.get_vocab_size()
|
||||
|
||||
|
||||
# Устанавливаем специальные токены
|
||||
self.pad_token = getattr(llm_tokenizer, 'pad_token', '<pad>')
|
||||
self.unk_token = getattr(llm_tokenizer, 'unk_token', '<unk>')
|
||||
self.bos_token = getattr(llm_tokenizer, 'bos_token', '<bos>')
|
||||
self.eos_token = getattr(llm_tokenizer, 'eos_token', '<eos>')
|
||||
|
||||
self.pad_token = getattr(llm_tokenizer, "pad_token", "<pad>")
|
||||
self.unk_token = getattr(llm_tokenizer, "unk_token", "<unk>")
|
||||
self.bos_token = getattr(llm_tokenizer, "bos_token", "<bos>")
|
||||
self.eos_token = getattr(llm_tokenizer, "eos_token", "<eos>")
|
||||
|
||||
# Сохраняем ID специальных токенов
|
||||
self.pad_token_id = getattr(llm_tokenizer, 'pad_token_id', 0)
|
||||
self.unk_token_id = getattr(llm_tokenizer, 'unk_token_id', 1)
|
||||
self.bos_token_id = getattr(llm_tokenizer, 'bos_token_id', 2)
|
||||
self.eos_token_id = getattr(llm_tokenizer, 'eos_token_id', 3)
|
||||
|
||||
self.pad_token_id = getattr(llm_tokenizer, "pad_token_id", 0)
|
||||
self.unk_token_id = getattr(llm_tokenizer, "unk_token_id", 1)
|
||||
self.bos_token_id = getattr(llm_tokenizer, "bos_token_id", 2)
|
||||
self.eos_token_id = getattr(llm_tokenizer, "eos_token_id", 3)
|
||||
|
||||
def __call__(self, text: str, **kwargs):
|
||||
"""
|
||||
Вызов токенизатора с параметрами как у HuggingFace.
|
||||
|
||||
|
||||
Args:
|
||||
text: Входной текст
|
||||
**kwargs: Параметры токенизации
|
||||
|
||||
|
||||
Returns:
|
||||
dict: Словарь с токенами
|
||||
"""
|
||||
return_tensors = kwargs.get('return_tensors', None)
|
||||
padding = kwargs.get('padding', False)
|
||||
truncation = kwargs.get('truncation', False)
|
||||
max_length = kwargs.get('max_length', None)
|
||||
add_special_tokens = kwargs.get('add_special_tokens', True)
|
||||
|
||||
return_tensors = kwargs.get("return_tensors", None)
|
||||
padding = kwargs.get("padding", False)
|
||||
truncation = kwargs.get("truncation", False)
|
||||
max_length = kwargs.get("max_length", None)
|
||||
add_special_tokens = kwargs.get("add_special_tokens", True)
|
||||
|
||||
# Кодируем текст
|
||||
#input_ids = self.llm_tokenizer.encode(
|
||||
# text,
|
||||
# input_ids = self.llm_tokenizer.encode(
|
||||
# text,
|
||||
# add_special_tokens=add_special_tokens
|
||||
#)
|
||||
# )
|
||||
if isinstance(text, str):
|
||||
input_ids = self.llm_tokenizer.encode(
|
||||
text,
|
||||
add_special_tokens=add_special_tokens
|
||||
text, add_special_tokens=add_special_tokens
|
||||
)
|
||||
input_ids = [input_ids] # <-- оборачиваем в batch
|
||||
else:
|
||||
# Список строк, батч-режим!
|
||||
input_ids = [
|
||||
self.llm_tokenizer.encode(
|
||||
t,
|
||||
add_special_tokens=add_special_tokens
|
||||
) for t in text
|
||||
self.llm_tokenizer.encode(t, add_special_tokens=add_special_tokens)
|
||||
for t in text
|
||||
]
|
||||
|
||||
|
||||
# Применяем truncation
|
||||
if truncation and max_length is not None and len(input_ids) > max_length:
|
||||
input_ids = input_ids[:max_length]
|
||||
|
||||
|
||||
# Применяем padding
|
||||
if padding and max_length is not None and len(input_ids) < max_length:
|
||||
input_ids = input_ids + [self.pad_token_id] * (max_length - len(input_ids))
|
||||
|
||||
|
||||
# Конвертируем в тензоры если нужно
|
||||
if return_tensors == "pt":
|
||||
import torch
|
||||
|
||||
input_ids = torch.tensor([input_ids])
|
||||
|
||||
|
||||
return {"input_ids": input_ids}
|
||||
|
||||
|
||||
def encode(
|
||||
self,
|
||||
text: str,
|
||||
@@ -99,11 +97,11 @@ class HFTokenizerAdapter:
|
||||
truncation: bool = False,
|
||||
max_length: Optional[int] = None,
|
||||
return_tensors: Optional[str] = None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
) -> Union[List[int], List[List[int]]]:
|
||||
"""
|
||||
Кодирует текст в последовательность токенов.
|
||||
|
||||
|
||||
Args:
|
||||
text: Входной текст
|
||||
text_pair: Второй текст (для парных задач)
|
||||
@@ -112,84 +110,91 @@ class HFTokenizerAdapter:
|
||||
truncation: Обрезать последовательность
|
||||
max_length: Максимальная длина
|
||||
return_tensors: Возвращать тензоры
|
||||
|
||||
|
||||
Returns:
|
||||
Список токенов или список списков токенов
|
||||
"""
|
||||
# Кодируем основной текст
|
||||
token_ids = self.llm_tokenizer.encode(
|
||||
text,
|
||||
add_special_tokens=add_special_tokens
|
||||
text, add_special_tokens=add_special_tokens
|
||||
)
|
||||
|
||||
|
||||
# Обрабатываем text_pair если есть
|
||||
if text_pair is not None:
|
||||
pair_ids = self.llm_tokenizer.encode(
|
||||
text_pair,
|
||||
add_special_tokens=False
|
||||
)
|
||||
pair_ids = self.llm_tokenizer.encode(text_pair, add_special_tokens=False)
|
||||
token_ids.extend(pair_ids)
|
||||
|
||||
|
||||
# Применяем truncation
|
||||
if truncation and max_length is not None and len(token_ids) > max_length:
|
||||
token_ids = token_ids[:max_length]
|
||||
|
||||
|
||||
# Применяем padding
|
||||
if padding and max_length is not None and len(token_ids) < max_length:
|
||||
token_ids = token_ids + [self.pad_token_id] * (max_length - len(token_ids))
|
||||
|
||||
|
||||
# Конвертируем в тензоры если нужно
|
||||
if return_tensors == "pt":
|
||||
import torch
|
||||
|
||||
return torch.tensor([token_ids])
|
||||
elif return_tensors == "np":
|
||||
import numpy as np
|
||||
|
||||
return np.array([token_ids])
|
||||
|
||||
|
||||
return token_ids
|
||||
|
||||
|
||||
def decode(
|
||||
self,
|
||||
token_ids: Union[int, List[int], List[List[int]]],
|
||||
skip_special_tokens: bool = True,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
) -> str:
|
||||
"""
|
||||
Декодирует последовательность токенов в текст.
|
||||
|
||||
|
||||
Args:
|
||||
token_ids: ID токенов
|
||||
skip_special_tokens: Пропускать специальные токены
|
||||
|
||||
|
||||
Returns:
|
||||
str: Декодированный текст
|
||||
"""
|
||||
# Обрабатываем разные форматы входных данных
|
||||
if isinstance(token_ids, int):
|
||||
token_ids = [token_ids]
|
||||
elif isinstance(token_ids, list) and len(token_ids) > 0 and isinstance(token_ids[0], list):
|
||||
elif (
|
||||
isinstance(token_ids, list)
|
||||
and len(token_ids) > 0
|
||||
and isinstance(token_ids[0], list)
|
||||
):
|
||||
# Список списков - берем первый элемент
|
||||
token_ids = token_ids[0]
|
||||
|
||||
|
||||
# Фильтруем специальные токены если нужно
|
||||
if skip_special_tokens:
|
||||
special_ids = {self.pad_token_id, self.unk_token_id, self.bos_token_id, self.eos_token_id}
|
||||
special_ids = {
|
||||
self.pad_token_id,
|
||||
self.unk_token_id,
|
||||
self.bos_token_id,
|
||||
self.eos_token_id,
|
||||
}
|
||||
token_ids = [tid for tid in token_ids if tid not in special_ids]
|
||||
|
||||
|
||||
return self.llm_tokenizer.decode(token_ids)
|
||||
|
||||
|
||||
def tokenize(self, text: str, **kwargs) -> List[str]:
|
||||
"""
|
||||
Токенизирует текст в список строковых токенов.
|
||||
|
||||
|
||||
Args:
|
||||
text: Входной текст
|
||||
|
||||
|
||||
Returns:
|
||||
List[str]: Список токенов
|
||||
"""
|
||||
return self.llm_tokenizer.tokenize(text)
|
||||
|
||||
|
||||
def pad(
|
||||
self,
|
||||
encoded_inputs,
|
||||
@@ -202,7 +207,7 @@ class HFTokenizerAdapter:
|
||||
):
|
||||
"""
|
||||
Pad a list of encoded inputs.
|
||||
|
||||
|
||||
Args:
|
||||
encoded_inputs: List of encoded inputs
|
||||
padding: Padding strategy
|
||||
@@ -211,7 +216,7 @@ class HFTokenizerAdapter:
|
||||
return_attention_mask: Return attention mask
|
||||
return_tensors: Return tensors
|
||||
verbose: Verbose mode
|
||||
|
||||
|
||||
Returns:
|
||||
Padded inputs
|
||||
"""
|
||||
@@ -224,47 +229,62 @@ class HFTokenizerAdapter:
|
||||
# Обрабатываем разные типы данных
|
||||
if isinstance(input_ids, int):
|
||||
seq_len = 1
|
||||
elif hasattr(input_ids, 'shape'):
|
||||
seq_len = input_ids.shape[-1] if len(input_ids.shape) > 1 else len(input_ids)
|
||||
elif hasattr(input_ids, "shape"):
|
||||
seq_len = (
|
||||
input_ids.shape[-1]
|
||||
if len(input_ids.shape) > 1
|
||||
else len(input_ids)
|
||||
)
|
||||
else:
|
||||
seq_len = len(input_ids)
|
||||
max_len = max(max_len, seq_len)
|
||||
|
||||
|
||||
if max_length is not None:
|
||||
max_len = min(max_len, max_length)
|
||||
|
||||
|
||||
# Применяем padding
|
||||
for item in encoded_inputs:
|
||||
input_ids = item["input_ids"]
|
||||
|
||||
|
||||
# Получаем текущую длину
|
||||
if isinstance(input_ids, int):
|
||||
current_len = 1
|
||||
elif hasattr(input_ids, 'shape'):
|
||||
current_len = input_ids.shape[-1] if len(input_ids.shape) > 1 else len(input_ids)
|
||||
elif hasattr(input_ids, "shape"):
|
||||
current_len = (
|
||||
input_ids.shape[-1]
|
||||
if len(input_ids.shape) > 1
|
||||
else len(input_ids)
|
||||
)
|
||||
else:
|
||||
current_len = len(input_ids)
|
||||
|
||||
|
||||
if current_len < max_len:
|
||||
# Дополняем pad_token_id
|
||||
padding_length = max_len - current_len
|
||||
|
||||
|
||||
# Обрабатываем разные типы данных
|
||||
if isinstance(input_ids, int):
|
||||
item["input_ids"] = [input_ids] + [self.pad_token_id] * padding_length
|
||||
elif hasattr(input_ids, 'shape'):
|
||||
item["input_ids"] = [input_ids] + [
|
||||
self.pad_token_id
|
||||
] * padding_length
|
||||
elif hasattr(input_ids, "shape"):
|
||||
import torch
|
||||
padding_tensor = torch.full((padding_length,), self.pad_token_id, dtype=input_ids.dtype)
|
||||
|
||||
padding_tensor = torch.full(
|
||||
(padding_length,), self.pad_token_id, dtype=input_ids.dtype
|
||||
)
|
||||
item["input_ids"] = torch.cat([input_ids, padding_tensor])
|
||||
else:
|
||||
item["input_ids"] = input_ids + [self.pad_token_id] * padding_length
|
||||
|
||||
item["input_ids"] = (
|
||||
input_ids + [self.pad_token_id] * padding_length
|
||||
)
|
||||
|
||||
# Добавляем attention_mask если требуется
|
||||
if "attention_mask" in item:
|
||||
mask = item["attention_mask"]
|
||||
if isinstance(mask, int):
|
||||
item["attention_mask"] = [mask] + [0] * padding_length
|
||||
elif hasattr(mask, 'shape'):
|
||||
elif hasattr(mask, "shape"):
|
||||
padding_mask = torch.zeros(padding_length, dtype=mask.dtype)
|
||||
item["attention_mask"] = torch.cat([mask, padding_mask])
|
||||
else:
|
||||
@@ -272,44 +292,49 @@ class HFTokenizerAdapter:
|
||||
elif return_attention_mask:
|
||||
if isinstance(input_ids, int):
|
||||
item["attention_mask"] = [1] + [0] * padding_length
|
||||
elif hasattr(input_ids, 'shape'):
|
||||
elif hasattr(input_ids, "shape"):
|
||||
attention_mask = torch.ones(current_len, dtype=torch.long)
|
||||
padding_mask = torch.zeros(padding_length, dtype=torch.long)
|
||||
item["attention_mask"] = torch.cat([attention_mask, padding_mask])
|
||||
item["attention_mask"] = torch.cat(
|
||||
[attention_mask, padding_mask]
|
||||
)
|
||||
else:
|
||||
item["attention_mask"] = [1] * current_len + [0] * padding_length
|
||||
|
||||
item["attention_mask"] = [1] * current_len + [
|
||||
0
|
||||
] * padding_length
|
||||
|
||||
# Конвертируем в тензоры если требуется
|
||||
if return_tensors == "pt":
|
||||
import torch
|
||||
|
||||
for key in list(encoded_inputs[0].keys()):
|
||||
if isinstance(encoded_inputs[0][key], list):
|
||||
for i in range(len(encoded_inputs)):
|
||||
encoded_inputs[i][key] = torch.tensor(encoded_inputs[i][key])
|
||||
|
||||
|
||||
return encoded_inputs
|
||||
|
||||
|
||||
def get_vocab(self) -> Dict[str, int]:
|
||||
"""Возвращает словарь токенизатора."""
|
||||
return self._vocab
|
||||
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Возвращает размер словаря."""
|
||||
return self.vocab_size
|
||||
|
||||
|
||||
def save_pretrained(self, save_directory: str, **kwargs):
|
||||
"""
|
||||
Сохраняет токенизатор в формате HuggingFace.
|
||||
|
||||
|
||||
Args:
|
||||
save_directory: Директория для сохранения
|
||||
**kwargs: Дополнительные параметры
|
||||
"""
|
||||
import os
|
||||
|
||||
|
||||
# Создаем директорию если не существует
|
||||
os.makedirs(save_directory, exist_ok=True)
|
||||
|
||||
|
||||
# Сохраняем конфигурацию токенизатора
|
||||
tokenizer_config = {
|
||||
"tokenizer_class": self.__class__.__name__,
|
||||
@@ -324,77 +349,81 @@ class HFTokenizerAdapter:
|
||||
"bos_token_id": self.bos_token_id,
|
||||
"eos_token_id": self.eos_token_id,
|
||||
}
|
||||
|
||||
|
||||
config_path = os.path.join(save_directory, "tokenizer_config.json")
|
||||
with open(config_path, 'w', encoding='utf-8') as f:
|
||||
with open(config_path, "w", encoding="utf-8") as f:
|
||||
json.dump(tokenizer_config, f, ensure_ascii=False, indent=2)
|
||||
|
||||
|
||||
# Сохраняем словарь
|
||||
vocab_path = os.path.join(save_directory, "vocab.json")
|
||||
with open(vocab_path, 'w', encoding='utf-8') as f:
|
||||
with open(vocab_path, "w", encoding="utf-8") as f:
|
||||
json.dump(self._vocab, f, ensure_ascii=False, indent=2)
|
||||
|
||||
|
||||
print(f"✅ Токенизатор сохранен в {save_directory}")
|
||||
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs):
|
||||
"""
|
||||
Загружает адаптированный токенизатор.
|
||||
|
||||
|
||||
Args:
|
||||
pretrained_model_name_or_path: Путь к сохраненному токенизатору
|
||||
**kwargs: Дополнительные параметры
|
||||
|
||||
|
||||
Returns:
|
||||
HFTokenizerAdapter: Загруженный адаптер
|
||||
"""
|
||||
import os
|
||||
|
||||
|
||||
# Проверяем, является ли путь директорией с файлами токенизатора
|
||||
if os.path.isdir(pretrained_model_name_or_path):
|
||||
# Загружаем из директории
|
||||
config_path = os.path.join(pretrained_model_name_or_path, "tokenizer_config.json")
|
||||
config_path = os.path.join(
|
||||
pretrained_model_name_or_path, "tokenizer_config.json"
|
||||
)
|
||||
vocab_path = os.path.join(pretrained_model_name_or_path, "vocab.json")
|
||||
|
||||
|
||||
if not os.path.exists(config_path) or not os.path.exists(vocab_path):
|
||||
raise FileNotFoundError(
|
||||
f"Файлы токенизатора не найдены в {pretrained_model_name_or_path}"
|
||||
)
|
||||
|
||||
|
||||
# Загружаем конфигурацию
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
config = json.load(f)
|
||||
|
||||
|
||||
# Определяем тип токенизатора llm
|
||||
llm_tokenizer_type = config.get("llm_tokenizer_type", "BPETokenizer")
|
||||
|
||||
|
||||
if llm_tokenizer_type == "BPETokenizer":
|
||||
# Создаем BPETokenizer и загружаем словарь
|
||||
llm_tokenizer = BPETokenizer()
|
||||
|
||||
|
||||
# Загружаем словарь
|
||||
with open(vocab_path, 'r', encoding='utf-8') as f:
|
||||
with open(vocab_path, "r", encoding="utf-8") as f:
|
||||
vocab = json.load(f)
|
||||
|
||||
|
||||
llm_tokenizer.vocab = vocab
|
||||
llm_tokenizer.inverse_vocab = {v: k for k, v in vocab.items()}
|
||||
llm_tokenizer.vocab_size = len(vocab)
|
||||
|
||||
|
||||
# Устанавливаем специальные токены
|
||||
llm_tokenizer.pad_token = config.get("pad_token", "<pad>")
|
||||
llm_tokenizer.unk_token = config.get("unk_token", "<unk>")
|
||||
llm_tokenizer.bos_token = config.get("bos_token", "<bos>")
|
||||
llm_tokenizer.eos_token = config.get("eos_token", "<eos>")
|
||||
|
||||
|
||||
llm_tokenizer.pad_token_id = config.get("pad_token_id", 0)
|
||||
llm_tokenizer.unk_token_id = config.get("unk_token_id", 1)
|
||||
llm_tokenizer.bos_token_id = config.get("bos_token_id", 2)
|
||||
llm_tokenizer.eos_token_id = config.get("eos_token_id", 3)
|
||||
|
||||
|
||||
return cls(llm_tokenizer, **kwargs)
|
||||
else:
|
||||
raise ValueError(f"Неподдерживаемый тип токенизатора: {llm_tokenizer_type}")
|
||||
|
||||
raise ValueError(
|
||||
f"Неподдерживаемый тип токенизатора: {llm_tokenizer_type}"
|
||||
)
|
||||
|
||||
else:
|
||||
# Пытаемся загрузить как файл llm токенизатора
|
||||
try:
|
||||
@@ -409,10 +438,10 @@ class HFTokenizerAdapter:
|
||||
def create_hf_tokenizer(llm_tokenizer: BaseTokenizer) -> HFTokenizerAdapter:
|
||||
"""
|
||||
Создает адаптер HuggingFace для кастомного токенизатора.
|
||||
|
||||
|
||||
Args:
|
||||
llm_tokenizer: Токенизатор из библиотеки llm
|
||||
|
||||
|
||||
Returns:
|
||||
HFTokenizerAdapter: Адаптированный токенизатор
|
||||
"""
|
||||
@@ -422,7 +451,7 @@ def create_hf_tokenizer(llm_tokenizer: BaseTokenizer) -> HFTokenizerAdapter:
|
||||
def convert_to_hf_format(llm_tokenizer: BaseTokenizer, save_directory: str):
|
||||
"""
|
||||
Конвертирует кастомный токенизатор в формат HuggingFace.
|
||||
|
||||
|
||||
Args:
|
||||
llm_tokenizer: Токенизатор из llm
|
||||
save_directory: Директория для сохранения
|
||||
|
||||
@@ -14,55 +14,57 @@ class HFUtils:
|
||||
"""
|
||||
Утилиты для работы с HuggingFace адаптером.
|
||||
"""
|
||||
|
||||
|
||||
@staticmethod
|
||||
def create_hf_config_from_llm(llm_config: Dict[str, Any]) -> HFPretrainedConfig:
|
||||
"""
|
||||
Создает конфигурацию HuggingFace из конфигурации llm.
|
||||
|
||||
|
||||
Args:
|
||||
llm_config: Конфигурация модели из библиотеки llm
|
||||
|
||||
|
||||
Returns:
|
||||
HFPretrainedConfig: Конфигурация для HuggingFace
|
||||
"""
|
||||
adapter_config = HFAdapterConfig.from_llm_config(llm_config)
|
||||
return HFPretrainedConfig(**adapter_config.to_dict())
|
||||
|
||||
|
||||
@staticmethod
|
||||
def convert_to_hf_format(
|
||||
llm_model,
|
||||
tokenizer = None,
|
||||
model_name: str = "custom-gpt"
|
||||
llm_model, tokenizer=None, model_name: str = "custom-gpt"
|
||||
) -> tuple:
|
||||
"""
|
||||
Конвертирует llm модель в формат HuggingFace.
|
||||
|
||||
|
||||
Args:
|
||||
llm_model: Модель из библиотеки llm
|
||||
tokenizer: Токенизатор (HF или кастомный)
|
||||
model_name: Имя модели для сохранения
|
||||
|
||||
|
||||
Returns:
|
||||
tuple: (адаптированная модель, токенизатор)
|
||||
"""
|
||||
# Создаем адаптер
|
||||
hf_model = HFAdapter.from_llm_model(llm_model)
|
||||
|
||||
|
||||
# Если токенизатор не передан, создаем стандартный
|
||||
if tokenizer is None:
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
||||
# Устанавливаем специальные токены
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
elif hasattr(tokenizer, '__class__') and 'BPETokenizer' in str(tokenizer.__class__):
|
||||
elif hasattr(tokenizer, "__class__") and "BPETokenizer" in str(
|
||||
tokenizer.__class__
|
||||
):
|
||||
# Если передан наш кастомный токенизатор, создаем адаптер
|
||||
from .hf_tokenizer import create_hf_tokenizer
|
||||
|
||||
tokenizer = create_hf_tokenizer(tokenizer)
|
||||
|
||||
|
||||
return hf_model, tokenizer
|
||||
|
||||
|
||||
@staticmethod
|
||||
def push_to_hub(
|
||||
model: HFGPTAdapter,
|
||||
@@ -70,11 +72,11 @@ class HFUtils:
|
||||
repo_name: str,
|
||||
organization: Optional[str] = None,
|
||||
private: bool = False,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Загружает модель в HuggingFace Hub.
|
||||
|
||||
|
||||
Args:
|
||||
model: Адаптированная модель
|
||||
tokenizer: Токенизатор
|
||||
@@ -85,23 +87,23 @@ class HFUtils:
|
||||
"""
|
||||
try:
|
||||
from huggingface_hub import HfApi, ModelCard, create_repo
|
||||
|
||||
|
||||
# Создаем репозиторий
|
||||
if organization:
|
||||
repo_id = f"{organization}/{repo_name}"
|
||||
else:
|
||||
repo_id = repo_name
|
||||
|
||||
|
||||
create_repo(repo_id, private=private, exist_ok=True)
|
||||
|
||||
|
||||
# Сохраняем модель локально
|
||||
import tempfile
|
||||
import os
|
||||
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
# Сохраняем модель
|
||||
HFAdapter.save_pretrained(model, tmp_dir, tokenizer=tokenizer)
|
||||
|
||||
|
||||
# Создаем Model Card
|
||||
card = ModelCard.from_template(
|
||||
model_name=repo_name,
|
||||
@@ -110,46 +112,43 @@ class HFUtils:
|
||||
tags=["llm", "gpt", "custom"],
|
||||
)
|
||||
card.save(os.path.join(tmp_dir, "README.md"))
|
||||
|
||||
|
||||
# Загружаем в Hub
|
||||
api = HfApi()
|
||||
api.upload_folder(
|
||||
folder_path=tmp_dir,
|
||||
repo_id=repo_id,
|
||||
commit_message="Initial commit with custom GPT model"
|
||||
commit_message="Initial commit with custom GPT model",
|
||||
)
|
||||
|
||||
|
||||
print(f"✅ Модель успешно загружена в HuggingFace Hub: {repo_id}")
|
||||
|
||||
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Для загрузки в HuggingFace Hub установите huggingface_hub: "
|
||||
"pip install huggingface_hub"
|
||||
)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def load_from_hub(
|
||||
repo_id: str,
|
||||
**kwargs
|
||||
) -> tuple:
|
||||
def load_from_hub(repo_id: str, **kwargs) -> tuple:
|
||||
"""
|
||||
Загружает модель из HuggingFace Hub.
|
||||
|
||||
|
||||
Args:
|
||||
repo_id: ID репозитория
|
||||
**kwargs: Дополнительные параметры
|
||||
|
||||
|
||||
Returns:
|
||||
tuple: (модель, токенизатор)
|
||||
"""
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
|
||||
# Загружаем токенизатор
|
||||
tokenizer = AutoTokenizer.from_pretrained(repo_id, **kwargs)
|
||||
|
||||
|
||||
# Загружаем конфигурацию
|
||||
config = AutoConfig.from_pretrained(repo_id, **kwargs)
|
||||
|
||||
|
||||
# Создаем модель llm на основе конфигурации
|
||||
llm_config = {
|
||||
"vocab_size": config.vocab_size,
|
||||
@@ -159,63 +158,56 @@ class HFUtils:
|
||||
"max_position_embeddings": config.max_position_embeddings,
|
||||
"dropout": config.hidden_dropout_prob,
|
||||
}
|
||||
|
||||
|
||||
# Загружаем модель через адаптер
|
||||
model = HFAdapter.from_pretrained(
|
||||
f"{repo_id}/pytorch_model.bin",
|
||||
HFAdapterConfig.from_llm_config(llm_config)
|
||||
f"{repo_id}/pytorch_model.bin", HFAdapterConfig.from_llm_config(llm_config)
|
||||
)
|
||||
|
||||
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
@staticmethod
|
||||
def compare_with_hf_model(
|
||||
llm_model,
|
||||
hf_model_name: str = "gpt2",
|
||||
test_input: str = "Hello world"
|
||||
llm_model, hf_model_name: str = "gpt2", test_input: str = "Hello world"
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Сравнивает llm модель с эталонной моделью из HuggingFace.
|
||||
|
||||
|
||||
Args:
|
||||
llm_model: Модель из библиотеки llm
|
||||
hf_model_name: Имя модели HuggingFace для сравнения
|
||||
test_input: Тестовый вход
|
||||
|
||||
|
||||
Returns:
|
||||
Dict: Результаты сравнения
|
||||
"""
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
|
||||
# Загружаем эталонную модель
|
||||
hf_tokenizer = AutoTokenizer.from_pretrained(hf_model_name)
|
||||
hf_model = AutoModelForCausalLM.from_pretrained(hf_model_name)
|
||||
|
||||
|
||||
# Подготавливаем входные данные
|
||||
inputs = hf_tokenizer(test_input, return_tensors="pt")
|
||||
|
||||
|
||||
# Получаем логиты от обеих моделей
|
||||
with torch.no_grad():
|
||||
hf_logits = hf_model(**inputs).logits
|
||||
llm_logits = llm_model(inputs['input_ids'])
|
||||
|
||||
llm_logits = llm_model(inputs["input_ids"])
|
||||
|
||||
# Сравниваем результаты
|
||||
hf_probs = torch.softmax(hf_logits[0, -1], dim=-1)
|
||||
llm_probs = torch.softmax(llm_logits[0, -1], dim=-1)
|
||||
|
||||
|
||||
# Вычисляем метрики
|
||||
kl_divergence = torch.nn.functional.kl_div(
|
||||
torch.log(llm_probs + 1e-8),
|
||||
hf_probs,
|
||||
reduction='batchmean'
|
||||
torch.log(llm_probs + 1e-8), hf_probs, reduction="batchmean"
|
||||
)
|
||||
|
||||
|
||||
cosine_similarity = torch.nn.functional.cosine_similarity(
|
||||
hf_logits.flatten(),
|
||||
llm_logits.flatten(),
|
||||
dim=0
|
||||
hf_logits.flatten(), llm_logits.flatten(), dim=0
|
||||
)
|
||||
|
||||
|
||||
return {
|
||||
"kl_divergence": kl_divergence.item(),
|
||||
"cosine_similarity": cosine_similarity.item(),
|
||||
@@ -228,58 +220,52 @@ class TokenizerWrapper:
|
||||
"""
|
||||
Обертка для токенизатора с дополнительными утилитами.
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self, tokenizer):
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
|
||||
def encode_batch(self, texts: List[str], **kwargs) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
Кодирует батч текстов.
|
||||
|
||||
|
||||
Args:
|
||||
texts: Список текстов
|
||||
**kwargs: Дополнительные параметры токенизации
|
||||
|
||||
|
||||
Returns:
|
||||
Dict: Токенизированные данные
|
||||
"""
|
||||
return self.tokenizer(
|
||||
texts,
|
||||
padding=True,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
**kwargs
|
||||
texts, padding=True, truncation=True, return_tensors="pt", **kwargs
|
||||
)
|
||||
|
||||
|
||||
def decode_batch(self, token_ids: torch.Tensor, **kwargs) -> List[str]:
|
||||
"""
|
||||
Декодирует батч токенов.
|
||||
|
||||
|
||||
Args:
|
||||
token_ids: Тензор с токенами
|
||||
**kwargs: Дополнительные параметры декодирования
|
||||
|
||||
|
||||
Returns:
|
||||
List[str]: Декодированные тексты
|
||||
"""
|
||||
if token_ids.dim() == 1:
|
||||
token_ids = token_ids.unsqueeze(0)
|
||||
|
||||
|
||||
texts = []
|
||||
for i in range(token_ids.size(0)):
|
||||
text = self.tokenizer.decode(
|
||||
token_ids[i],
|
||||
skip_special_tokens=True,
|
||||
**kwargs
|
||||
token_ids[i], skip_special_tokens=True, **kwargs
|
||||
)
|
||||
texts.append(text)
|
||||
|
||||
|
||||
return texts
|
||||
|
||||
|
||||
def get_vocab_size(self) -> int:
|
||||
"""Возвращает размер словаря."""
|
||||
return len(self.tokenizer)
|
||||
|
||||
|
||||
def get_special_tokens(self) -> Dict[str, int]:
|
||||
"""Возвращает специальные токены."""
|
||||
return {
|
||||
@@ -290,36 +276,27 @@ class TokenizerWrapper:
|
||||
}
|
||||
|
||||
|
||||
def create_hf_pipeline(
|
||||
llm_model,
|
||||
tokenizer=None,
|
||||
device: str = "auto",
|
||||
**kwargs
|
||||
):
|
||||
def create_hf_pipeline(llm_model, tokenizer=None, device: str = "auto", **kwargs):
|
||||
"""
|
||||
Создает HuggingFace pipeline из llm модели.
|
||||
|
||||
|
||||
Args:
|
||||
llm_model: Модель из библиотеки llm
|
||||
tokenizer: Токенизатор
|
||||
device: Устройство для вычислений
|
||||
**kwargs: Дополнительные параметры pipeline
|
||||
|
||||
|
||||
Returns:
|
||||
transformers.Pipeline: Готовый pipeline
|
||||
"""
|
||||
from transformers import pipeline
|
||||
|
||||
|
||||
# Конвертируем модель в HF формат
|
||||
hf_model, tokenizer = HFUtils.convert_to_hf_format(llm_model, tokenizer)
|
||||
|
||||
|
||||
# Создаем pipeline
|
||||
pipe = pipeline(
|
||||
"text-generation",
|
||||
model=hf_model,
|
||||
tokenizer=tokenizer,
|
||||
device=device,
|
||||
**kwargs
|
||||
"text-generation", model=hf_model, tokenizer=tokenizer, device=device, **kwargs
|
||||
)
|
||||
|
||||
|
||||
return pipe
|
||||
|
||||
@@ -27,14 +27,19 @@ llm/
|
||||
│ │ ├── gpt.py # Базовая GPT
|
||||
│ │ ├── gpt2.py # GPT-2 реализация
|
||||
│ │ └── __init__.py
|
||||
│ └── llama/ # LLaMA архитектура
|
||||
│ ├── llama.py # LLaMA реализация
|
||||
│ ├── llama/ # LLaMA архитектура
|
||||
│ │ ├── llama.py # LLaMA реализация
|
||||
│ │ └── __init__.py
|
||||
│ └── mistral/ # Mistral архитектура
|
||||
│ ├── mistral.py # Mistral реализация
|
||||
│ └── __init__.py
|
||||
├── tokenizers/ # Токенизаторы
|
||||
│ ├── base_tokenizer.py # Базовый интерфейс
|
||||
│ └── bpe_tokenizer.py # BPE токенизатор
|
||||
├── datasets/ # Работа с датасетами
|
||||
│ ├── text_dataset.py # Стандартный датасет
|
||||
│ └── streaming_text_dataset.py # Стриминговый датасет
|
||||
└── training/ # Утилиты обучения
|
||||
├── dataset.py # Датасеты
|
||||
├── trainer.py # Тренировочный цикл
|
||||
├── optimizer.py # Оптимизаторы
|
||||
└── scheduler.py # Планировщики обучения
|
||||
@@ -175,13 +180,12 @@ generated = model.generate(input_ids, max_length=100)
|
||||
- ✅ Learned positional embeddings
|
||||
- ✅ Базовая архитектура трансформер-декодера
|
||||
|
||||
### GPT-2 Особенности
|
||||
- ✅ Улучшенная версия оригинальной GPT
|
||||
### GPT-2 Особенности
|
||||
- ✅ Layer Normalization (перед вниманием и FFN)
|
||||
- ✅ GELU активация
|
||||
- ✅ Learned positional embeddings
|
||||
- ✅ Кэширование для эффективной генерации
|
||||
- ✅ Оптимизированные веса инициализации
|
||||
- ✅ Кэширование KV для быстрой генерации
|
||||
- ✅ Улучшенная инициализация слоёв
|
||||
|
||||
### LLaMA Особенности
|
||||
- ✅ Rotary Positional Embeddings (RoPE)
|
||||
@@ -190,6 +194,21 @@ generated = model.generate(input_ids, max_length=100)
|
||||
- ✅ Оптимизированная структура декодера
|
||||
- ✅ Эффективное кэширование KV-памяти
|
||||
|
||||
### Mistral Особенности
|
||||
- ✅ Sliding Window Attention (оконное внимание)
|
||||
- ✅ Grouped Query Attention (GQA)
|
||||
- ✅ RoPE
|
||||
- ✅ RMSNorm
|
||||
- ✅ Разделённая архитектура на блоки с эффективным управлением памятью
|
||||
- ✅ Совместимость с HuggingFace через hf-proxy
|
||||
|
||||
## 🤝 Интеграция с HuggingFace и BPE
|
||||
|
||||
- Встроенная поддержка собственных BPE токенизаторов и экспериментальная поддержка токенизаторов через HuggingFace (см. hf-proxy).
|
||||
- hf-proxy — экспериментальный модуль! Совместимость с будущими версиями Transformers не гарантируется; API может меняться.
|
||||
- Допускается загрузка/конвертация моделей в формат HF для использования экосистемы Transformers.
|
||||
- Для запуска моделей с токенизаторами HF используйте `hf-proxy` и соответствующие эксперименты из `experiments/hf_integration/`.
|
||||
|
||||
## 🧪 Тестирование
|
||||
|
||||
Запуск всех тестов:
|
||||
@@ -198,7 +217,7 @@ cd llm
|
||||
python -m pytest tests/ -v
|
||||
```
|
||||
|
||||
**Статус тестов:** ✅ 101 тест пройден
|
||||
**Статус тестов:** ✅ 101+ тест, охвачены все основные компоненты (ядро, ядро-токенизация, архитектуры, обучение)
|
||||
|
||||
## 📚 Научные концепции
|
||||
|
||||
|
||||
@@ -19,23 +19,25 @@ from abc import ABC, abstractmethod
|
||||
from typing import Optional, Tuple
|
||||
import torch
|
||||
|
||||
|
||||
class BaseModel(nn.Module, ABC):
|
||||
"""
|
||||
Абстрактный класс — стандарт для всех архитектур LLM.
|
||||
|
||||
|
||||
Научная идея:
|
||||
Реализация унифицированного входа/выхода для поддержки построения и обучения любых современных языковых моделей.
|
||||
|
||||
|
||||
Args:
|
||||
config (dict): Параметры архитектуры (размерность эмбеддингов, число слоев, heads и т.д.)
|
||||
|
||||
|
||||
Attributes:
|
||||
config (dict): Конфиг модели
|
||||
"""
|
||||
|
||||
def __init__(self, config: dict):
|
||||
"""
|
||||
Инициализация модели.
|
||||
|
||||
|
||||
Args:
|
||||
config (dict): Настройки архитектуры модели (размеры слоев, типы блоков и т.д.)
|
||||
"""
|
||||
@@ -43,10 +45,12 @@ class BaseModel(nn.Module, ABC):
|
||||
self.config = config
|
||||
|
||||
@abstractmethod
|
||||
def forward(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
def forward(
|
||||
self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Прямой проход — получение логитов для входных токенов.
|
||||
|
||||
|
||||
Args:
|
||||
input_ids (Tensor[int]): Индексы токенов [batch, seq_len]
|
||||
attention_mask (Optional[Tensor[bool]]): Маска разрешенных позиций (если требуется) [batch, seq_len]
|
||||
@@ -59,7 +63,7 @@ class BaseModel(nn.Module, ABC):
|
||||
def generate(self, input_ids: torch.Tensor, max_length: int = 50) -> torch.Tensor:
|
||||
"""
|
||||
Генерация текста (авторегрессивно, greedy или sampling).
|
||||
|
||||
|
||||
Args:
|
||||
input_ids (Tensor[int]): Начальные токены [batch, start_len]
|
||||
max_length (int): Максимальная длина последовательности
|
||||
|
||||
@@ -6,36 +6,51 @@ from .feed_forward import FeedForward
|
||||
from .multi_head_attention import MultiHeadAttention
|
||||
from .rope import RoPE
|
||||
|
||||
|
||||
class CachedDecoder(nn.Module):
|
||||
"""
|
||||
Универсальный декодерный блок для современных LLM (GPT, LLaMA, др.), поддерживает кэширование key-value для эффективной генерации.
|
||||
CachedDecoder — Transformer-декодер с key/value-кэшированием (реализация накладывающегося masked multi-head attention).
|
||||
|
||||
Научная идея:
|
||||
Автопагрессивная авторегрессия в трансформерах требует быстрого доступа к ранее вычисленным self-attention ключам/значениям — этот класс позволяет прозрачно кэшировать такие состояния для быстрой инференс-генерации.
|
||||
Назначение:
|
||||
-----------
|
||||
Позволяет быстро и эффективно реализовывать autoregressive генерацию текста в стиле GPT-2/3/4:
|
||||
- На шаге генерации используются только нужные токены, “прошлые” key/value значения не пересчитываются, а подаются из кэша.
|
||||
- Позволяет значительно ускорять inferece (особенно на длинных последовательностях).
|
||||
- Вдохновлено реализациями в HuggingFace transformers, GPT-2/3 и других LLM.
|
||||
|
||||
Алгоритм:
|
||||
- Input -> LayerNorm -> Многоголовое внимание с кэшем (может быть RoPE)
|
||||
- Суммируем residual
|
||||
- LayerNorm -> FeedForward (любой, например SwiGLU) -> Residual
|
||||
- Возвращается кортеж (output, kvcache)
|
||||
Архитектурные особенности:
|
||||
--------------------------
|
||||
- Использует классическую multi-head attention (с causal mask — запрещает видеть “будущее”).
|
||||
- Предусматривает передачу и накопление KV-cache для каждого слоя (hidden state attention).
|
||||
- Поддерживает передачу внимания через стек attention-блоков.
|
||||
- Применяется layernorm и feed-forward block (GELU).
|
||||
|
||||
Параметры конструктора:
|
||||
-----------------------
|
||||
num_heads : int — число attention heads
|
||||
emb_size : int — embedding размерность
|
||||
head_size : int — размер каждой attention head (обычно emb_size // num_heads)
|
||||
feed_forward_layer : nn.Module — feedforward блок (mLP), может быть любым PyTorch-слоем
|
||||
max_seq_len : int — максимально допустимая длина последовательности
|
||||
dropout : float — dropout на attention/ffn
|
||||
|
||||
Пример использования:
|
||||
---------------------
|
||||
>>> from llm.core.feed_forward import FeedForward
|
||||
>>> ff_block = FeedForward(emb_size=256, dropout=0.1, activation=\"gelu\")
|
||||
>>> decoder = CachedDecoder(num_heads=4, emb_size=256, head_size=64, feed_forward_layer=ff_block, max_seq_len=2048, dropout=0.1)
|
||||
>>> x = torch.randn(2, 100, 256)
|
||||
>>> y, kv_cache = decoder(x, use_cache=True, cache=None)
|
||||
>>> print(y.shape) # torch.Size([2, 100, 256])
|
||||
|
||||
Подробнее:
|
||||
----------
|
||||
- GPT-2: https://cdn.openai.com/better-language-models/language-models.pdf
|
||||
- HuggingFace cache mechanics: https://huggingface.co/docs/transformers/main/en/model_doc/gpt2
|
||||
- Объяснения autoregressive cache: https://jalammar.github.io/illustrated-gpt2/
|
||||
|
||||
Args:
|
||||
feed_forward_layer (nn.Module): FeedForward или SwiGLU слой
|
||||
num_heads (int): Количество голов внимания
|
||||
emb_size (int): Размерность эмбеддингов
|
||||
head_size (int): Размерность головы внимания
|
||||
max_seq_len (int): Максимальная длина
|
||||
norm_layer (тип nn.Module): Normalization слой (LayerNorm или RMSNorm)
|
||||
dropout (float): Dropout
|
||||
rope (RoPE|None): Экземпляр RoPE (для LLaMA)
|
||||
|
||||
Пример (GPT2 style):
|
||||
>>> decoder = CachedDecoder(
|
||||
... feed_forward_layer=FeedForward(...),
|
||||
... norm_layer=nn.LayerNorm,
|
||||
... num_heads=4, emb_size=256, head_size=64, max_seq_len=128)
|
||||
>>> out, cache = decoder(x, use_cache=True)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
feed_forward_layer: nn.Module,
|
||||
@@ -48,20 +63,22 @@ class CachedDecoder(nn.Module):
|
||||
rope: RoPE = None,
|
||||
):
|
||||
"""
|
||||
Инициализация декодера с кэшированием.
|
||||
|
||||
Поведение аналогично блоку TransformerDecoderLayer,
|
||||
но с гибкой возможностью подмены любых подкомпонент (активация, norm, позиции).
|
||||
Конструктор CachedDecoder.
|
||||
|
||||
Args:
|
||||
feed_forward_layer: Слой feed-forward (должен быть экземпляром, а не классом)
|
||||
num_heads: Количество голов внимания
|
||||
emb_size: Размерность эмбеддингов
|
||||
head_size: Размерность каждой головы
|
||||
max_seq_len: Максимальная длина последовательности
|
||||
norm_layer: Класс нормализации (по умолчанию LayerNorm)
|
||||
dropout: Вероятность dropout
|
||||
rope: Rotary Positional Embeddings (опционально)
|
||||
Аргументы:
|
||||
----------
|
||||
num_heads : int
|
||||
Сколько attention heads используется в каждом attention слое.
|
||||
emb_size : int
|
||||
Размерность входного вектора x.
|
||||
head_size : int
|
||||
Размерность каждой attention head; emb_size = num_heads * head_size должно быть True!
|
||||
feed_forward_layer : nn.Module
|
||||
Feed-forward слой (например, обычный двухслойный MLP), который применяется после нормы и внимания, и после второй нормы.
|
||||
max_seq_len : int
|
||||
Максимальная поддерживаемая длина последовательности (выделяет буфер для causal-маски).
|
||||
dropout : float, default=0.1
|
||||
Dropout после внимания и/или feedforward.
|
||||
"""
|
||||
super().__init__()
|
||||
self._heads = MultiHeadAttention(
|
||||
@@ -84,19 +101,30 @@ class CachedDecoder(nn.Module):
|
||||
cache: list = None,
|
||||
):
|
||||
"""
|
||||
Прямой проход с поддержкой кэша.
|
||||
|
||||
Args:
|
||||
x (Tensor[float]): [batch, seq_len, emb_size] — скрытые состояния
|
||||
mask (Optional[Tensor]): маска внимания (или causal mask), shape [seq_len, seq_len]
|
||||
use_cache (bool): использовать кэширование KV
|
||||
cache (list): кэш self-attention для быстрого авторегрессива
|
||||
Returns:
|
||||
output (Tensor[float]): выходные состояния [batch, seq_len, emb_size]
|
||||
kv_caches (list): обновленный кэш, если use_cache
|
||||
Пример:
|
||||
>>> out, new_cache = decoder(x, use_cache=True, cache=old_cache)
|
||||
>>> out.shape # [batch, seq_len, emb_size]
|
||||
Прямой проход через Decoder Block с поддержкой KV-кэша.
|
||||
|
||||
В этом методе применяется:
|
||||
- Causal multi-head attention (masked, не смотрит вперёд)
|
||||
- Быстрая обработка длинных последовательностей за счёт сохранения и передачи KV-кэша
|
||||
- LayerNorm перед каждым блоком
|
||||
- Feed-forward блок и вторая LayerNorm
|
||||
- Dropout
|
||||
|
||||
Аргументы:
|
||||
----------
|
||||
x : torch.Tensor
|
||||
Вход [batch, seq_len, emb_size]
|
||||
use_cache : bool, по умолчанию True
|
||||
Включать ли накопление и возврат KV-кэша для autoregressive inferece.
|
||||
cache : list, опционально
|
||||
Список предыдущего KV-кеша для attention.
|
||||
|
||||
Возвращает:
|
||||
-----------
|
||||
x_ff_out : torch.Tensor
|
||||
Результат после attention, модуля и их рез. связей (shape == x)
|
||||
new_cache : new KV-cache (или None)
|
||||
|
||||
"""
|
||||
norm1_out = self._norm1(x)
|
||||
# Передаём все cache/use_cache дальше в attention
|
||||
@@ -111,4 +139,4 @@ class CachedDecoder(nn.Module):
|
||||
if use_cache:
|
||||
return (result, kv_caches)
|
||||
else:
|
||||
return (result, None)
|
||||
return (result, None)
|
||||
|
||||
@@ -1,79 +0,0 @@
|
||||
from torch import nn
|
||||
import torch
|
||||
from .feed_forward import FeedForward
|
||||
from .multi_head_attention import MultiHeadAttention
|
||||
|
||||
class Decoder(nn.Module):
|
||||
"""
|
||||
Базовый автогерессивный блок-декодер трансформера (без кэша KV).
|
||||
|
||||
Научная суть:
|
||||
- Осуществляет посимвольное предсказание: каждый токен видит только предыдущие (masked attention)
|
||||
- Состоит из self-attention + feedforward + residual + нормализация
|
||||
- Residual connection и normalization дают стабильность и градиентный “flow” при обучении
|
||||
- Механизм предложен в Vaswani et al., "Attention is All You Need", 2017
|
||||
Args:
|
||||
num_heads (int): количество attention-голов
|
||||
emb_size (int): размер эмбеддинга
|
||||
head_size (int): размер одной attention-головы
|
||||
max_seq_len (int): максимальная длина последовательности
|
||||
dropout (float): вероятность dropout
|
||||
Пример:
|
||||
>>> decoder = Decoder(num_heads=8, emb_size=512, head_size=64, max_seq_len=1024)
|
||||
>>> x = torch.randn(1, 10, 512)
|
||||
>>> out = decoder(x)
|
||||
>>> print(out.shape) # torch.Size([1, 10, 512])
|
||||
"""
|
||||
def __init__(self,
|
||||
num_heads: int,
|
||||
emb_size: int,
|
||||
head_size: int,
|
||||
max_seq_len: int,
|
||||
dropout: float = 0.1
|
||||
):
|
||||
"""
|
||||
Инициализация декодера.
|
||||
|
||||
Параметры:
|
||||
num_heads: int - количество голов внимания
|
||||
emb_size: int - размерность эмбеддингов
|
||||
head_size: int - размерность каждой головы внимания
|
||||
max_seq_len: int - максимальная длина последовательности
|
||||
dropout: float (default=0.1) - вероятность dropout
|
||||
"""
|
||||
super().__init__()
|
||||
self._heads = MultiHeadAttention(
|
||||
num_heads=num_heads,
|
||||
emb_size=emb_size,
|
||||
head_size=head_size,
|
||||
max_seq_len=max_seq_len,
|
||||
dropout=dropout
|
||||
)
|
||||
self._ff = FeedForward(emb_size=emb_size, dropout=dropout)
|
||||
self._norm1 = nn.LayerNorm(emb_size)
|
||||
self._norm2 = nn.LayerNorm(emb_size)
|
||||
|
||||
def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
|
||||
"""
|
||||
Прямой проход через декодер.
|
||||
|
||||
Вход:
|
||||
x: torch.Tensor - входной тензор [batch_size, seq_len, emb_size]
|
||||
mask: torch.Tensor (optional) - маска внимания [seq_len, seq_len]
|
||||
|
||||
Возвращает:
|
||||
torch.Tensor - выходной тензор [batch_size, seq_len, emb_size]
|
||||
|
||||
Алгоритм forward:
|
||||
1. Применяем MultiHeadAttention к входу
|
||||
2. Добавляем residual connection и LayerNorm
|
||||
3. Применяем FeedForward сеть
|
||||
4. Добавляем residual connection и LayerNorm
|
||||
"""
|
||||
# Self-Attention блок
|
||||
attention, _ = self._heads(x, mask, use_cache=False, cache=None)
|
||||
out = self._norm1(attention + x)
|
||||
|
||||
# FeedForward блок
|
||||
ffn_out = self._ff(out)
|
||||
return self._norm2(ffn_out + out)
|
||||
@@ -6,52 +6,71 @@ from .gelu import GELU
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
"""
|
||||
Классический слой прямого распространения (FeedForward, или FFN) для архитектуры Transformer.
|
||||
FeedForward — классический позиционно-независимый блок для Transformer, применяется к каждому токену отдельно.
|
||||
|
||||
Этот слой состоит из двух линейных преобразований с расширением внутренней размерности
|
||||
в 4 раза и механизмом dropout для регуляризации. Между линейными слоями применяется
|
||||
активация ReLU.
|
||||
Назначение и роль:
|
||||
------------------
|
||||
- Реализует двухслойную (или более сложную) нейронную сеть, которая обрабатывает каждый токен ПОРЯДОЧНО независимо (по последней измерении).
|
||||
- Дает модели "нелинейную мощность": любой токен может быть переосмыслен вне глобального контекста.
|
||||
- После слоя внимания (MHA) FFN помогает связать смысл локальных (внутри токена) “скрытых” значений.
|
||||
|
||||
Научная суть:
|
||||
- После внимания каждому токену применяется одинаковая двухслойная нейросеть.
|
||||
- Дает глубокую нелинейность; позволяет модели не только сопоставлять, но и моделировать сложные связи между токенами.
|
||||
- Изначально предложен в «Attention is All You Need» (Vaswani et al., 2017).
|
||||
|
||||
Формула:
|
||||
FFN(x) = Dropout(W2·act(W1·x))
|
||||
где act — ReLU, GELU и др., обычно expansion x4.
|
||||
Архитектурные детали:
|
||||
---------------------
|
||||
- Обычно используется блок: (Linear → Activation → Dropout → Linear → Dropout)
|
||||
- В современных LLM обычно в 4 раза расширяют скрытый слой (inner_dim = 4 * emb_size).
|
||||
- Активация часто GELU или SiLU (Swish), иногда SwiGLU, ReGLU, GeGLU (см. PaLM, Llama).
|
||||
|
||||
Алгоритм работы:
|
||||
1. Входной тензор x (размерность: [batch_size, seq_len, emb_size])
|
||||
2. Линейное преобразование: emb_size -> 4*emb_size
|
||||
3. Активация ReLU
|
||||
4. Линейное преобразование: 4*emb_size -> emb_size
|
||||
5. Применение dropout
|
||||
6. Возврат результата (размерность: [batch_size, seq_len, emb_size])
|
||||
Формула (обычная версия):
|
||||
-------------------------
|
||||
FFN(x) = Linear2(Dropout(Activation(Linear1(x))))
|
||||
где Linear1: [emb_size → 4*emb_size], Activation: GELU/SiLU, Linear2: [4*emb_size → emb_size]
|
||||
|
||||
Параметры конструктора:
|
||||
-----------------------
|
||||
emb_size: int — размерность входа/выхода токена
|
||||
inner_dim: int (необязательно) — размер скрытого слоя (по умолчанию 4*emb_size)
|
||||
activation: str — тип активации ('gelu', 'silu', 'relu', ...), см. варианты ниже
|
||||
dropout: float — dropout после каждой линейной проекции
|
||||
|
||||
Пример использования:
|
||||
---------------------
|
||||
>>> ffn = FeedForward(emb_size=256, dropout=0.1, activation='gelu')
|
||||
>>> x = torch.randn(2, 32, 256) # [batch, seq_len, emb_size]
|
||||
>>> y = ffn(x)
|
||||
>>> print(y.shape) # torch.Size([2, 32, 256])
|
||||
|
||||
Пояснения:
|
||||
----------
|
||||
- FeedForward не использует позицию токена — это МLP, применяемый к каждому токену независимо.
|
||||
- Длина последовательности и размер батча не имеют значения (broadcast/reshape по [-2, -1]).
|
||||
- Используется во всех декодерах/энкодерах трансформеров.
|
||||
|
||||
Подробнее смотри:
|
||||
-----------------
|
||||
- Vaswani et al., "Attention is All You Need": https://arxiv.org/abs/1706.03762
|
||||
- GELU: https://arxiv.org/abs/1606.08415
|
||||
- SwiGLU (PaLM, Llama): https://arxiv.org/abs/2002.05202
|
||||
|
||||
Предназначение:
|
||||
- Добавляет нелинейность в архитектуру трансформера
|
||||
- Обеспечивает взаимодействие между различными размерностями эмбеддингов
|
||||
- Работает независимо для каждого токена в последовательности
|
||||
|
||||
Args:
|
||||
emb_size (int): размерность входных эмбеддингов
|
||||
dropout (float): вероятность(dropout)
|
||||
activation (str): нелинейная функция (relu, gelu, gelu_exact)
|
||||
|
||||
Пример:
|
||||
>>> ff = FeedForward(emb_size=512, dropout=0.1)
|
||||
>>> x = torch.randn(32, 10, 512)
|
||||
>>> output = ff(x)
|
||||
>>> print(output.shape) # torch.Size([32, 10, 512])
|
||||
"""
|
||||
|
||||
def __init__(self, emb_size: int, dropout: float = 0.1, activation: str = "relu"):
|
||||
"""
|
||||
Инициализация слоя Feed Forward Network.
|
||||
|
||||
Args:
|
||||
emb_size: Размерность входных эмбеддингов
|
||||
dropout: Вероятность dropout для регуляризации (по умолчанию: 0.1)
|
||||
Инициализация FeedForward блока для трансформера.
|
||||
|
||||
Аргументы:
|
||||
----------
|
||||
emb_size: int
|
||||
Размерность входного и выходного эмбеддинга модели.
|
||||
dropout: float, по умолчанию 0.1
|
||||
Dropout после линии и/или активации (уменьшает переобучение).
|
||||
activation: str, по умолчанию 'gelu'
|
||||
Какая нелинейность использовать ('gelu', 'silu', 'relu' и т.д.).
|
||||
inner_dim: int, опционально
|
||||
Размер скрытого слоя (по умолчанию 4 * emb_size, как в оригинальном Transformer).
|
||||
|
||||
Внутри:
|
||||
-------
|
||||
- Задает структуру: Linear → Activation → Dropout → Linear → Dropout.
|
||||
"""
|
||||
super().__init__()
|
||||
# Первый линейный слой (расширение размерности)
|
||||
@@ -72,24 +91,34 @@ class FeedForward(nn.Module):
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
"""
|
||||
Прямой проход через слой Feed Forward Network.
|
||||
|
||||
Args:
|
||||
x: Входной тензор размерности [batch_size, seq_len, emb_size]
|
||||
|
||||
Returns:
|
||||
Тензор той же размерности, что и входной
|
||||
Прямой проход через FeedForward блок.
|
||||
|
||||
Аргументы:
|
||||
----------
|
||||
x : torch.Tensor
|
||||
Входной тензор формы [..., emb_size] (используется на каждом токене отдельно!)
|
||||
|
||||
Возвращает:
|
||||
-----------
|
||||
torch.Tensor — выход такой же формы, как вход (только последняя размерность сохраняется).
|
||||
|
||||
Пример:
|
||||
-------
|
||||
>>> ffn = FeedForward(emb_size=256)
|
||||
>>> x = torch.randn(8, 16, 256)
|
||||
>>> y = ffn(x)
|
||||
>>> y.shape # [8, 16, 256]
|
||||
"""
|
||||
# Сохраняем dtype входных данных
|
||||
input_dtype = x.dtype
|
||||
|
||||
|
||||
# Приводим веса к нужному типу если необходимо
|
||||
if input_dtype != self._layer1.weight.dtype:
|
||||
self._layer1 = self._layer1.to(dtype=input_dtype)
|
||||
self._layer2 = self._layer2.to(dtype=input_dtype)
|
||||
|
||||
|
||||
# Пропустим тензор x по очереди через все созданные слои
|
||||
x = self._layer1(x)
|
||||
x = self._activation(x)
|
||||
x = self._layer2(x)
|
||||
return self._dropout(x)
|
||||
return self._dropout(x)
|
||||
|
||||
140
llm/src/llm/core/geglu.py
Normal file
140
llm/src/llm/core/geglu.py
Normal file
@@ -0,0 +1,140 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from llm.core.gelu import GELU
|
||||
|
||||
class GeGLU(nn.Module):
|
||||
"""
|
||||
GeGLU (Gated GELU Linear Unit) — эффективная нелинейность для feed-forward блоков в современных трансформерах.
|
||||
|
||||
Назначение:
|
||||
-----------
|
||||
GeGLU — это вариант GLU (Gated Linear Unit), где «шлюз» реализован через GELU-активацию,
|
||||
а затем поэлементно перемножается с другим линейным преобразованием. Такой gating-механизм позволяет повысить
|
||||
выразительность MLP-блока и ускорить обучение, что подтверждено экспериментами на LLM (см. PaLM, LLaMA, T5).
|
||||
|
||||
Формула:
|
||||
--------
|
||||
GeGLU(x) = GELU(W_g x + b_g) ⊙ (W_u x + b_u) W_d + b_d
|
||||
(здесь W_g, W_u, W_d — матрицы весов; GELU применяется к одной ветке, ⊙ — поэлементное умножение)
|
||||
|
||||
Структура блока:
|
||||
----------------
|
||||
1. gate = GELU(Linear_gate(x)) # ветка gating-а, shape [batch, seq, 4×emb]
|
||||
2. up = Linear_up(x) # ветка передачи, shape [batch, seq, 4×emb]
|
||||
3. out = gate * up # поэлементно, реализует динамическую фильтрацию информации
|
||||
4. out = Linear_down(out) # проекция обратно в исходное пространство
|
||||
5. out = Dropout(out) # регуляризация
|
||||
|
||||
Основные преимущества:
|
||||
----------------------
|
||||
- Позволяет эффективно обучать глубокие трансформеры (см. PaLM, LLaMA).
|
||||
- Обеспечивает плавные градиенты за счёт GELU и gating-эффекта.
|
||||
- Используется во многих современных LLM вместо обычных FFN или простых GLU.
|
||||
|
||||
Аргументы конструктора:
|
||||
-----------------------
|
||||
emb_size : int
|
||||
Размер эмбеддинга (input и output).
|
||||
dropout : float, по умолчанию 0.1
|
||||
Dropout к финальному выходу (примерно 0.1-0.2 для регуляризации).
|
||||
|
||||
Пример использования:
|
||||
---------------------
|
||||
>>> geglu = GeGLU(emb_size=512, dropout=0.1)
|
||||
>>> x = torch.randn(8, 16, 512)
|
||||
>>> y = geglu(x)
|
||||
>>> print(y.shape) # torch.Size([8, 16, 512])
|
||||
|
||||
Литература:
|
||||
-----------
|
||||
- Shazeer N., "GLU Variants Improve Transformer", 2020: https://arxiv.org/abs/2002.05202
|
||||
- PaLM: https://arxiv.org/abs/2204.02311
|
||||
- LLaMA: https://arxiv.org/abs/2302.13971
|
||||
- T5: https://arxiv.org/abs/1910.10683
|
||||
"""
|
||||
def __init__(self, emb_size: int, dropout: float = 0.1):
|
||||
"""
|
||||
Инициализация блока GeGLU.
|
||||
|
||||
Создаёт три последовательных линейных слоя и задаёт GELU в качестве активации для ветки gating,
|
||||
а также финальный dropout. Все размеры согласованы так, чтобы реализовать формулу GeGLU (см. описание класса).
|
||||
|
||||
Аргументы:
|
||||
----------
|
||||
emb_size : int
|
||||
Размерность входного и выходного скрытого пространства (hidden size).
|
||||
Данная величина определяет размерность эмбеддинга для всех внутренних вычислений.
|
||||
Обычно равна размеру скрытого слоя трансформера.
|
||||
|
||||
dropout : float, по умолчанию 0.1
|
||||
Вероятность отключения нейронов после выхода из блока (регуляризация).
|
||||
Рекомендуемое значение: 0.1 (или чуть больше для небольших моделей).
|
||||
|
||||
Внутри:
|
||||
-------
|
||||
- self._gate: Linear слой размерности [emb_size, 4 * emb_size], ветка gating (проходит через GELU)
|
||||
- self._up: Linear слой размерности [emb_size, 4 * emb_size], ветка передачи ("пропускная")
|
||||
- self._down: Linear слой сжатия обратно к emb_size
|
||||
- self._activation: Активация GELU для gating-ветки
|
||||
- self._dropout: Dropout для выходного тензора
|
||||
|
||||
Пример:
|
||||
-------
|
||||
>>> block = GeGLU(emb_size=256, dropout=0.1)
|
||||
>>> print(block)
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self._gate = nn.Linear(emb_size, 4 * emb_size)
|
||||
self._up = nn.Linear(emb_size, 4 * emb_size)
|
||||
self._down = nn.Linear(4 * emb_size, emb_size)
|
||||
self._activation = GELU()
|
||||
self._dropout = nn.Dropout(dropout)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
"""
|
||||
Прямой проход (forward) через блок GeGLU.
|
||||
|
||||
Для входного тензора скрытых состояний x реализует последовательность операций:
|
||||
1. Gating-ветка: линейное преобразование → GELU-активация
|
||||
2. Пропускная ветка: линейное преобразование
|
||||
3. Поэлементное умножение результатов обеих веток (gating)
|
||||
4. Проекция через Linear обратно к emb_size
|
||||
5. Dropout результата для регуляризации
|
||||
|
||||
Математически:
|
||||
--------------
|
||||
gate = GELU(W_g·x + b_g)
|
||||
up = W_u·x + b_u
|
||||
out = gate * up
|
||||
out = W_d·out + b_d
|
||||
out = Dropout(out)
|
||||
|
||||
Аргументы:
|
||||
----------
|
||||
x : torch.Tensor
|
||||
Входной тензор формы [batch_size, seq_len, emb_size]
|
||||
(или любой совместимой формы, где последняя ось — emb_size).
|
||||
|
||||
Возвращает:
|
||||
-----------
|
||||
torch.Tensor :
|
||||
Тензор той же формы [batch_size, seq_len, emb_size], прошедший через структуру GeGLU.
|
||||
|
||||
Пример:
|
||||
-------
|
||||
>>> y = geglu(x)
|
||||
>>> print(y.shape) # [batch_size, seq_len, emb_size]
|
||||
|
||||
Примечания:
|
||||
-----------
|
||||
- Ветка gating строит masк для динамической фильтрации информации.
|
||||
- Такой тип блока эффективно используется как замена обычного FFN в современных LLM.
|
||||
"""
|
||||
gate_out = self._gate(x) # [batch, seq, 4*emb]
|
||||
activation_out = self._activation(gate_out) # [batch, seq, 4*emb]
|
||||
up_out = self._up(x) # [batch, seq, 4*emb]
|
||||
out = up_out * activation_out # поэлементное!
|
||||
out = self._down(out) # [batch, seq, emb]
|
||||
return self._dropout(out)
|
||||
|
||||
@@ -1,27 +1,72 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
import math
|
||||
|
||||
class GELU(nn.Module):
|
||||
"""
|
||||
Гауссовская Эрф-активация (GELU, Gaussian Error Linear Unit).
|
||||
GELU (Gaussian Error Linear Unit) — современная сглаженная функция активации для нейросетей.
|
||||
|
||||
Научная суть:
|
||||
- Одна из самых популярных smooth активаций для трансформеров.
|
||||
- Дает более гибкие аппроксимации, чем ReLU/SiLU, улучшает flow градиентов для больших LLM.
|
||||
- Используется в BERT, GPT, GPT2 и почти всех современных NLP-моделях.
|
||||
Формула:
|
||||
GELU(x) = 0.5 * x * (1 + tanh(\sqrt{2/π} * (x + 0.044715 x³)))
|
||||
Подробнее: Hendrycks & Gimpel, "Gaussian Error Linear Units (GELUs)", arXiv:1606.08415
|
||||
Пример:
|
||||
Мотивация и назначение:
|
||||
-----------------------
|
||||
- GELU используется во всех современных трансформерах (BERT, GPT, Llama) вместо ReLU, поскольку лучше передает градиенты и даёт более "мягкое" обучение.
|
||||
- Формирует плавный переход между активированным и неактивированным состоянием, что улучшает устойчивость и общую производительность больших моделей.
|
||||
- Дает возможность обучению «решать», насколько сильно и в каких диапазонах нужно передавать сигнал (в отличие от жёсткого ReLU).
|
||||
|
||||
Математическая формула:
|
||||
-----------------------
|
||||
GELU(x) = 0.5 * x * (1 + tanh( sqrt(2/pi) * (x + 0.044715 * x^3) ))
|
||||
- Статья (Hendrycks & Gimpel, 2016): https://arxiv.org/abs/1606.08415
|
||||
- В PyTorch с версии 1.4+ встроена как torch.nn.functional.gelu и torch.nn.GELU.
|
||||
|
||||
Как это работает:
|
||||
-----------------
|
||||
- Для каждого входного значения x:
|
||||
- x при больших значениях (большие положительные) почти полностью передается дальше.
|
||||
- x при малых (или сильно отрицательных) "заглушается" к нулю.
|
||||
- На промежуточных значениях — плавный переход.
|
||||
- Является аппроксимацией случайного бинома с гауссовским шумом.
|
||||
|
||||
Args:
|
||||
-----
|
||||
Нет learnable параметров — GELU работает одинаково для всех входов.
|
||||
|
||||
Пример использования:
|
||||
---------------------
|
||||
>>> gelu = GELU()
|
||||
>>> y = gelu(torch.tensor([-1.0, 0.0, 1.0]))
|
||||
>>> print(y)
|
||||
>>> x = torch.tensor([-2.0, 0.0, 2.0])
|
||||
>>> print(gelu(x)) # тензор из плавно переходящих значений
|
||||
|
||||
References:
|
||||
-----------
|
||||
- Hendrycks & Gimpel: https://arxiv.org/abs/1606.08415
|
||||
- BERT, GPT-2 papers (везде используется GELU)
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.sqrt_2_over_pi = torch.sqrt(torch.tensor(2.0) / math.pi)
|
||||
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return 0.5 * x * (1 + torch.tanh(
|
||||
self.sqrt_2_over_pi * (x + 0.044715 * torch.pow(x, 3))
|
||||
))
|
||||
"""
|
||||
Прямой проход через GELU-активацию.
|
||||
|
||||
Args:
|
||||
-----
|
||||
x : torch.Tensor
|
||||
Любой входной тензор.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
torch.Tensor — тензор той же формы, где к каждому элементу применён GELU.
|
||||
|
||||
Пример:
|
||||
-------
|
||||
>>> gelu = GELU()
|
||||
>>> x = torch.linspace(-3, 3, 7)
|
||||
>>> y = gelu(x)
|
||||
"""
|
||||
return (
|
||||
0.5
|
||||
* x
|
||||
* (1 + torch.tanh(self.sqrt_2_over_pi * (x + 0.044715 * torch.pow(x, 3))))
|
||||
)
|
||||
|
||||
188
llm/src/llm/core/gemma_decoder.py
Normal file
188
llm/src/llm/core/gemma_decoder.py
Normal file
@@ -0,0 +1,188 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from llm.core.rope import RoPE
|
||||
from llm.core.multi_query_attention import MultiQueryAttention
|
||||
from llm.core.rms_norm import RMSNorm
|
||||
from llm.core.geglu import GeGLU
|
||||
|
||||
class GemmaDecoder(nn.Module):
|
||||
"""
|
||||
GemmaDecoder — декодерный блок архитектуры Gemma (Google DeepMind, 2024).
|
||||
|
||||
Назначение:
|
||||
-----------
|
||||
Данный блок реализует одну «ячейку» декодерного стека в модели Gemma. Архитектура схожа с современными LLM (Llama/Mistral),
|
||||
но имеет уникальные особенности attention и feed-forward слоёв, соответствующие спецификации Gemma.
|
||||
|
||||
Архитектурные компоненты:
|
||||
-------------------------
|
||||
- LayerNorm или RMSNorm
|
||||
- Multi-head self-attention (обычно Multi-Query Attention)
|
||||
- Skip connection (остаточное сложение)
|
||||
- Feed-forward блок (может включать SwiGLU, GeGLU или классический FFN)
|
||||
- Повторная нормализация
|
||||
- Dropout (регуляризация на уровне attention и feed-forward)
|
||||
|
||||
Алгоритм прямого прохода:
|
||||
-------------------------
|
||||
1. norm1_out = LayerNorm(x)
|
||||
2. attention_out = Attention(norm1_out, ...)
|
||||
3. resid1 = attention_out + x
|
||||
4. norm2_out = LayerNorm(resid1)
|
||||
5. ffn_out = FeedForward(norm2_out)
|
||||
6. output = ffn_out + resid1
|
||||
|
||||
Теоретические детали:
|
||||
---------------------
|
||||
- В Gemma используются техники оптимизации памяти и ускорения инференса (например, shared K/V-головы, Rope, кастомные FFN).
|
||||
- Поддержка кэширования attention для ускорения генерации (KV cache).
|
||||
- Блок проектирован для использования в стеке, повторяется N раз во всей LLM.
|
||||
|
||||
Аргументы конструктора:
|
||||
----------------------
|
||||
num_q_heads : int
|
||||
Число голов query (Query Heads) для attention.
|
||||
num_kv_heads : int
|
||||
Число ключевых/значенческих голов (Key/Value Heads).
|
||||
emb_size : int
|
||||
Размерность скрытого пространства (embedding dim).
|
||||
head_size : int
|
||||
Размерность одной attention-головы.
|
||||
max_seq_len : int
|
||||
Максимальная длина последовательности (ограничение на causal mask).
|
||||
dropout : float, optional
|
||||
Dropout для регуляризации (примерно 0.0–0.1).
|
||||
rope : RoPE, optional
|
||||
Позиционное кодирование Rotary Position Embedding.
|
||||
|
||||
Пример использования:
|
||||
---------------------
|
||||
>>> decoder = GemmaDecoder(
|
||||
... num_q_heads=8,
|
||||
... num_kv_heads=2,
|
||||
... emb_size=256,
|
||||
... head_size=32,
|
||||
... max_seq_len=1024,
|
||||
... dropout=0.1,
|
||||
... rope=rope_obj
|
||||
... )
|
||||
>>> x = torch.randn(2, 24, 256)
|
||||
>>> out, cache = decoder(x, mask=None, use_cache=True, cache=None)
|
||||
>>> print(out.shape) # torch.Size([2, 24, 256])
|
||||
|
||||
Литература и ссылки:
|
||||
--------------------
|
||||
- Gemma (официальный релиз): https://ai.google.dev/gemma
|
||||
- Gemma paper: https://arxiv.org/abs/2403.07794
|
||||
- Rotary Embedding: https://arxiv.org/abs/2104.09864
|
||||
- Multi-Query Attention: https://arxiv.org/abs/1911.02150
|
||||
- Llama: https://arxiv.org/abs/2302.13971
|
||||
"""
|
||||
def __init__(self,
|
||||
num_q_heads: int,
|
||||
emb_size: int,
|
||||
head_size: int,
|
||||
max_seq_len: int,
|
||||
rope: RoPE,
|
||||
dropout: float = 0.1
|
||||
):
|
||||
"""
|
||||
Конструктор слоя GemmaDecoder.
|
||||
|
||||
Производит инициализацию всех подслоёв (нормализация, multi-head или multi-query attention, feed-forward блок, Dropout)
|
||||
согласно архитектуре декодера Gemma. Обеспечивает поддержку rotary-позиционирования, обучения и inference с caching.
|
||||
|
||||
Аргументы:
|
||||
----------
|
||||
num_q_heads : int
|
||||
Количество query-голов в attention (определяет степень параллелизма внимания).
|
||||
emb_size : int
|
||||
Размер пространства эмбеддинга (embedding dim, input/output размерность слоя).
|
||||
head_size : int
|
||||
Размерность одной attention-головы. Обычно emb_size // num_q_heads.
|
||||
max_seq_len : int
|
||||
Максимальная длина последовательности, для которой поддерживается attention и маскирование.
|
||||
rope : RoPE
|
||||
Объект для rotary positional encoding (позиционное кодирование для attention).
|
||||
dropout : float, default=0.1
|
||||
Dropout после attention и feed-forward для регуляризации (обычно 0.0–0.1).
|
||||
|
||||
Внутри:
|
||||
-------
|
||||
- Инициализируются все слои norm, attention, rope, FFN, остаточные соединения.
|
||||
- Строится causal-маска автоагрессивного attention (если требуется).
|
||||
- Гибко поддерживает работу как на training, так и для быстрых inference/генерации.
|
||||
|
||||
Пример:
|
||||
-------
|
||||
>>> decoder = GemmaDecoder(
|
||||
... num_q_heads=8, emb_size=512, head_size=64, max_seq_len=1024, rope=rope_obj, dropout=0.05
|
||||
... )
|
||||
"""
|
||||
super().__init__()
|
||||
self._heads = MultiQueryAttention(
|
||||
num_q_heads=num_q_heads,
|
||||
emb_size=emb_size,
|
||||
head_size=head_size,
|
||||
max_seq_len=max_seq_len,
|
||||
rope=rope,
|
||||
dropout=dropout
|
||||
)
|
||||
self._ff = GeGLU(emb_size=emb_size, dropout=dropout)
|
||||
self._norm1 = RMSNorm(emb_size)
|
||||
self._norm2 = RMSNorm(emb_size)
|
||||
|
||||
def forward(self, x: torch.Tensor, mask: torch.Tensor = None, use_cache: bool = True, cache: list = None) -> torch.Tensor:
|
||||
"""
|
||||
Прямой проход (forward) через GemmaDecoder.
|
||||
|
||||
Последовательно реализует:
|
||||
- Нормализацию входа (обычно RMSNorm или LayerNorm)
|
||||
- Self-attention (multi-query или multi-head, с опциональной маской и кэшем)
|
||||
- Остаточное сложение (skip connection)
|
||||
- Вторую нормализацию
|
||||
- Feed-Forward-блок (например, GeGLU/SwiGLU)
|
||||
- Ещё одно residual сложение
|
||||
|
||||
Поддерживает autoregressive режим с caching (KV-слоты attention для ускорения генерации).
|
||||
|
||||
Аргументы:
|
||||
----------
|
||||
x : torch.Tensor
|
||||
Входной скрытый тензор формы [batch_size, seq_length, emb_size].
|
||||
mask : torch.Tensor, optional
|
||||
Attention mask (например, causal или padding mask). Если None, используется встроенная causal mask.
|
||||
use_cache : bool, по умолчанию True
|
||||
Если True — возвращается кэш KV для ускорения autoregressive генерации.
|
||||
cache : list, optional
|
||||
Кэш предыдущих ключей/значений attention (если используется при инференсе).
|
||||
|
||||
Возвращает:
|
||||
-----------
|
||||
Tuple[torch.Tensor, cache]:
|
||||
- Выход декодера с той же формой [batch_size, seq_length, emb_size]
|
||||
- Кэш attention (если use_cache=True), иначе None
|
||||
|
||||
Пример:
|
||||
-------
|
||||
>>> out, new_cache = decoder(x, mask=att_mask, use_cache=True, cache=old_cache)
|
||||
>>> out.shape # [batch_size, seq_len, emb_size]
|
||||
|
||||
Примечания:
|
||||
-----------
|
||||
- mask используется для ограничения внимания (напр., каузальный режим GPT/LLM).
|
||||
- Для ускорения в режиме генерации рекомендуется использовать use_cache=True + передавать cache.
|
||||
|
||||
"""
|
||||
norm1_out = self._norm1(x)
|
||||
attention, kv_caches = self._heads(norm1_out, mask, use_cache=use_cache, cache=cache)
|
||||
out = attention + x
|
||||
|
||||
norm2_out = self._norm2(out)
|
||||
ffn_out = self._ff(norm2_out)
|
||||
|
||||
if use_cache is True:
|
||||
return (ffn_out + out, kv_caches)
|
||||
else:
|
||||
return (ffn_out + out, None)
|
||||
138
llm/src/llm/core/gpt_decoder.py
Normal file
138
llm/src/llm/core/gpt_decoder.py
Normal file
@@ -0,0 +1,138 @@
|
||||
from torch import nn
|
||||
import torch
|
||||
from .feed_forward import FeedForward
|
||||
from .multi_head_attention import MultiHeadAttention
|
||||
|
||||
|
||||
class GptDecoder(nn.Module):
|
||||
"""
|
||||
Decoder — базовый transformer decoder block (pre-LN), классический строительный блок современных языковых моделей.
|
||||
|
||||
Назначение:
|
||||
-----------
|
||||
- Инкапсулирует архитектуру: norm → multi-head self-attention → residual → norm → feed-forward → residual
|
||||
- Подходит как для LLM/GPT, так и для любых autoregressive sequence моделей.
|
||||
- Использует masked self-attention: каждый токен видит только предыдущие (никакого \"заглядывания в будущее\").
|
||||
- Стабильность обеспечивается через residual connections и LayerNorm после каждого sub-layer.
|
||||
|
||||
Почему это важно?
|
||||
-----------------
|
||||
- Все современные языковые модели состоят из подобных блоков, соединённых в стек.
|
||||
- Алгоритм residual+norm позволяет проще обучать очень глубокие сети.
|
||||
- Разделение на attention+FFN дает и локальные, и глобальные взаимодействия между токенами.
|
||||
|
||||
Формула работы (псевдокод):
|
||||
---------------------------
|
||||
y1 = norm1(x)
|
||||
attn_out = Attention(y1)
|
||||
x2 = x + attn_out # residual
|
||||
y2 = norm2(x2)
|
||||
ffn_out = FFN(y2)
|
||||
out = x2 + ffn_out # residual
|
||||
|
||||
Архитектурные особенности:
|
||||
--------------------------
|
||||
- Поддержка внимания с маской (causal mask или произвольная attention mask)
|
||||
- Residual connections для каждого блока (attention, FFN)
|
||||
- Pre-LN (norm перед каждым подблоком)
|
||||
- Зависит от переданных блоков self_attention и feed_forward, а не их реализации
|
||||
|
||||
References:
|
||||
-----------
|
||||
- Vaswani et al., \"Attention is All You Need\" (2017): https://arxiv.org/abs/1706.03762
|
||||
- Illustrated Transformer: https://jalammar.github.io/illustrated-transformer/
|
||||
- Transformer Circuits (дружественное описание): https://transformer-circuits.pub/2021/framework/index.html
|
||||
|
||||
Пример:
|
||||
-------
|
||||
>>> decoder = Decoder(num_heads=8, emb_size=512, head_size=64, max_seq_len=1024)
|
||||
>>> x = torch.randn(1, 10, 512)
|
||||
>>> out = decoder(x)
|
||||
>>> print(out.shape) # torch.Size([1, 10, 512])
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
emb_size: int,
|
||||
head_size: int,
|
||||
max_seq_len: int,
|
||||
dropout: float = 0.1,
|
||||
):
|
||||
"""
|
||||
Инициализация стандартного decoder-блока для Transformer.
|
||||
|
||||
Аргументы:
|
||||
----------
|
||||
num_heads: int
|
||||
Количество attention голов (как делить emb_size на heads)
|
||||
emb_size: int
|
||||
Размерность эмбеддингов (и входа и выхода)
|
||||
head_size: int
|
||||
Размерность одной attention-головы (emb_size = num_heads * head_size)
|
||||
max_seq_len: int
|
||||
Максимальная длина последовательности (важно для mask)
|
||||
dropout: float, default=0.1
|
||||
Dropout после внимания и FFN
|
||||
|
||||
Внутри:
|
||||
-------
|
||||
- Создаёт слой MultiHeadAttention (masked/casual)
|
||||
- Создаёт двухслойный FeedForward (SwiGLU или GELU)
|
||||
- Применяет 2 слоя LayerNorm для стабилизации градиентов
|
||||
- Все блоки реализованы как PyTorch-модули
|
||||
"""
|
||||
super().__init__()
|
||||
self._heads = MultiHeadAttention(
|
||||
num_heads=num_heads,
|
||||
emb_size=emb_size,
|
||||
head_size=head_size,
|
||||
max_seq_len=max_seq_len,
|
||||
dropout=dropout,
|
||||
)
|
||||
self._ff = FeedForward(emb_size=emb_size, dropout=dropout)
|
||||
self._norm1 = nn.LayerNorm(emb_size)
|
||||
self._norm2 = nn.LayerNorm(emb_size)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
use_cache: bool = False,
|
||||
cache: list = None,
|
||||
attention_mask=None
|
||||
) -> tuple:
|
||||
"""
|
||||
Один прямой проход через Transformer decoder block.
|
||||
|
||||
Аргументы:
|
||||
----------
|
||||
x : torch.Tensor
|
||||
Входной тензор [batch_size, seq_len, emb_size]
|
||||
mask : torch.Tensor, optional
|
||||
Attention/causal mask (по умолчанию None, тогда будет casual mask по длине seq_len)
|
||||
|
||||
Возвращает:
|
||||
-----------
|
||||
out : torch.Tensor
|
||||
Выходной тензор той же формы, что и x
|
||||
|
||||
Алгоритм:
|
||||
---------
|
||||
- Применяем attention к нормализованному входу (layernorm)
|
||||
- Добавляем residual-связь (attention + исходный вход)
|
||||
- Применяем FFN к нормализованному результату (layernorm)
|
||||
- Добавляем residual-связь (ffn + предыдущий выход)
|
||||
"""
|
||||
|
||||
# Self-Attention блок
|
||||
attention, kv_caches = self._heads(x, attention_mask, use_cache=use_cache, cache=cache)
|
||||
out = self._norm1(attention + x)
|
||||
|
||||
# FeedForward блок
|
||||
ffn_out = self._ff(out)
|
||||
result = self._norm2(ffn_out + out)
|
||||
|
||||
if use_cache:
|
||||
return (result, kv_caches)
|
||||
else:
|
||||
return (result, None)
|
||||
413
llm/src/llm/core/group_query_attention.py
Normal file
413
llm/src/llm/core/group_query_attention.py
Normal file
@@ -0,0 +1,413 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from llm.core.rope import RoPE
|
||||
|
||||
class GroupedQueryAttention(nn.Module):
|
||||
"""
|
||||
Grouped Query Attention (GQA)
|
||||
=============================
|
||||
|
||||
Что такое Grouped Query Attention?
|
||||
----------------------------------
|
||||
Это разновидность многоголового внимания (multi-head), где для Q (query) голов может быть больше, чем для K/V (key/value) голов:
|
||||
вместо стандартного MHA (num_q_heads == num_kv_heads) — меньшее число K/V разделяет информацию для всех Q.
|
||||
Такой подход экономит память и ускоряет инференс, сохраняя высокое качество внимания (используется например в Mistral, Llama-2, GPT-4 и др.).
|
||||
|
||||
Зачем это нужно?
|
||||
----------------
|
||||
- Сокращает количество вычислений и размер KV-кэша в больших LLM.
|
||||
- Позволяет эффективно масштабировать число attention-глав для моделирования сложных связей, не увеличивая размер всех матриц.
|
||||
|
||||
Как работает?
|
||||
-------------
|
||||
1. Q формируется для каждого query-head (их много)
|
||||
2. K и V вычисляется только для меньшего числа KV-heads (обычно в 2-4 раза меньше, чем Q)
|
||||
3. К/V heads дублируются (repeat) так, чтобы на каждую Q-head был свой набор
|
||||
4. Всё внимание (Q,K,V) — стандартное scaled dot-product, только более эффективно и с компрессией
|
||||
|
||||
Поддержка дополнительных фич:
|
||||
-----------------------------
|
||||
- Rotary Position Encoding (RoPE) для Q и K (для относительной позиции)
|
||||
- Sliding-window attention mask (можно ограничить исторический контекст, как в Mistral)
|
||||
- Кэширование Q/K/V (ускоряет генерацию автоагретивно)
|
||||
|
||||
Аргументы конструктора:
|
||||
-----------------------
|
||||
num_q_heads: int — количество query голов (Q)
|
||||
num_kv_heads: int — количество key/value голов (обычно меньше Q)
|
||||
emb_size: int — embedding размерность
|
||||
head_size: int — размер каждой attention-head
|
||||
max_seq_len: int — максимальная длина последовательности
|
||||
window_size: int — размер sliding window (макс. количество токенов в контексте внимания)
|
||||
rope: RoPE (по желанию) — если задан, то будет применяться RoPE для Q и K
|
||||
dropout: float — dropout после линейной проекции
|
||||
|
||||
Пример использования:
|
||||
---------------------
|
||||
>>> gqa = GroupedQueryAttention(num_q_heads=8, num_kv_heads=2, emb_size=256, head_size=32, max_seq_len=1024, window_size=256)
|
||||
>>> x = torch.randn(2, 128, 256)
|
||||
>>> y, cache = gqa(x)
|
||||
>>> print(y.shape) # torch.Size([2, 128, 256])
|
||||
|
||||
Где прочитать подробнее:
|
||||
------------------------
|
||||
- LlamaV2 (Section 2.3): https://arxiv.org/abs/2307.09288
|
||||
- Mistral: https://arxiv.org/abs/2310.06825
|
||||
- \"Self-attention with linear complexity\" (Vila et al.): https://arxiv.org/abs/2302.05442
|
||||
- Обзор: https://huggingface.co/blog/mistral
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_q_heads: int,
|
||||
num_kv_heads: int,
|
||||
emb_size: int,
|
||||
head_size: int,
|
||||
max_seq_len: int,
|
||||
window_size: int,
|
||||
rope: RoPE = None,
|
||||
dropout: float = 0.1,
|
||||
):
|
||||
"""
|
||||
Инициализация слоя Grouped Query Attention (GQA).
|
||||
|
||||
Этот конструктор задаёт архитектуру эффективного внимания, где Q-голов может быть больше, чем KV-голов.
|
||||
Это экономит память/вычисления и позволяет реализовать сдвигающееся "окно" внимания (Mistral-style).
|
||||
|
||||
Аргументы:
|
||||
----------
|
||||
num_q_heads : int
|
||||
Количество Query attention heads (чаще всего кратно num_kv_heads, напр. 8/2, 12/4).
|
||||
Чем больше — тем богаче контекстное окно каждой позиции.
|
||||
num_kv_heads : int
|
||||
Количество Key/Value attention heads (обычно 2-4, иногда меньше, чем Query).
|
||||
В современных LLM принято уменьшать их число для оптимизации скорости/кэша.
|
||||
emb_size : int
|
||||
Размерность входного embedding (общий размер вектора на токен).
|
||||
head_size : int
|
||||
Размерность одной головы внимания.
|
||||
Требуется: num_q_heads * head_size == emb_size (иначе ошибка).
|
||||
max_seq_len : int
|
||||
Максимальная поддерживаемая длина входной последовательности; определяет размер триангулярной (causal/sliding window) маски.
|
||||
window_size : int
|
||||
Размер "скользящего окна" истории — сколько токенов учитывается при слепом внимании (как у Mistral).
|
||||
Чем меньше значение, тем локальнее работает внимание (и меньше память/время).
|
||||
rope : RoPE, опционально
|
||||
Если задан — применяется Rotary Positional Encoding к Q и K для относительного позиционного кодирования.
|
||||
dropout : float, по умолчанию 0.1
|
||||
Dropout после линейной проекции attention (обычно 0.1, помогает борьбе с переобучением).
|
||||
|
||||
Что создаётся внутри:
|
||||
---------------------
|
||||
- Линейные слои для получения Q, K, V из embedding.
|
||||
- Буфер для causal/sliding window mask (матрица масок в зависимости от window_size и max_seq_len).
|
||||
- Линейный слой для финального преобразования (объединение всех голов и возврат к emb_size).
|
||||
- Dropout перед возвратом.
|
||||
|
||||
Пример:
|
||||
-------
|
||||
>>> attn = GroupedQueryAttention(
|
||||
... num_q_heads=8, num_kv_heads=2, emb_size=256, head_size=32,
|
||||
... max_seq_len=1024, window_size=256, dropout=0.1)
|
||||
"""
|
||||
super().__init__()
|
||||
self._num_heads = num_q_heads
|
||||
self._num_kv_heads = num_kv_heads
|
||||
self._head_size = head_size
|
||||
self._max_seq_len = max_seq_len
|
||||
self._rope = rope
|
||||
self._window_size = window_size
|
||||
|
||||
self._q = nn.Linear(emb_size, self._num_heads * head_size)
|
||||
self._k = nn.Linear(emb_size, num_kv_heads * head_size)
|
||||
self._v = nn.Linear(emb_size, num_kv_heads * head_size)
|
||||
|
||||
# Создание causal маски
|
||||
mask = self._create_sliding_window_mask(max_seq_len, self._window_size)
|
||||
self.register_buffer(
|
||||
"_tril_mask", mask.bool() if hasattr(torch, "bool") else mask.byte()
|
||||
)
|
||||
|
||||
self._layer = nn.Linear(head_size * self._num_heads, emb_size)
|
||||
self._dropout = nn.Dropout(dropout)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
mask: torch.Tensor = None,
|
||||
use_cache: bool = True,
|
||||
cache: list = None,
|
||||
):
|
||||
"""
|
||||
Шаг внимания в режиме Grouped Query Attention —
|
||||
реализует эффективное многооконное внимание с раздельными Q/KV и sliding/casual mask.
|
||||
|
||||
Что происходит в этом методе:
|
||||
-----------------------------
|
||||
- Преобразует входной тензор x (токеновые эмбеддинги) в Q, K, V-матрицы с учётом разного числа голов для Q и KV.
|
||||
- Формирует attention "маску" для sliding window, если нужно ограничить историю.
|
||||
- Применяет RoPE (если задан) к Q и K, вносит позиционную информацию.
|
||||
- При работе с кэшем дополняет ключи и значения предыдущими (ускоряет генерацию).
|
||||
- Повторяет K/V головы для соответствия количеству Q (чтобы на каждую Q-head приходился свой KV).
|
||||
- Считает обычное scaled dot-product внимание, применяет маску (не даёт видеть будущее, как и в autoregressive).
|
||||
- Softmax, смешивание V на основе attention, объединение всех голов.
|
||||
- Dropout и финальное линейное преобразование обратно к emb_size.
|
||||
|
||||
Аргументы:
|
||||
----------
|
||||
x : torch.Tensor
|
||||
Входной тензор размера [batch, seq_len, emb_size]
|
||||
mask : torch.Tensor, по умолчанию None
|
||||
Матричная маска для внимания (можно передать внешнюю или использовать встроенную sliding window mask)
|
||||
use_cache : bool, по умолчанию True
|
||||
Нужно ли использовать/возвращать кэш KV для быстрых автогенераций.
|
||||
cache : list, опционально
|
||||
Ранее сохранённый кэш KV (используется для инференса по одному токену)
|
||||
|
||||
Возвращает:
|
||||
-----------
|
||||
- output: torch.Tensor формы [batch, seq_len, emb_size]
|
||||
- kv_cache: кэш новых KV (если use_cache=True), иначе None
|
||||
|
||||
Важно:
|
||||
-------
|
||||
- Реализует Mistral-style attention: к каждой Q-head в итоге “приписан” собственный (но потенциально дублированный) KV-head.
|
||||
- Sliding window ограничивает область вижимости в attention (ускоряет генерацию на длинных последовательностях).
|
||||
- Использование RoPE опционально — но необходимо для современных архитектур LLM.
|
||||
|
||||
Пример:
|
||||
-------
|
||||
>>> attn = GroupedQueryAttention(num_q_heads=8, num_kv_heads=2, emb_size=256, head_size=32, max_seq_len=1024, window_size=256)
|
||||
>>> x = torch.randn(2, 128, 256)
|
||||
>>> y, kv_cache = attn(x)
|
||||
>>> print(y.shape) # torch.Size([2, 128, 256])
|
||||
"""
|
||||
batch_size, seq_len, emb_size = x.shape
|
||||
|
||||
if seq_len > self._max_seq_len:
|
||||
raise ValueError(
|
||||
f"Длина последовательности {seq_len} превышает максимум {self._max_seq_len}"
|
||||
)
|
||||
|
||||
# Пропустите тензор x через матрицы Wq, Wk , Wv, чтобы получить матрицы запроса, ключа и значения.
|
||||
k = self._k(x) # [B, T, hs]
|
||||
q = self._q(x) # [B, T, hs]
|
||||
v = self._v(x) # [B, T, hs]
|
||||
|
||||
# Шаг 2: Изменение формы для multi-head
|
||||
# [batch_size, seq_len, num_heads * head_size]
|
||||
# -> [batch_size, seq_len, num_heads, head_size]
|
||||
# Измените форму запроса (query) на batch_size × num_q_heads × seq_len × head_size.
|
||||
q = q.reshape(batch_size, seq_len, self._num_heads, self._head_size)
|
||||
|
||||
# Измените форму ключа (key) и значения (value) на batch_size × num_kv_heads × seq_len × head_size.
|
||||
k = k.reshape(batch_size, seq_len, self._num_kv_heads, self._head_size)
|
||||
v = v.reshape(batch_size, seq_len, self._num_kv_heads, self._head_size)
|
||||
|
||||
|
||||
# 3. Transpose: [B, T, H, hs] -> [B, H, T, hs]
|
||||
q = q.transpose(1, 2)
|
||||
k = k.transpose(1, 2)
|
||||
v = v.transpose(1, 2)
|
||||
|
||||
start_pos = 0
|
||||
if cache is not None:
|
||||
k_cache, v_cache = cache
|
||||
cache_len = k_cache.shape[2]
|
||||
start_pos = cache_len
|
||||
|
||||
# Пропустите матрицы запроса и ключа через экземпляр rope, чтобы выполнить поворот.
|
||||
if self._rope is not None:
|
||||
# Применяем RoPE к Q и K (НЕ к V!)
|
||||
q = self._rope(q, start_pos=start_pos) # [B, T, hs]
|
||||
k = self._rope(k, start_pos=start_pos) # [B, T, hs]
|
||||
|
||||
# Если cache пришел, то объединяем кэш и одну строку из ключа и значения. Это будут новые key и value для последующих вычислений.
|
||||
# 5. Кэширование (для autoregressive generation)
|
||||
if cache is not None:
|
||||
k_cache, v_cache = cache
|
||||
k = torch.cat([k_cache, k], dim=2) # Concat по seq_len (dim=2)
|
||||
v = torch.cat([v_cache, v], dim=2)
|
||||
|
||||
# Если use_cache == True, то сохраните матрицы ключа и значения для кэша (это нужно сделать до дублирования голов).
|
||||
#if use_cache == True:
|
||||
# # Обрезаем до последних window_size токенов
|
||||
# k_to_cache = k[:, :, -self._window_size:, :]
|
||||
# v_to_cache = v[:, :, -self._window_size:, :]
|
||||
# kv_cache = (k_to_cache, v_to_cache)
|
||||
|
||||
# Продублируйте головы в тензорах ключа (key) и значения (value), чтобы получился тензор размера на batch_size × num_q_heads × seq_len × head_size.
|
||||
#k = self._repeat_kv_heads(k, self._num_heads, self._num_kv_heads)
|
||||
#v = self._repeat_kv_heads(v, self._num_heads, self._num_kv_heads)
|
||||
k_expanded = self._repeat_kv_heads(k, self._num_heads, self._num_kv_heads)
|
||||
v_expanded = self._repeat_kv_heads(v, self._num_heads, self._num_kv_heads)
|
||||
|
||||
# Перемножим матрицы запроса и ключа (транспонированную), чтобы вычислить матрицу внимания.
|
||||
# И разделить все значения в матрице внимания на корень из head_size.
|
||||
scores = q @ k_expanded.transpose(-2, -1) / (self._head_size ** 0.5)
|
||||
|
||||
# 8. Применение маски
|
||||
k_seq_len = k_expanded.size(2) # Длина K после concat с кэшем
|
||||
|
||||
if cache is None:
|
||||
# Случай 1: Без кэша - полная квадратная маска
|
||||
# scores: [B, H, seq_len, seq_len]
|
||||
# Применяем маску [:seq_len, :seq_len]
|
||||
scores = scores.masked_fill(
|
||||
~self._tril_mask[:seq_len, :seq_len],
|
||||
float("-inf")
|
||||
)
|
||||
|
||||
# Применить к матрице внимания (построчно) функцию Softmax.
|
||||
weights = F.softmax(scores, dim=-1)
|
||||
|
||||
# Перемножим матрицу внимания и матрицу значения.
|
||||
x_out = weights @ v_expanded # [B, T, hs]
|
||||
|
||||
# Измените форму тензора на batch_size × seq_len × num_heads*head_size.
|
||||
# Transpose обратно и concatenate heads
|
||||
x_out = x_out.transpose(1, 2) # [B, T_q, H, hs]
|
||||
x_out = x_out.contiguous() # Важно для reshape!
|
||||
concatenated_attention = x_out.reshape(batch_size, seq_len, self._num_heads * self._head_size)
|
||||
|
||||
#concatenated_attention = x_out.reshape(batch_size, seq_len, self._num_heads * self._head_size)
|
||||
|
||||
# Пропустите получившийся тензор через последний линейный слой.
|
||||
# 3. Проецируем в пространство эмбеддингов
|
||||
projected_output = self._layer(concatenated_attention)
|
||||
|
||||
# 4. Применяем dropout для регуляризации
|
||||
output = self._dropout(projected_output)
|
||||
|
||||
if use_cache:
|
||||
# Обрезаем оригинальный K и V (до дублирования)
|
||||
k_to_cache = k[:, :, -self._window_size:, :]
|
||||
v_to_cache = v[:, :, -self._window_size:, :]
|
||||
kv_cache = (k_to_cache, v_to_cache)
|
||||
return output, kv_cache
|
||||
else:
|
||||
return output, None
|
||||
|
||||
def _repeat_kv_heads(
|
||||
self,
|
||||
kv: torch.Tensor,
|
||||
num_q_heads: int,
|
||||
num_kv_heads: int
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Приводит число голов K/V к числу голов Q путём поэлементного повторения (tile) KV-голов.
|
||||
|
||||
Зачем это нужно?
|
||||
----------------
|
||||
В Grouped Query Attention (Mistral, Llama-2, GPT-4 и др.) обычно num_kv_heads < num_q_heads.
|
||||
Чтобы каждая Query-head могла смотреть на свою собственную (пусть и общую) KV, мы "нарезаем" или повторяем KV столько раз, сколько требуется — это экономит память и ускоряет генерацию.
|
||||
|
||||
Алгоритм:
|
||||
---------
|
||||
- kv имеет форму [batch_size, num_kv_heads, seq_len, head_size]
|
||||
- Для каждого KV-head делается n_repeat = num_q_heads // num_kv_heads по head-axis (обычно целое)
|
||||
- На выходе форма [batch_size, num_q_heads, seq_len, head_size], где каждый KV-head дублирован для нужного количества Q-heads.
|
||||
|
||||
Args:
|
||||
-----
|
||||
kv : torch.Tensor
|
||||
Входной тензор KV (обычно после linear layer on эмбеддинги), размер [batch_size, num_kv_heads, seq_len, head_size]
|
||||
num_q_heads : int
|
||||
Сколько должно быть Q-голов (их больше!)
|
||||
num_kv_heads : int
|
||||
Сколько KV-голов было (их меньше!)
|
||||
|
||||
Returns:
|
||||
--------
|
||||
torch.Tensor формы [batch_size, num_q_heads, seq_len, head_size], где KV-головы повторены как требуется.
|
||||
|
||||
Пример:
|
||||
-------
|
||||
num_q_heads = 8, num_kv_heads = 2
|
||||
[KV0, KV1] -> [KV0, KV0, KV0, KV0, KV1, KV1, KV1, KV1]
|
||||
# Каждый KV-head дублируется 4 раза, чтобы покрыть все 8 Q-heads.
|
||||
"""
|
||||
batch_size, num_kv_heads, seq_len, head_size = kv.shape
|
||||
|
||||
if num_q_heads == num_kv_heads:
|
||||
# Нет необходимости дублировать
|
||||
return kv
|
||||
|
||||
# Вычисляем сколько раз нужно повторить каждую голову
|
||||
num_repeats = num_q_heads // num_kv_heads
|
||||
|
||||
# repeat_interleave дублирует каждую голову num_repeats раз
|
||||
# [B, num_kv_heads, S, hs] -> [B, num_q_heads, S, hs]
|
||||
# [B, num_kv_heads, S, hs] -> [B, num_kv_heads, 1, S, hs]
|
||||
kv = kv.unsqueeze(2)
|
||||
|
||||
# [B, num_kv_heads, 1, S, hs] -> [B, num_kv_heads, num_repeats, S, hs]
|
||||
kv = kv.repeat(1, 1, num_repeats, 1, 1)
|
||||
|
||||
# [B, num_kv_heads, num_repeats, S, hs] -> [B, num_q_heads, S, hs]
|
||||
kv = kv.reshape(batch_size, num_q_heads, seq_len, head_size)
|
||||
|
||||
|
||||
return kv
|
||||
|
||||
def _create_sliding_window_mask(
|
||||
self,
|
||||
max_seq_len: int,
|
||||
window_size: int,
|
||||
device: torch.device = None
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Создаёт маску для Sliding Window Attention (ограниченного окна внимания).
|
||||
|
||||
Зачем нужна эта маска?
|
||||
----------------------
|
||||
В современных LLM (например, Mistral) self-attention работает не по всей истории, а только в узком "скользящем окне":
|
||||
каждый токен видит только предшествующие (или соседние) токены на расстоянии window_size.
|
||||
Это ускоряет инференс на длинных текстах и экономит память, но сохраняет ключевые зависимости в пределах окна.
|
||||
|
||||
Как работает алгоритм:
|
||||
----------------------
|
||||
- Для каждого токена mask[i, j] == True только если токен j находится СЛЕВА и не дальше, чем window_size позиций (или сам i).
|
||||
- Главное: mask всегда "нижнетреугольная" (causal), плюс полоса шириной window_size вдоль главной диагонали.
|
||||
- Всё за пределами окна — False (attention нельзя).
|
||||
|
||||
Args:
|
||||
-----
|
||||
max_seq_len : int
|
||||
Максимальная длина последовательности (размер будущей attention-матрицы).
|
||||
window_size : int
|
||||
Сколько предыдущих токенов доступно для внимания у каждого шага (вкл. сам себя).
|
||||
device : torch.device, опционально
|
||||
На каком устройстве (cpu/gpu) создавать маску.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
torch.Tensor
|
||||
Маска внимания формы [max_seq_len, max_seq_len], где True — допускается внимание (иначе False).
|
||||
|
||||
Пример:
|
||||
-------
|
||||
>>> mask = create_sliding_window_mask(8, 3)
|
||||
>>> print(mask.int())
|
||||
tensor([[1, 0, 0, 0, 0, 0, 0, 0],
|
||||
[1, 1, 0, 0, 0, 0, 0, 0],
|
||||
[1, 1, 1, 0, 0, 0, 0, 0],
|
||||
[0, 1, 1, 1, 0, 0, 0, 0],
|
||||
[0, 0, 1, 1, 1, 0, 0, 0],
|
||||
[0, 0, 0, 1, 1, 1, 0, 0],
|
||||
[0, 0, 0, 0, 1, 1, 1, 0],
|
||||
[0, 0, 0, 0, 0, 1, 1, 1]])
|
||||
"""
|
||||
row_indices = torch.arange(max_seq_len, device=device).unsqueeze(1) # [max_seq_len, 1]
|
||||
col_indices = torch.arange(max_seq_len, device=device).unsqueeze(0) # [1, max_seq_len]
|
||||
|
||||
causal_mask = col_indices <= row_indices
|
||||
|
||||
window_mask = (row_indices - col_indices) <= window_size
|
||||
|
||||
mask = causal_mask & window_mask
|
||||
|
||||
return mask
|
||||
@@ -1,103 +0,0 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from math import sqrt
|
||||
from .rope import RoPE
|
||||
|
||||
class HeadAttention(nn.Module):
|
||||
"""
|
||||
Одноголовый механизм внимания (scaled dot-product attention) — фундаментальный строительный блок всех современных Transformer.
|
||||
|
||||
Научная суть:
|
||||
- Attention учит модель самостоятельно "выбирать" важные связи между словами, независимо от их положения.
|
||||
- Механизм causal mask гарантирует невозможность "заглядывания в будущее" при генерации (авторегрессия).
|
||||
|
||||
Формула:
|
||||
Attention(Q, K, V) = softmax(QK^T / sqrt(d_k)) · V
|
||||
(Q — запросы, K — ключи, V — значения; d_k — размерность ключа)
|
||||
|
||||
Поддерживает Rotary Position Encoding (RoPE) для относительного позиционного кодирования.
|
||||
|
||||
Args:
|
||||
emb_size (int): размер входного эмбеддинга
|
||||
head_size (int): размерность attention-головы
|
||||
max_seq_len (int): максимальная длина последовательности
|
||||
rope (RoPE, optional): экземпляр RoPE для позиций
|
||||
|
||||
Примечания:
|
||||
- Использует нижнетреугольную маску для предотвращения "заглядывания в будущее"
|
||||
- Автоматически адаптируется к разным версиям PyTorch
|
||||
- Поддерживает batch-обработку входных данных
|
||||
|
||||
Пример использования:
|
||||
>>> attention = HeadAttention(emb_size=64, head_size=32, max_seq_len=128)
|
||||
>>> x = torch.randn(1, 10, 64)
|
||||
>>> output, _ = attention(x)
|
||||
>>> print(output.shape) # torch.Size([1, 10, 32])
|
||||
"""
|
||||
def __init__(self, emb_size: int, head_size: int, max_seq_len: int, rope: RoPE = None):
|
||||
super().__init__()
|
||||
self._emb_size = emb_size
|
||||
self._head_size = head_size
|
||||
self._max_seq_len = max_seq_len
|
||||
self._rope = rope
|
||||
|
||||
# Линейные преобразования для Q, K, V
|
||||
self._k = nn.Linear(emb_size, head_size)
|
||||
self._q = nn.Linear(emb_size, head_size)
|
||||
self._v = nn.Linear(emb_size, head_size)
|
||||
|
||||
# Создание causal маски
|
||||
mask = torch.tril(torch.ones(max_seq_len, max_seq_len))
|
||||
self.register_buffer('_tril_mask', mask.bool() if hasattr(torch, 'bool') else mask.byte())
|
||||
|
||||
def forward(self, x: torch.Tensor, use_cache: bool = True, cache: tuple = None) -> tuple:
|
||||
"""
|
||||
Прямой проход через слой внимания.
|
||||
|
||||
Аргументы:
|
||||
x (torch.Tensor): Входной тензор формы [batch_size, seq_len, emb_size]
|
||||
|
||||
Возвращает:
|
||||
torch.Tensor: Выходной тензор формы [batch_size, seq_len, head_size]
|
||||
|
||||
Исключения:
|
||||
ValueError: Если длина последовательности превышает max_seq_len
|
||||
|
||||
Пример внутренних преобразований:
|
||||
Для входа x.shape = [2, 5, 64]:
|
||||
1. Q/K/V преобразования -> [2, 5, 32]
|
||||
2. Scores = Q·K^T -> [2, 5, 5]
|
||||
3. После маски и softmax -> [2, 5, 5]
|
||||
4. Умножение на V -> [2, 5, 32]
|
||||
"""
|
||||
seq_len = x.shape[1]
|
||||
if seq_len > self._max_seq_len:
|
||||
raise ValueError(f"Длина последовательности {seq_len} превышает максимум {self._max_seq_len}")
|
||||
|
||||
k = self._k(x) # [B, T, hs]
|
||||
q = self._q(x) # [B, T, hs]
|
||||
v = self._v(x) # [B, T, hs]
|
||||
|
||||
if self._rope is not None:
|
||||
# ✅ Применяем RoPE к Q и K (НЕ к V!)
|
||||
q = self._rope(q) # [B, T, hs]
|
||||
k = self._rope(k) # [B, T, hs]
|
||||
|
||||
if cache is not None:
|
||||
k_cache, v_cache = cache
|
||||
k = torch.cat([k_cache, k], dim=1) # [B, cache_len + T, hs]
|
||||
v = torch.cat([v_cache, v], dim=1) # [B, cache_len + T, hs]
|
||||
|
||||
scores = q @ k.transpose(-2, -1) / sqrt(self._head_size)
|
||||
|
||||
if cache is None:
|
||||
scores = scores.masked_fill(~self._tril_mask[:seq_len, :seq_len], float('-inf'))
|
||||
|
||||
weights = F.softmax(scores, dim=-1)
|
||||
x_out = weights @ v # [B, T, hs]
|
||||
|
||||
if use_cache is True:
|
||||
return (x_out, (k, v))
|
||||
else:
|
||||
return (x_out, None)
|
||||
134
llm/src/llm/core/mistral_decoder.py
Normal file
134
llm/src/llm/core/mistral_decoder.py
Normal file
@@ -0,0 +1,134 @@
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from llm.core.rms_norm import RMSNorm
|
||||
from llm.core.swi_glu import SwiGLU
|
||||
from llm.core.rope import RoPE
|
||||
from llm.core.group_query_attention import GroupedQueryAttention
|
||||
|
||||
class MistralDecoder(nn.Module):
|
||||
"""
|
||||
MistralDecoder — стек декодирующих блоков, реализующий архитектуру Mistral-style Transformer.
|
||||
|
||||
Назначение:
|
||||
-----------
|
||||
Этот класс описывает один или несколько блоков декодера, включающих Grouped Query Attention (GQA),
|
||||
sliding window attention и SwiGLU feed-forward, как реализовано в моделях Mistral и Llama 2.
|
||||
|
||||
Ключевые особенности архитектуры:
|
||||
---------------------------------
|
||||
- Использует GQA: для каждого токена вычисляется attention c раздельным числом Q и KV голов (сильно ускоряет LLM).
|
||||
- Sliding Window Attention: внимание ограничено окном из window_size элементов (ускоряет обработку длинных текстов).
|
||||
- Rotary Positional Embedding (RoPE): позиционная информация интегрируется вращением Q/K.
|
||||
- RMSNorm перед и после внимания и FFN (устойчивое обучение).
|
||||
- SwiGLU в качестве нелинейности вместо стандартного GELU (больше capacity в модели).
|
||||
|
||||
Аргументы конструктора:
|
||||
-----------------------
|
||||
num_layers : int — сколько блоков-декодеров в стеке
|
||||
параметры GQA: num_q_heads, num_kv_heads, emb_size, head_size, max_seq_len, window_size, rope, dropout
|
||||
- все они идут в каждый слой (блок) декодера
|
||||
|
||||
Пример использования:
|
||||
---------------------
|
||||
>>> decoder = MistralDecoder(
|
||||
... num_q_heads=8, num_kv_heads=2, emb_size=256, head_size=32,
|
||||
... max_seq_len=4096, window_size=256, rope=rope, dropout=0.1)
|
||||
>>> x = torch.randn(2, 512, 256)
|
||||
>>> out, cache = decoder(x)
|
||||
>>> print(out.shape) # torch.Size([2, 512, 256])
|
||||
|
||||
Подробнее:
|
||||
----------
|
||||
- Mistral: https://arxiv.org/abs/2310.06825
|
||||
- Llama 2: https://arxiv.org/abs/2307.09288
|
||||
- Open LLM обзор: https://huggingface.co/blog/mistral
|
||||
|
||||
"""
|
||||
def __init__(self,
|
||||
num_q_heads: int,
|
||||
num_kv_heads: int,
|
||||
emb_size: int,
|
||||
head_size: int,
|
||||
max_seq_len: int,
|
||||
window_size: int,
|
||||
rope: RoPE,
|
||||
dropout: float = 0.1
|
||||
):
|
||||
"""
|
||||
Инициализация стека декодеров MistralDecoder.
|
||||
|
||||
Аргументы:
|
||||
----------
|
||||
num_layers : int
|
||||
Сколько слоёв (декодеров/GQA-блоков) собрать в стек.
|
||||
num_q_heads : int
|
||||
Количество Query-heads в attention (их больше, экономит память).
|
||||
num_kv_heads : int
|
||||
Количество Key/Value-heads в attention (их меньше для быстрой генерации).
|
||||
emb_size : int
|
||||
Размерность embedding (должна делиться на num_q_heads без остатка).
|
||||
head_size : int
|
||||
Размер одного attention head.
|
||||
max_seq_len : int
|
||||
Максимально обрабатываемая длина последовательности.
|
||||
window_size : int
|
||||
Размер окна для sliding window attention.
|
||||
rope : RoPE
|
||||
Rotary Positional Embedding для Q/K.
|
||||
dropout : float, опционально
|
||||
Dropout на каждом attention/FFN (по умолчанию 0.1).
|
||||
|
||||
Внутри:
|
||||
-------
|
||||
- Собираются num_layers Sequential-блоков из GQA + SwiGLU + RMSNorm.
|
||||
- Все параметры передаются в каждый слой (блок).
|
||||
"""
|
||||
super().__init__()
|
||||
self._heads = GroupedQueryAttention(
|
||||
num_q_heads=num_q_heads,
|
||||
num_kv_heads=num_kv_heads,
|
||||
emb_size=emb_size,
|
||||
head_size=head_size,
|
||||
max_seq_len=max_seq_len,
|
||||
window_size=window_size,
|
||||
rope=rope,
|
||||
dropout=dropout
|
||||
)
|
||||
self._ff = SwiGLU(emb_size=emb_size, dropout=dropout)
|
||||
self._norm1 = RMSNorm(emb_size)
|
||||
self._norm2 = RMSNorm(emb_size)
|
||||
|
||||
def forward(self, x: torch.Tensor, mask: torch.Tensor = None, use_cache: bool = True, cache: list = None) -> torch.Tensor:
|
||||
"""
|
||||
Прямой проход через стек MistralDecoder.
|
||||
|
||||
Аргументы:
|
||||
----------
|
||||
x : torch.Tensor
|
||||
Входные эмбеддинги (обычно shape [batch, seq_len, emb_size]).
|
||||
use_cache : bool, по умолчанию True
|
||||
Включить ли кэширование для ускорения генерации (авторегрессия).
|
||||
cache : list, опционально
|
||||
Предыдущий кеш attention-блоков (или None).
|
||||
|
||||
Возвращает:
|
||||
-----------
|
||||
out : torch.Tensor
|
||||
Тензор после декодирования (shape соответствует x).
|
||||
new_cache : list (или None)
|
||||
Новый кэш attention для дальнейшей генерации (или None, если use_cache=False).
|
||||
|
||||
"""
|
||||
norm1_out = self._norm1(x)
|
||||
attention, kv_caches = self._heads(norm1_out, mask, use_cache=use_cache, cache=cache)
|
||||
out = attention + x
|
||||
|
||||
norm2_out = self._norm2(out)
|
||||
ffn_out = self._ff(norm2_out)
|
||||
|
||||
if use_cache is True:
|
||||
return (ffn_out + out, kv_caches)
|
||||
else:
|
||||
return (ffn_out + out, None)
|
||||
211
llm/src/llm/core/mixtral_decoder.py
Normal file
211
llm/src/llm/core/mixtral_decoder.py
Normal file
@@ -0,0 +1,211 @@
|
||||
|
||||
|
||||
from torch import nn
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from llm.core.rope import RoPE
|
||||
from llm.core.group_query_attention import GroupedQueryAttention
|
||||
from llm.core.moe import MoE
|
||||
from llm.core.rms_norm import RMSNorm
|
||||
|
||||
class MixtralDecoder(nn.Module):
|
||||
"""
|
||||
MixtralDecoder — декодерный блок для Mixtral/MoE-трансформеров (см. Mixtral 8x7B, Mistral v0.2 и др.).
|
||||
|
||||
Назначение:
|
||||
-----------
|
||||
MixtralDecoder реализует один модульный слой глубокой трансформерной архитектуры с Mixture-of-Experts (MoE) Feed-Forward Network и Grouped Query Attention (GQA).
|
||||
Поддерживает разреженную активацию и масштабируемое количество экспертов, оптимально для больших LLM.
|
||||
|
||||
Архитектура блока:
|
||||
------------------
|
||||
- RMSNorm -> Grouped Query Attention (GQA)
|
||||
- skip-connection
|
||||
- RMSNorm -> MoE (SwiGLU-эксперты)
|
||||
- skip-connection
|
||||
|
||||
Для входа `x` проходит:
|
||||
1. norm1_out = RMSNorm(x)
|
||||
2. attention, kv_caches = GQA(norm1_out, ...)
|
||||
3. out = attention + x # residual connection
|
||||
4. norm2_out = RMSNorm(out)
|
||||
5. ffn_out = MoE(norm2_out)
|
||||
6. return (ffn_out + out, kv_caches)
|
||||
|
||||
Теоретическая мотивация:
|
||||
------------------------
|
||||
- Использование MoE (см. https://arxiv.org/abs/1701.06538) позволяет кратно увеличивать capacity без роста затрат на ff-часть.
|
||||
- Grouped Query Attention эффективно масштабирует self-attention для больших моделей (см. Mistral, Llama 2/3).
|
||||
- RMSNorm (Root Mean Square LayerNorm) стабилизирует градиенты и память.
|
||||
- Является строительным блоком для стека декодеров в Mixtral-моделях (см. Mixtral, Mistral, LLaMA).
|
||||
|
||||
Аргументы конструктора:
|
||||
----------------------
|
||||
num_q_heads : int
|
||||
Число query-голов в attention.
|
||||
num_kv_heads : int
|
||||
Число key-value голов (группировка ключей/values).
|
||||
emb_size : int
|
||||
Скрытый размер эмбеддинга.
|
||||
head_size : int
|
||||
Размерность одной головы (emb_size // num_q_heads).
|
||||
max_seq_len : int
|
||||
Максимальная поддерживаемая длина последовательности.
|
||||
num_experts : int
|
||||
Количество «экспертов» (MoE).
|
||||
top_k_experts : int
|
||||
Сколько одновременно экспертов активируется для одного токена.
|
||||
window_size : int
|
||||
Размер окна внимания (используется для efficient attention).
|
||||
rope : RoPE
|
||||
Реализация позиционного кодирования RoPE.
|
||||
dropout : float
|
||||
Вероятность Dropout для регуляризации.
|
||||
|
||||
Пример использования:
|
||||
---------------------
|
||||
>>> decoder = MixtralDecoder(... параметры ...)
|
||||
>>> x = torch.randn(batch, seq, emb_size)
|
||||
>>> out, cache = decoder(x, mask=None, use_cache=True)
|
||||
>>> out.shape
|
||||
|
||||
Литература и ссылки:
|
||||
--------------------
|
||||
- Mixtral 8x7B: https://mistral.ai/news/mixtral-of-experts/
|
||||
- Shazeer et al., “Outrageously Large Neural Networks: The Sparsely-Gated Mixture-of-Experts Layer”, 2017. https://arxiv.org/abs/1701.06538
|
||||
- Mistral paper: https://arxiv.org/abs/2310.06825
|
||||
- GQA: https://arxiv.org/abs/2305.14236
|
||||
- RMSNorm: https://arxiv.org/abs/1910.07467
|
||||
|
||||
"""
|
||||
def __init__(self,
|
||||
num_q_heads: int,
|
||||
num_kv_heads: int,
|
||||
emb_size: int,
|
||||
head_size: int,
|
||||
max_seq_len: int,
|
||||
num_experts: int,
|
||||
top_k_experts: int,
|
||||
window_size: int,
|
||||
rope: RoPE,
|
||||
dropout: float = 0.1
|
||||
):
|
||||
"""
|
||||
Конструктор декодерного блока MixtralDecoder.
|
||||
|
||||
Осуществляет инициализацию всех под-компонентов слоя: Attention (Grouped Query Attention), MoE (Mixture-of-Experts, SwiGLU)
|
||||
и нормализации (RMSNorm). Позволяет гибко настраивать архитектуру под специфику задач и размеры LLM.
|
||||
|
||||
Аргументы:
|
||||
----------
|
||||
num_q_heads : int
|
||||
Количество голов внимания (queries) в механизме GroupedQueryAttention.
|
||||
Чем больше — тем тоньше дискретизация внимания по подпространствам признаков.
|
||||
num_kv_heads : int
|
||||
Количество групп ключей/значений (key-value heads) для GQA.
|
||||
Позволяет балансировать производительность и память.
|
||||
emb_size : int
|
||||
Размерность эмбеддингового пространства внутри слоя (hidden).
|
||||
head_size : int
|
||||
Размерность одной attention-головы. Обычно emb_size // num_q_heads.
|
||||
max_seq_len : int
|
||||
Максимально поддерживаемая длина токенизированной последовательности.
|
||||
num_experts : int
|
||||
Количество экспертов в слое MoE (размер пула SwiGLU-экспертов).
|
||||
top_k_experts : int
|
||||
Сколько экспертов по роутингу активируется на 1 токен (разреженность — эффективная экономия вычислений).
|
||||
window_size : int
|
||||
Размер окна для attention (может использоваться для ограничения receptive field, как в Mistral).
|
||||
rope : RoPE
|
||||
Объект позиционного кодирования RoPE (Rotary Positional Embedding), необходим для архитектуры внимания.
|
||||
dropout : float, по умолчанию 0.1
|
||||
Вероятность зануляции выходных значений для регуляризации и борьбы с переобучением.
|
||||
|
||||
Пример:
|
||||
-------
|
||||
>>> decoder = MixtralDecoder(
|
||||
... num_q_heads=8,
|
||||
... num_kv_heads=2,
|
||||
... emb_size=256,
|
||||
... head_size=32,
|
||||
... max_seq_len=1024,
|
||||
... num_experts=4,
|
||||
... top_k_experts=2,
|
||||
... window_size=128,
|
||||
... rope=rope_module,
|
||||
... dropout=0.05
|
||||
... )
|
||||
|
||||
"""
|
||||
super().__init__()
|
||||
self._heads = GroupedQueryAttention(
|
||||
num_q_heads=num_q_heads,
|
||||
num_kv_heads=num_kv_heads,
|
||||
emb_size=emb_size,
|
||||
head_size=head_size,
|
||||
max_seq_len=max_seq_len,
|
||||
window_size=window_size,
|
||||
rope=rope,
|
||||
dropout=dropout
|
||||
)
|
||||
self._ff = MoE(
|
||||
emb_size=emb_size,
|
||||
num_experts=num_experts,
|
||||
top_k_experts=top_k_experts,
|
||||
dropout=dropout
|
||||
)
|
||||
self._norm1 = RMSNorm(emb_size)
|
||||
self._norm2 = RMSNorm(emb_size)
|
||||
|
||||
def forward(self, x: torch.Tensor, mask: torch.Tensor = None, use_cache: bool = True, cache: list = None) -> torch.Tensor:
|
||||
"""
|
||||
Прямой проход (forward) через декодерный блок MixtralDecoder.
|
||||
|
||||
Данный метод реализует последовательную обработку входных скрытых состояний (x) через:
|
||||
- нормализацию (RMSNorm),
|
||||
- attention-модуль (Grouped Query Attention) с опциональным применением маски и кэша ключей/значений для ускорения инференса,
|
||||
- остаточное сложение (residual connection),
|
||||
- повторную нормализацию,
|
||||
- feed-forward блок на основе Mixture-of-Experts (MoE),
|
||||
- финальное остаточное сложение.
|
||||
|
||||
Аргументы:
|
||||
----------
|
||||
x : torch.Tensor
|
||||
Входной скрытый тензор формы [batch_size, seq_len, emb_size] — результат эмбеддинга токенов либо предыдущего слоя.
|
||||
mask : torch.Tensor, optional
|
||||
(Необязательно) Маска внимания для ограничения области self-attention (например, для автоперемешивания или causal-LLM-моделей).
|
||||
use_cache : bool, по умолчанию True
|
||||
Если True — сохраняет кэш ключей/значений attention для ускорения авторегрессии (инференса).
|
||||
cache : list, optional
|
||||
(Необязательно) Предварительно вычисленный кеш attention (для ускорения генерации длинного текста).
|
||||
|
||||
Возвращает:
|
||||
-----------
|
||||
Tuple[torch.Tensor, Any]:
|
||||
- Первый элемент: скрытый тензор выхода слоя с той же формой, что вход (последовательный residual из attention и MoE-блока).
|
||||
- Второй элемент: обновлённый кэш attention (если use_cache=True), иначе None.
|
||||
|
||||
Пример:
|
||||
-------
|
||||
>>> out, cache = decoder(x, mask=att_mask, use_cache=True, cache=old_cache)
|
||||
>>> out.shape # [batch_size, seq_len, emb_size]
|
||||
|
||||
Примечания:
|
||||
-----------
|
||||
- Для autoregressive-генерации (GPT-like режимов) следует передавать mask и использовать use_cache=True.
|
||||
- Реализация поддерживает произвольные батчи и длины последовательностей, в пределах max_seq_len слоя.
|
||||
- Модуль MixtralDecoder обычно используется в виде стека (несколько подряд) внутри крупной LLM.
|
||||
|
||||
"""
|
||||
norm1_out = self._norm1(x)
|
||||
attention, kv_caches = self._heads(norm1_out, mask, use_cache=use_cache, cache=cache)
|
||||
out = attention + x
|
||||
|
||||
norm2_out = self._norm2(out)
|
||||
ffn_out = self._ff(norm2_out)
|
||||
|
||||
if use_cache is True:
|
||||
return (ffn_out + out, kv_caches)
|
||||
else:
|
||||
return (ffn_out + out, None)
|
||||
229
llm/src/llm/core/moe.py
Normal file
229
llm/src/llm/core/moe.py
Normal file
@@ -0,0 +1,229 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from llm.core.swi_glu import SwiGLU
|
||||
|
||||
class MoE(nn.Module):
|
||||
"""
|
||||
MoE (Mixture of Experts) — слой «смеси экспертов» для современных трансформерных архитектур с разреженной активацией.
|
||||
|
||||
Назначение:
|
||||
-----------
|
||||
Класс реализует слой разреженного условного вычисления для увеличения capacity трансформеров без роста вычислительных затрат.
|
||||
Для каждого токена из последовательности выбирается (с помощью роутера) наиболее подходящее подмножество экспертов (малых нейросетей).
|
||||
Итоговый выход формируется как взвешенная сумма откликов экспертов, выбранных для данного токена.
|
||||
|
||||
Архитектурная схема:
|
||||
---------------------
|
||||
- Для каждого входного токена `x` роутер (обычно один Linear-слой) предсказывает skor, насколько каждый из `num_experts` релевантен.
|
||||
- Для каждого токена выбираются top_k_experts с максимальными skor; только они обрабатывают этот токен.
|
||||
- Каждый эксперт здесь представлен отдельным экземпляром блока `SwiGLU` (может быть любая небольшая feed-forward сеть).
|
||||
- Выход каждого эксперта умножается на индивидуальный вес (softmax по skor) — агрегируется взвешенная сумма.
|
||||
- Dropout применяется к итоговому выходу.
|
||||
|
||||
Математика (коротко):
|
||||
---------------------
|
||||
Пусть X ∈ R^{BxSxD} — вход,
|
||||
E — число экспертов,
|
||||
K — число активируемых экспертов на токен.
|
||||
r(x) = softmax(W_r x) — роутинг-логиты, top-K берём индексы и веса.
|
||||
Для каждого токена:
|
||||
y_j = Expert_j(x)
|
||||
y = sum_j(w_j * y_j), где j пробегает по выбранным экспертам
|
||||
Output: Y ∈ R^{BxSxD}
|
||||
|
||||
Аргументы конструктора:
|
||||
----------------------
|
||||
emb_size : int
|
||||
Размерность входных/выходных векторов (обычно совпадает с embedding модели).
|
||||
num_experts : int
|
||||
Общее число экспертов внутри слоя MoE.
|
||||
top_k_experts : int
|
||||
Сколько экспертов активировать и агрегировать на каждом токене (обычно 2-8).
|
||||
dropout : float, по умолчанию 0.1
|
||||
Dropout к выходу агрегатора.
|
||||
|
||||
Пример использования:
|
||||
---------------------
|
||||
>>> moe = MoE(emb_size=512, num_experts=8, top_k_experts=2, dropout=0.1)
|
||||
>>> x = torch.randn(4, 16, 512)
|
||||
>>> y = moe(x)
|
||||
>>> y.shape # torch.Size([4, 16, 512])
|
||||
|
||||
Литература:
|
||||
-----------
|
||||
- Shazeer, N. et al. “Outrageously Large Neural Networks: The Sparsely-Gated Mixture-of-Experts Layer”, 2017. https://arxiv.org/abs/1701.06538
|
||||
- Fedus, W., Zoph, B., & Shazeer, N. “Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity”, 2021. https://arxiv.org/abs/2101.03961
|
||||
- Mistral/Mixtral: https://mistral.ai/news/mixtral-of-experts/
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
emb_size: int,
|
||||
num_experts: int,
|
||||
top_k_experts: int,
|
||||
dropout: float = 0.1,
|
||||
):
|
||||
"""
|
||||
Конструктор слоя MoE (Mixture of Experts).
|
||||
|
||||
Позволяет создать слой, состоящий из набора экспертов (например, отдельных небольших feedforward-нейросетей) и роутера,
|
||||
который будет для каждого токена определять наиболее релевантных экспертов.
|
||||
Часть экспертов (top_k_experts) активируется для каждого токена, остальные — пропускаются.
|
||||
|
||||
Аргументы:
|
||||
----------
|
||||
emb_size : int
|
||||
Размерность входных и выходных векторов (embedding size).
|
||||
Определяет, над каким пространством признаков будет работать роутер и эксперты.
|
||||
Например, если скрытый размер слоя трансформера 512, сюда нужно передать 512.
|
||||
|
||||
num_experts : int
|
||||
Общее количество экспертов в слое MoE.
|
||||
Чем больше экспертов — тем больше capacity у модели, но тем выше требования к RAM/VRAM при обучении.
|
||||
Пример: 8, 16, 32, 64.
|
||||
|
||||
top_k_experts : int
|
||||
Сколько экспертов одновременно будет обрабатывать каждый токен.
|
||||
Обычно 2–8. Меньшее значение — выше разреженность, больше экономия вычислений.
|
||||
|
||||
dropout : float, по умолчанию 0.1
|
||||
Вероятность зануления значений на выходе после агрегации откликов экспертов.
|
||||
Используется для регуляризации (борьбы с переобучением).
|
||||
|
||||
Пример:
|
||||
-------
|
||||
>>> moe = MoE(emb_size=256, num_experts=8, top_k_experts=2, dropout=0.1)
|
||||
>>> print(moe)
|
||||
MoE( ... )
|
||||
|
||||
Теория:
|
||||
-------
|
||||
Слой строит:
|
||||
- Линейный роутер (Linear(emb_size, num_experts)): выдает «важность» каждого эксперта для токена.
|
||||
- Список из num_experts экспертов (в данной реализации — SwiGLU-блоки).
|
||||
|
||||
При каждом проходе для каждого токена выбираются top_k_experts наиболее релевантных экспертов,
|
||||
их ответы агрегируются взвешенной суммой (softmax по роутерным логитам).
|
||||
"""
|
||||
super().__init__()
|
||||
if top_k_experts > num_experts:
|
||||
raise ValueError(f"top_k_experts ({top_k_experts}) должен быть меньше или равен num_experts ({num_experts})!")
|
||||
self._num_experts = num_experts
|
||||
self._top_k_experts = top_k_experts
|
||||
|
||||
self._router = nn.Linear(emb_size, num_experts)
|
||||
self._experts = nn.ModuleList([SwiGLU(
|
||||
emb_size=emb_size,
|
||||
dropout=dropout,
|
||||
) for _ in range(num_experts)])
|
||||
self._dropout = nn.Dropout(dropout)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
"""
|
||||
Прямой проход (forward) через слой MoE.
|
||||
|
||||
Для входной последовательности скрытых состояний (обычно из предыдущего слоя трансформера)
|
||||
данный метод динамически выбирает для каждого токена топ-k наиболее релевантных экспертов с помощью роутера,
|
||||
пропускает соответствующие токены через выбранных экспертов и агрегирует их результаты.
|
||||
|
||||
Математически:
|
||||
--------------
|
||||
1. Для каждого токена вычисляются логиты маршрутизатора (роутера):
|
||||
router_logits = Linear(x) ∈ ℝ^{batch, seq, num_experts}
|
||||
2. Выбираются top_k экспертов (topk_indices) и соответствующие им softmax-веса (topk_weights).
|
||||
3. Каждый эксперт обрабатывает только свой поднабор токенов.
|
||||
4. Результат агрегируется — отклик эксперта умножается на вес, ответы суммируются для каждого токена.
|
||||
5. На результат применяется dropout для регуляризации.
|
||||
|
||||
Аргументы:
|
||||
----------
|
||||
x : torch.Tensor
|
||||
Трёхмерный входной тензор формы [batch_size, seq_length, emb_size],
|
||||
где batch_size — размер батча, seq_length — длина последовательности, emb_size — размерность эмбеддинга.
|
||||
|
||||
Возвращает:
|
||||
-----------
|
||||
torch.Tensor :
|
||||
Тензор той же формы [batch_size, seq_length, emb_size] — результат комбинирования выходов выбранных экспертов
|
||||
с учетом softmax-весов маршрутизатора и dropout'а.
|
||||
|
||||
Пример:
|
||||
-------
|
||||
>>> y = moe(x)
|
||||
>>> print(y.shape)
|
||||
torch.Size([batch_size, seq_length, emb_size])
|
||||
|
||||
Примечание:
|
||||
-----------
|
||||
- Каждый токен чаще всего активирует только подмножество экспертов.
|
||||
- Остальные эксперты вычислительно “спят”, что позволяет строить очень большие (по параметрам) модели с малым ростом затрат.
|
||||
- Работа с распределением топ-к экспертов и агрегирование с весами реализовано автоматически.
|
||||
|
||||
"""
|
||||
batch_size, seq_len, emb_size = x.shape
|
||||
|
||||
# 1. Пропускаем через роутер
|
||||
router_logits = self._router(x) # [batch_size, seq_len, num_experts]
|
||||
|
||||
# 2. Отбираем топ-k экспертов для каждого токена
|
||||
topk_logits, topk_indices = torch.topk(
|
||||
router_logits,
|
||||
k=self._top_k_experts,
|
||||
dim=-1
|
||||
) # topk_logits: [batch_size, seq_len, top_k]
|
||||
# topk_indices: [batch_size, seq_len, top_k]
|
||||
|
||||
# 3. Получаем веса через softmax и нормируем
|
||||
topk_weights = F.softmax(topk_logits, dim=-1) # [batch_size, seq_len, top_k]
|
||||
|
||||
# 4. Создаём нулевой тензор для результата
|
||||
output = torch.zeros_like(x) # [batch_size, seq_len, emb_size]
|
||||
|
||||
# 5. Проходим по всем экспертам
|
||||
for expert_id in range(self._num_experts):
|
||||
# Шаг 1: Создаём маску - где находится текущий эксперт в топ-k
|
||||
expert_mask = (topk_indices == expert_id) # [batch_size, seq_len, top_k]
|
||||
# Шаг 2: Проверяем, выбран ли эксперт хотя бы одним токеном
|
||||
if not expert_mask.any():
|
||||
continue # Эксперт никем не выбран, переходим к следующему
|
||||
|
||||
# Шаг 3: Находим токены, которые выбрали этого эксперта
|
||||
# (хотя бы в одной из top_k позиций)
|
||||
token_mask = expert_mask.any(dim=-1) # [batch_size, seq_len]
|
||||
|
||||
# Шаг 4: Отбираем токены из x
|
||||
# Отбираем токены для этого эксперта
|
||||
expert_input = x[token_mask]
|
||||
|
||||
# Пропускаем через эксперта
|
||||
# Добавляем batch dimension для SwiGLU и затем убираем
|
||||
expert_output = self._experts[expert_id](
|
||||
expert_input.unsqueeze(0)
|
||||
).squeeze(0)
|
||||
|
||||
# Получаем веса для этого эксперта
|
||||
# Для каждого токена может быть несколько весов (если эксперт в топ-k несколько раз)
|
||||
# Но на практике каждый эксперт появляется максимум 1 раз в топ-k
|
||||
# Находим веса: где expert_mask == True, берём соответствующий вес
|
||||
weights_for_expert = torch.zeros(
|
||||
batch_size, seq_len, device=x.device
|
||||
)
|
||||
|
||||
# Для каждой позиции в топ-k
|
||||
for k in range(self._top_k_experts):
|
||||
mask_k = topk_indices[:, :, k] == expert_id
|
||||
weights_for_expert[mask_k] = topk_weights[:, :, k][mask_k]
|
||||
|
||||
# Отбираем только веса для выбранных токенов
|
||||
selected_weights = weights_for_expert[token_mask] # [num_selected_tokens]
|
||||
|
||||
|
||||
# Перемножьте выход эксперта на веса текущего эксперта.
|
||||
weighted_output = selected_weights.unsqueeze(-1) * expert_output
|
||||
|
||||
# Помещаем результат на своё место в выходном тензоре
|
||||
output[token_mask] += weighted_output
|
||||
|
||||
out = self._dropout(output)
|
||||
|
||||
return out
|
||||
@@ -1,129 +1,263 @@
|
||||
from torch import nn
|
||||
import torch
|
||||
from .head_attention import HeadAttention
|
||||
import torch.nn.functional as F
|
||||
from .rope import RoPE
|
||||
|
||||
|
||||
class MultiHeadAttention(nn.Module):
|
||||
"""
|
||||
Мультиголовый (многоголовый) механизм внимания — ключевой компонент любого Transformer.
|
||||
Multi-Head Attention (Многоголовое внимание)
|
||||
============================================
|
||||
|
||||
Научная суть:
|
||||
- Модель параллельно агрегирует информацию через несколько подпространств (головы),
|
||||
чтобы видеть разные связи в последовательности (разный контекст, локально/глобально).
|
||||
- Каждый attention блок работает независимо, выход конкатенируется.
|
||||
- Механизм предложен в статье "Attention is All You Need" (Vaswani et al., 2017).
|
||||
|
||||
Формула внимания для одной головы:
|
||||
Attention(Q, K, V) = softmax(QK^T/sqrt(d_k))·V
|
||||
Мультиголовый:
|
||||
MultiHead(Q, K, V) = Concat([head_i])*W^O
|
||||
Что такое Multi-Head Attention?
|
||||
-------------------------------
|
||||
Это ключевой компонент трансформеров, который позволяет "смотреть" на разные части предложения
|
||||
одновременно с нескольких независимых ракурсов ("голов"). Всё, что делает Single-Head Attention — только гораздо мощнее и глубже!
|
||||
|
||||
Args:
|
||||
num_heads (int): количество attention "голов"
|
||||
emb_size (int): размерности входа и выхода
|
||||
head_size (int): размер одной attention-головы (emb_size/num_heads)
|
||||
max_seq_len (int): максимальная длина последовательности
|
||||
rope (RoPE, optional): если задан, используется Rotary Positional Encoding
|
||||
dropout (float): вероятность регуляризации
|
||||
Зачем это нужно?
|
||||
----------------
|
||||
- Модель может учиться одновременно учитывать и локальные, и глобальные взаимосвязи между токенами.
|
||||
- Каждая attention head "ловит" свой собственный смысл/зависимости, и на выходе они объединяются.
|
||||
- Это значительно улучшает понимание сложных зависимостей в тексте, особенно на длинных последовательностях.
|
||||
|
||||
Как работает алгоритм? (основная схема)
|
||||
---------------------------------------
|
||||
1. Генерируются Q, K, V (query, key, value) — по отдельной проекции для каждой головы.
|
||||
2. Для каждой головы: attention(Q, K, V) = softmax(Q·K^T / sqrt(d)) · V
|
||||
3. Все головы "склеиваются" (concatenate) и прогоняются через общий финальный линейный слой.
|
||||
|
||||
Почему это работает?
|
||||
--------------------
|
||||
- Даёт трансформеру многомерное восприятие текста.
|
||||
- Позволяет эффективно обучаться на задачах, где порядок и "дальние" связи важнее, чем простое соседство.
|
||||
|
||||
Что принимается на вход:
|
||||
------------------------
|
||||
- x: shape [batch, seq_len, embed_dim] — обычный batched-embed тензор.
|
||||
- mask (опционально): shape [seq_len, seq_len] — маска для автогерерации или causal attention.
|
||||
|
||||
Какие параметры важны:
|
||||
----------------------
|
||||
- num_heads: сколько attention heads внутри (обычно 4, 8, 16...).
|
||||
- embed_dim: исходная размерность входного тензора.
|
||||
- head_size: размер одной attention-head (обычно embed_dim // num_heads).
|
||||
- max_seq_len: максимальная длина последовательности для маски.
|
||||
|
||||
Что возвращает:
|
||||
---------------
|
||||
- output: shape [batch, seq_len, embed_dim] — результат применения всех attention heads.
|
||||
- (опционально) cache: кэш для Q/K/V (нужно для генерации по одному токену).
|
||||
|
||||
Особенности реализации:
|
||||
-----------------------
|
||||
- Оптимизированно работает через матричные умножения (без python for циклов!).
|
||||
- Включена поддержка causal attention (маска, предотвращающая «заглядывание в будущее»).
|
||||
- Является ядром любого трансформера (и LLM!).
|
||||
|
||||
Пример использования:
|
||||
>>> mha = MultiHeadAttention(num_heads=8, emb_size=512, head_size=64, max_seq_len=1024)
|
||||
>>> x = torch.randn(2, 50, 512)
|
||||
>>> out, cache = mha(x)
|
||||
>>> print(out.shape)
|
||||
---------------------
|
||||
>>> attn = MultiHeadAttention(num_heads=8, embed_dim=256, head_size=32, max_seq_len=1024)
|
||||
>>> x = torch.randn(2, 128, 256) # [batch, seq_len, embed_dim]
|
||||
>>> context, _ = attn(x)
|
||||
>>> print(context.shape) # torch.Size([2, 128, 256])
|
||||
|
||||
Где прочитать подробнее:
|
||||
-------------------------
|
||||
- Attention is All You Need (Vaswani et al, 2017): https://arxiv.org/abs/1706.03762
|
||||
- Illustrated Transformer (blog): https://jalammar.github.io/illustrated-transformer/
|
||||
"""
|
||||
def __init__(self, num_heads: int, emb_size: int, head_size: int, max_seq_len: int, rope: RoPE = None, dropout: float = 0.1):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
emb_size: int,
|
||||
head_size: int,
|
||||
max_seq_len: int,
|
||||
rope: RoPE = None,
|
||||
dropout: float = 0.1,
|
||||
):
|
||||
"""
|
||||
Инициализация многоголового внимания.
|
||||
Конструктор многоголового внимания (MultiHeadAttention).
|
||||
|
||||
Параметры:
|
||||
num_heads (int): Количество голов внимания. Типичные значения: 4-16
|
||||
emb_size (int): Размерность входных и выходных эмбеддингов
|
||||
head_size (int): Размерность каждой головы внимания (обычно emb_size // num_heads)
|
||||
max_seq_len (int): Максимальная длина последовательности
|
||||
dropout (float): Вероятность dropout (по умолчанию 0.1)
|
||||
Здесь создаются все параметры и внутренние слои для эффективного параллельного внимания (attention) сразу из нескольких "голов".
|
||||
|
||||
Контрольные значения:
|
||||
- num_heads * head_size должно равняться emb_size
|
||||
- head_size обычно выбирают 32-128
|
||||
- max_seq_len зависит от задачи (512 для BERT, 2048 для GPT-3)
|
||||
Аргументы:
|
||||
----------
|
||||
num_heads : int
|
||||
Сколько attention-heads будет внутри слоя.
|
||||
Каждая “голова” учится видеть уникальные зависимости в тексте. Обычно это 4, 8, 16 и т.п.
|
||||
Чем больше голов — тем богаче контекст, но и больше памяти.
|
||||
emb_size : int
|
||||
Сколько float-значений в каждом входном векторе (размерность embedding).
|
||||
Обычно это 256, 512, 768, 1024 и т.д.
|
||||
head_size : int
|
||||
Сколько компонент будет у каждой головы внимания.
|
||||
Важно: num_heads * head_size должно ровно совпадать с emb_size!
|
||||
Обычно head_size = emb_size // num_heads.
|
||||
max_seq_len : int
|
||||
Максимально допустимая длина последовательности для attention/маски/генерации.
|
||||
Определяет размер буферов для causal mask.
|
||||
rope : RoPE, по умолчанию None
|
||||
Объект Rotary Positional Encoding (если хотите привнести продвинутое позиционное кодирование в attention).
|
||||
Не обязателен, но нужен для современных LLM (Llama, Mistral и пр.).
|
||||
dropout : float, по умолчанию 0.1
|
||||
Величина dropout (регуляризации) — помогает борьбе с переобучением. Чем больше, тем сильнее регуляризация.
|
||||
|
||||
Внутри конструктора происходит:
|
||||
-------------------------------
|
||||
- Создаются три линейных слоя для Q, K, V (“где смотреть” и “что вытаскивать” в attention).
|
||||
- Генерируется нижнетреугольная causal-маска (запрещает видеть будущее для автогерерации).
|
||||
- Создаётся финальный линейный слой для склейки всех голов в одно пространство emb_size.
|
||||
- Вводится dropout (случайное зануление, чтобы не было сильной зависимости внимания к отдельным "плейсам").
|
||||
|
||||
Пример:
|
||||
-------
|
||||
>>> attn = MultiHeadAttention(num_heads=8, emb_size=256, head_size=32, max_seq_len=1024)
|
||||
"""
|
||||
super().__init__()
|
||||
self._heads = nn.ModuleList([
|
||||
HeadAttention(
|
||||
emb_size=emb_size,
|
||||
head_size=head_size,
|
||||
max_seq_len=max_seq_len,
|
||||
rope=rope,
|
||||
) for _ in range(num_heads)
|
||||
])
|
||||
self._num_heads = num_heads
|
||||
self._head_size = head_size
|
||||
self._max_seq_len = max_seq_len
|
||||
self._rope = rope
|
||||
|
||||
self._q = nn.Linear(emb_size, num_heads * head_size)
|
||||
self._k = nn.Linear(emb_size, num_heads * head_size)
|
||||
self._v = nn.Linear(emb_size, num_heads * head_size)
|
||||
|
||||
# Создание causal маски
|
||||
mask = torch.tril(torch.ones(max_seq_len, max_seq_len))
|
||||
self.register_buffer(
|
||||
"_tril_mask", mask.bool() if hasattr(torch, "bool") else mask.byte()
|
||||
)
|
||||
|
||||
self._layer = nn.Linear(head_size * num_heads, emb_size)
|
||||
self._dropout = nn.Dropout(dropout)
|
||||
|
||||
def forward(self, x: torch.Tensor, mask: torch.Tensor = None, use_cache: bool = True, cache: list = None):
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
mask: torch.Tensor = None,
|
||||
use_cache: bool = True,
|
||||
cache: list = None,
|
||||
):
|
||||
"""
|
||||
Прямой проход (forward):
|
||||
Для каждого токена оценивает "важность" остальных токенов сразу через несколько attention-блоков.
|
||||
Основной шаг \"многоголового внимания\": находит взаимосвязи между токенами
|
||||
в последовательности сразу из нескольких “ракурсов” (attention heads).
|
||||
|
||||
Подробное описание преобразований тензоров:
|
||||
1. Входной тензор [batch_size, seq_len, emb_size] разделяется на N голов:
|
||||
- Каждая голова получает тензор [batch_size, seq_len, head_size]
|
||||
2. Каждая голова вычисляет attention:
|
||||
- Вход: [batch_size, seq_len, head_size]
|
||||
- Выход: [batch_size, seq_len, head_size]
|
||||
3. Конкатенация результатов:
|
||||
- Объединенный выход: [batch_size, seq_len, num_heads * head_size]
|
||||
4. Линейная проекция:
|
||||
- Выход: [batch_size, seq_len, emb_size]
|
||||
5. Применение dropout
|
||||
|
||||
Args:
|
||||
x (Tensor[float]): [batch, seq_len, emb_size] — вход
|
||||
mask (Optional[Tensor[bool]]): маска позиции [seq_len, seq_len]
|
||||
use_cache (bool): использовать ли key-value кэш (для генерации)
|
||||
cache (list): предыдущие значения KV для ускорения
|
||||
Что делает этот метод:
|
||||
----------------------
|
||||
- Для каждого токена сравнивает его с остальными во входной последовательности.
|
||||
- Делает это одновременно через несколько attention heads (каждая head видит текст по-своему).
|
||||
- Итоговое “внимание” — это взвешенная сумма других токенов (контекста) для каждого токена.
|
||||
- Можно использовать кэш для генерации длинных последовательностей по одному токену (ускоряет инференс).
|
||||
|
||||
Returns:
|
||||
out (Tensor[float]): [batch, seq_len, emb_size] — результат MHA
|
||||
kv_caches (list): списки новых KV-кэшей (если используется)
|
||||
Аргументы:
|
||||
----------
|
||||
x : torch.Tensor
|
||||
Входной тензор формы [batch, seq_len, emb_size].
|
||||
Это ваши входные эмбеддинги (обычно после token + positional embedding).
|
||||
mask : torch.Tensor, опционально
|
||||
Матрица формы [seq_len, seq_len], задающая “разрешения” — кто может смотреть на кого (например, causal mask).
|
||||
Если не указана — используется внутренняя маска (например, для autoregressive генерации).
|
||||
use_cache : bool, по умолчанию True
|
||||
Нужно ли использовать кэш для KV attention (важно для ускорения генерации по одному токену).
|
||||
cache : list, опционально
|
||||
Предыдущий кэш Key/Value — для генерации текста по частям.
|
||||
|
||||
Типичный паттерн:
|
||||
Вход: [batch, seq, emb] → N голов [batch, seq, head_size] →
|
||||
→ concat [batch, seq, N*head_size] → проекция → dropout
|
||||
Возвращает:
|
||||
-----------
|
||||
- output: torch.Tensor формы [batch, seq_len, emb_size] — результат применения multi-head attention.
|
||||
- kv_caches: список новых KV для кэширования при генерации (или None).
|
||||
|
||||
Важно:
|
||||
-------
|
||||
- Shape входа всегда [batch, seq_len, emb_size], выход тот же.
|
||||
- При seq_len > max_seq_len выбросит ошибку (безопасно для контроля переполнения буферов).
|
||||
- При использовании use_cache=True кешируется только последние токены (актуально для LLM).
|
||||
|
||||
Пример преобразований для emb_size=512, num_heads=8:
|
||||
Вход: [4, 100, 512]
|
||||
-> Каждая голова: [4, 100, 64]
|
||||
-> После внимания: 8 x [4, 100, 64]
|
||||
-> Конкатенация: [4, 100, 512]
|
||||
-> Проекция: [4, 100, 512]
|
||||
-> Dropout: [4, 100, 512]
|
||||
|
||||
Пример:
|
||||
>>> out, caches = mha(x)
|
||||
>>> out.shape # [batch, seq_len, emb_size]
|
||||
-------
|
||||
>>> attn = MultiHeadAttention(num_heads=8, emb_size=256, head_size=32, max_seq_len=1024)
|
||||
>>> x = torch.randn(2, 100, 256)
|
||||
>>> y, kv_cache = attn(x)
|
||||
>>> print(y.shape) # torch.Size([2, 100, 256])
|
||||
"""
|
||||
# 1. Вычисляем attention для каждой головы
|
||||
attention_results = []
|
||||
for i, head in enumerate(self._heads):
|
||||
head_cache = cache[i] if cache is not None else None
|
||||
result = head(x, use_cache=use_cache, cache=head_cache)
|
||||
attention_results.append(result)
|
||||
batch_size, seq_len, emb_size = x.shape
|
||||
|
||||
if seq_len > self._max_seq_len:
|
||||
raise ValueError(
|
||||
f"Длина последовательности {seq_len} превышает максимум {self._max_seq_len}"
|
||||
)
|
||||
|
||||
# Пропустите тензор x через матрицы Wq, Wk , Wv, чтобы получить матрицы запроса, ключа и значения.
|
||||
k = self._k(x) # [B, T, hs]
|
||||
q = self._q(x) # [B, T, hs]
|
||||
v = self._v(x) # [B, T, hs]
|
||||
|
||||
# Шаг 2: Изменение формы для multi-head
|
||||
# [batch_size, seq_len, num_heads * head_size]
|
||||
# -> [batch_size, seq_len, num_heads, head_size]
|
||||
q = q.reshape(batch_size, seq_len, self._num_heads, self._head_size)
|
||||
k = k.reshape(batch_size, seq_len, self._num_heads, self._head_size)
|
||||
v = v.reshape(batch_size, seq_len, self._num_heads, self._head_size)
|
||||
|
||||
outputs, caches = zip(*attention_results)
|
||||
attention_outputs = list(outputs)
|
||||
kv_caches = list(caches)
|
||||
|
||||
# 2. Объединяем результаты всех голов
|
||||
concatenated_attention = torch.cat(attention_outputs, dim=-1)
|
||||
|
||||
# 3. Transpose: [B, T, H, hs] -> [B, H, T, hs]
|
||||
q = q.transpose(1, 2)
|
||||
k = k.transpose(1, 2)
|
||||
v = v.transpose(1, 2)
|
||||
|
||||
start_pos = 0
|
||||
if cache is not None:
|
||||
k_cache, v_cache = cache
|
||||
cache_len = k_cache.shape[2]
|
||||
start_pos = cache_len
|
||||
|
||||
# Пропустите матрицы запроса и ключа через экземпляр rope, чтобы выполнить поворот.
|
||||
if self._rope is not None:
|
||||
# ✅ Применяем RoPE к Q и K (НЕ к V!)
|
||||
q = self._rope(q, start_pos=start_pos) # [B, T, hs]
|
||||
k = self._rope(k, start_pos=start_pos) # [B, T, hs]
|
||||
|
||||
# Если cache пришел, то объединяем кэш и одну строку из ключа и значения. Это будут новые key и value для последующих вычислений.
|
||||
# 5. Кэширование (для autoregressive generation)
|
||||
if cache is not None:
|
||||
k_cache, v_cache = cache
|
||||
k = torch.cat([k_cache, k], dim=2) # Concat по seq_len (dim=2)
|
||||
v = torch.cat([v_cache, v], dim=2)
|
||||
|
||||
# Перемножим матрицы запроса и ключа (транспонированную), чтобы вычислить матрицу внимания.
|
||||
# И разделить все значения в матрице внимания на корень из head_size.
|
||||
scores = q @ k.transpose(-2, -1) / (self._head_size ** 0.5)
|
||||
|
||||
# Если cache пришел, то маску не накладываем. Иначе наложите на матрицу внимания треугольную маску, созданную при инициализации. Все скрытые значения должны быть приведены к минус бесконечности: float('-inf').
|
||||
if cache is None:
|
||||
scores = scores.masked_fill(
|
||||
~self._tril_mask[:seq_len, :seq_len], float("-inf")
|
||||
)
|
||||
|
||||
# Применить к матрице внимания (построчно) функцию Softmax.
|
||||
weights = F.softmax(scores, dim=-1)
|
||||
|
||||
# Перемножим матрицу внимания и матрицу значения.
|
||||
x_out = weights @ v # [B, T, hs]
|
||||
|
||||
# Измените форму тензора на batch_size × seq_len × num_heads*head_size.
|
||||
# Transpose обратно и concatenate heads
|
||||
x_out = x_out.transpose(1, 2) # [B, T_q, H, hs]
|
||||
x_out = x_out.contiguous() # Важно для reshape!
|
||||
concatenated_attention = x_out.reshape(batch_size, seq_len, self._num_heads * self._head_size)
|
||||
|
||||
#concatenated_attention = x_out.reshape(batch_size, seq_len, self._num_heads * self._head_size)
|
||||
|
||||
# Пропустите получившийся тензор через последний линейный слой.
|
||||
# 3. Проецируем в пространство эмбеддингов
|
||||
projected_output = self._layer(concatenated_attention)
|
||||
|
||||
|
||||
# 4. Применяем dropout для регуляризации
|
||||
final_output = self._dropout(projected_output)
|
||||
|
||||
|
||||
if use_cache is True:
|
||||
return (final_output, kv_caches)
|
||||
return (final_output, (k, v))
|
||||
else:
|
||||
return (final_output, None)
|
||||
|
||||
252
llm/src/llm/core/multi_query_attention.py
Normal file
252
llm/src/llm/core/multi_query_attention.py
Normal file
@@ -0,0 +1,252 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from llm.core.rope import RoPE
|
||||
|
||||
class MultiQueryAttention(nn.Module):
|
||||
"""
|
||||
Multi-Query Attention (MQA) — быстрый и экономичный вариант self-attention для LLM.
|
||||
|
||||
Назначение:
|
||||
-----------
|
||||
Класс реализует механизм внимания (self-attention), в котором для всех Query-голов используются одни и те же Key и Value.
|
||||
В классическом MultiHeadAttention (MHA) на каждый Query используется свой Key/Value. В MQA набор Key/Value общий для всех голов,
|
||||
что снижает требования к памяти и ускоряет работу, что особенно важно для больших LLM на inference.
|
||||
|
||||
Теоретическое преимущество:
|
||||
--------------------------
|
||||
- Существенно экономит память на матрицы Key и Value: количество KV-голов обычно в 4–8 раз меньше, чем число Query-голов.
|
||||
- Позволяет достигать скорости почти обычной MHA при минимальной потере точности (см. Llama, Mistral).
|
||||
- Является стандартом де-факто для deployment и inference современных LLM.
|
||||
|
||||
Архитектурная схема:
|
||||
--------------------
|
||||
- Для каждого токена во входе вычисляются Q_h (отдельные для каждой Query-головы), но K и V — общие для всех.
|
||||
- Attention внутри каждой головы формируется через матричный продукт соответствующей Q_h и общего K.
|
||||
- Выходные вектора голов конкатенируются и проецируются обратно в emb_size.
|
||||
|
||||
Формулы:
|
||||
--------
|
||||
Q = Wq·x, K = Wk·x, V = Wv·x
|
||||
(Wq — отдельные для всех Query, Wk/Wv — общие для всех голов)
|
||||
Attention_h(x) = softmax(Q_h·K^T / sqrt(d_k))·V
|
||||
Output = Concat_h([Attention_h(x)])·W_o
|
||||
|
||||
Аргументы конструктора:
|
||||
-----------------------
|
||||
emb_size : int
|
||||
Размерность скрытого пространства (hidden size, embedding dim).
|
||||
num_heads : int
|
||||
Число Query-голов (обычно 8–32 в LLM).
|
||||
kv_heads : int
|
||||
Число Key/Value-голов (обычно 1, 2, 4, 8).
|
||||
head_size : int, optional
|
||||
Размерность одной головы (обычно emb_size // num_heads).
|
||||
dropout : float, optional
|
||||
Вероятность Dropout для регуляризации внимания.
|
||||
|
||||
Пример использования:
|
||||
---------------------
|
||||
>>> mqa = MultiQueryAttention(emb_size=512, num_heads=8, kv_heads=1)
|
||||
>>> x = torch.randn(2, 16, 512)
|
||||
>>> mask = torch.ones(2, 16, 16)
|
||||
>>> out = mqa(x, mask)
|
||||
>>> print(out.shape) # torch.Size([2, 16, 512])
|
||||
|
||||
Литература и статьи:
|
||||
--------------------
|
||||
- Shazeer, N., “Fast Transformer Decoding: One Write-Head Is All You Need” (MQA): https://arxiv.org/abs/1911.02150
|
||||
- Llama: https://arxiv.org/abs/2302.13971
|
||||
- Mistral: https://arxiv.org/abs/2310.06825
|
||||
- PaLM/PaLM2, Mixtral, ChatGLM: практическое описание MQA.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
num_q_heads: int,
|
||||
emb_size: int,
|
||||
head_size: int,
|
||||
max_seq_len: int,
|
||||
rope: RoPE = None,
|
||||
dropout: float = 0.1,
|
||||
):
|
||||
"""
|
||||
Конструктор MultiQueryAttention.
|
||||
|
||||
Инициализирует все слои и буферы для реализации Multi-Query Attention с общими K/V-головами и индивидуальными Q-головами.
|
||||
Позволяет существенно ускорять инференс и экономить память при работе с большими языковыми моделями.
|
||||
|
||||
Аргументы:
|
||||
----------
|
||||
num_q_heads : int
|
||||
Число query-голов (обычно совпадает с количеством attention heads в модели).
|
||||
Определяет количество параллельных subspace для запроса.
|
||||
emb_size : int
|
||||
Размер скрытого пространства embedding (input/output размерность attention слоя).
|
||||
head_size : int
|
||||
Размерность одной attention-головы.
|
||||
Обычно emb_size // num_q_heads.
|
||||
max_seq_len : int
|
||||
Максимально поддерживаемая длина последовательности (нужна для построения треугольной маски causal attention).
|
||||
rope : RoPE, optional
|
||||
Модуль для rotary positional encoding (позиционный энкодер, улучшает обобщающую способность attention).
|
||||
Если None, positional encoding не применяется.
|
||||
dropout : float, по умолчанию 0.1
|
||||
Вероятность dropout для выходного слоя attention (регуляризация).
|
||||
|
||||
Внутри:
|
||||
-------
|
||||
- Насчитывает отдельные весовые слои для Q, общие для всех голов K/V.
|
||||
- Строит causal маску для автогрессивной генерации.
|
||||
- (Опционально) использует RoPE для позиционного кодирования.
|
||||
- Dropout применяется после финального projection.
|
||||
|
||||
Пример:
|
||||
-------
|
||||
>>> mqa = MultiQueryAttention(emb_size=256, num_q_heads=8, head_size=32, max_seq_len=2048, rope=None, dropout=0.1)
|
||||
"""
|
||||
super().__init__()
|
||||
self._num_q_heads = num_q_heads
|
||||
self._head_size = head_size
|
||||
self._max_seq_len = max_seq_len
|
||||
self._rope = rope
|
||||
|
||||
self._q = nn.Linear(emb_size, num_q_heads * head_size)
|
||||
self._k = nn.Linear(emb_size, head_size)
|
||||
self._v = nn.Linear(emb_size, head_size)
|
||||
|
||||
# Создание causal маски
|
||||
mask = torch.tril(torch.ones(max_seq_len, max_seq_len))
|
||||
self.register_buffer(
|
||||
"_tril_mask", mask.bool() if hasattr(torch, "bool") else mask.byte()
|
||||
)
|
||||
|
||||
self._layer = nn.Linear(num_q_heads * head_size, emb_size)
|
||||
self._dropout = nn.Dropout(dropout)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
mask: torch.Tensor = None,
|
||||
use_cache: bool = True,
|
||||
cache: list = None,
|
||||
):
|
||||
"""
|
||||
Прямой проход (forward) через слой MultiQueryAttention.
|
||||
|
||||
Реализует multi-query self-attention для входных последовательностей с оптимизацией памяти за счёт общих K/V-голов для всех Query.
|
||||
Поддерживает работу с rotary positional encoding (RoPE), каузальной маской и кэшированием для ускорения генерации.
|
||||
|
||||
Аргументы:
|
||||
----------
|
||||
x : torch.Tensor
|
||||
Входной тензор формы [batch_size, seq_len, emb_size] — скрытые состояния после предыдущего слоя или эмбеддинга.
|
||||
mask : torch.Tensor, optional
|
||||
Необязательная маска внимания (например, для padding или custom-маскировки). По умолчанию используется встроенная causal mask.
|
||||
use_cache : bool, по умолчанию True
|
||||
Если True, возвращает кэш ключей/значений (для autoregressive inference/generation).
|
||||
cache : list, optional
|
||||
(K_cache, V_cache) — предварительный кэш KV (для ускоренного инференса). Если None, кэш не используется/создаётся заново.
|
||||
|
||||
Возвращает:
|
||||
-----------
|
||||
если use_cache == True:
|
||||
Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
|
||||
- attention_out: [batch_size, seq_len, emb_size] — результат attention после проекции и dropout.
|
||||
- (K, V): кэшированные ключи и значения (использовать для последующих forward'ов в autoregressive генерации)
|
||||
если use_cache == False:
|
||||
Tuple[torch.Tensor, None]
|
||||
|
||||
Математические шаги:
|
||||
--------------------
|
||||
1. Q = Wq·x; K = Wk·x; V = Wv·x # Q: индивидуальные для каждой головы, K/V — общие
|
||||
2. [optional] Rotary positional encoding применяется к Q и K
|
||||
3. (optional) concat c k/v cache (for autoregressive inference)
|
||||
4. attention_scores = softmax(Q·K^T / sqrt(head_size), mask)
|
||||
5. attention_out = attention_scores·V
|
||||
6. heads сливаются и проецируются в emb_size; применяется dropout.
|
||||
|
||||
Пример:
|
||||
-------
|
||||
>>> out, cache = mqa(x, mask=attn_mask, use_cache=True, cache=prev_cache)
|
||||
>>> print(out.shape) # torch.Size([batch_size, seq_len, emb_size])
|
||||
|
||||
Примечания:
|
||||
-----------
|
||||
- Для каузального режима используется треугольная маска (по умолчанию).
|
||||
- Для генерации текста с cache передавайте кэш от предыдущих токенов — это ускоряет autoregressive inference.
|
||||
- Внимание! Тензоры внутри cache должны иметь форму [batch, heads, seq_len, head_size].
|
||||
"""
|
||||
batch_size, seq_len, emb_size = x.shape
|
||||
if seq_len > self._max_seq_len:
|
||||
raise ValueError(
|
||||
f"Длина последовательности {seq_len} превышает максимум {self._max_seq_len}"
|
||||
)
|
||||
|
||||
# Пропустите тензор x через матрицы Wq, Wk , Wv, чтобы получить матрицы запроса, ключа и значения.
|
||||
k = self._k(x) # [B, T, hs]
|
||||
q = self._q(x) # [B, T, hs]
|
||||
v = self._v(x) # [B, T, hs]
|
||||
|
||||
# Шаг 2: Изменение формы для multi-head
|
||||
# [batch_size, seq_len, num_heads * head_size]
|
||||
# -> [batch_size, seq_len, num_heads, head_size]
|
||||
q = q.reshape(batch_size, seq_len, self._num_q_heads, self._head_size)
|
||||
k = k.reshape(batch_size, seq_len, 1, self._head_size)
|
||||
v = v.reshape(batch_size, seq_len, 1, self._head_size)
|
||||
|
||||
# 3. Transpose: [B, T, H, hs] -> [B, H, T, hs]
|
||||
q = q.transpose(1, 2)
|
||||
k = k.transpose(1, 2)
|
||||
v = v.transpose(1, 2)
|
||||
|
||||
# Пропустите матрицы запроса и ключа через экземпляр rope, чтобы выполнить поворот.
|
||||
if self._rope is not None:
|
||||
# Применяем RoPE к Q и K (НЕ к V!)
|
||||
q = self._rope(q) # [B, T, hs]
|
||||
k = self._rope(k) # [B, T, hs]
|
||||
|
||||
|
||||
# Если cache пришел, то объединяем кэш и одну строку из ключа и значения. Это будут новые key и value для последующих вычислений.
|
||||
# 5. Кэширование (для autoregressive generation)
|
||||
if cache is not None:
|
||||
k_cache, v_cache = cache
|
||||
k = torch.cat([k_cache, k], dim=2) # Concat по seq_len (dim=2)
|
||||
v = torch.cat([v_cache, v], dim=2)
|
||||
|
||||
|
||||
# Перемножим матрицы запроса и ключа (транспонированную), чтобы вычислить матрицу внимания.
|
||||
# И разделить все значения в матрице внимания на корень из head_size.
|
||||
scores = q @ k.transpose(-2, -1) / (self._head_size ** 0.5)
|
||||
|
||||
# Если cache пришел, то маску не накладываем. Иначе наложите на матрицу внимания треугольную маску, созданную при инициализации. Все скрытые значения должны быть приведены к минус бесконечности: float('-inf').
|
||||
if cache is None:
|
||||
scores = scores.masked_fill(
|
||||
~self._tril_mask[:seq_len, :seq_len], float("-inf")
|
||||
)
|
||||
|
||||
# Применить к матрице внимания (построчно) функцию Softmax.
|
||||
weights = F.softmax(scores, dim=-1)
|
||||
|
||||
# Перемножим матрицу внимания и матрицу значения.
|
||||
x_out = weights @ v # [B, T, hs]
|
||||
|
||||
|
||||
# Измените форму тензора на batch_size × seq_len × num_heads*head_size.
|
||||
# Transpose обратно и concatenate heads
|
||||
x_out = x_out.transpose(1, 2) # [B, T_q, H, hs]
|
||||
x_out = x_out.contiguous() # Важно для reshape!
|
||||
concatenated_attention = x_out.reshape(batch_size, seq_len, self._num_q_heads * self._head_size)
|
||||
|
||||
|
||||
# Пропустите получившийся тензор через последний линейный слой.
|
||||
# 3. Проецируем в пространство эмбеддингов
|
||||
projected_output = self._layer(concatenated_attention)
|
||||
|
||||
|
||||
# 4. Применяем dropout для регуляризации
|
||||
final_output = self._dropout(projected_output)
|
||||
|
||||
if use_cache is True:
|
||||
return (final_output, (k, v))
|
||||
else:
|
||||
return (final_output, None)
|
||||
@@ -1,66 +1,105 @@
|
||||
import torch
|
||||
from torch import nn, Tensor
|
||||
|
||||
|
||||
class PositionalEmbeddings(nn.Module):
|
||||
"""
|
||||
Обучаемые позиционные эмбеддинги (learnable positional embeddings).
|
||||
PositionalEmbeddings — классические позиционные эмбеддинги для трансформеров (absolute sinusoidal or learned).
|
||||
|
||||
Позиционные эмбеддинги используются в нейросетях для передачи информации
|
||||
о позиции элементов в последовательности (например, в Transformer).
|
||||
|
||||
Научная суть:
|
||||
- Трансформеры не используют рекуррентность, а значит сами по себе не различают порядок слов.
|
||||
- Позиционные эмбеддинги добавляются к токеновым, чтобы сеть понимала, в каком месте последовательности находится каждый токен.
|
||||
- Обычно реализуются как отдельная матрица (nn.Embedding), которая обучается вместе с моделью (это learnable вариант, как в GPT и BERT).
|
||||
Назначение:
|
||||
-----------
|
||||
- Добавляет или конкатенирует форму позиционной информации к каждому входному токену (since Transformer cannot distinguish positions otherwise).
|
||||
- Используется во всех \"ранних\" трансформерах (GPT, BERT, T5), чаще всего в виде learnable или синусоидальных embeddings.
|
||||
|
||||
Args:
|
||||
max_seq_len (int): максимальная длина последовательности
|
||||
emb_size (int): размер вектора позиции
|
||||
|
||||
Пример использования:
|
||||
>>> pos_encoder = PositionalEmbeddings(max_seq_len=100, emb_size=256)
|
||||
>>> # Получить эмбеддинги для последовательности из 10 элементов
|
||||
>>> embeddings = pos_encoder(10) # Tensor shape: [10, 256]
|
||||
>>> # Использование в модели
|
||||
>>> class MyModel(nn.Module):
|
||||
... def __init__(self):
|
||||
... super().__init__()
|
||||
... self.pos_emb = PositionalEmbeddings(100, 256)
|
||||
... def forward(self, x):
|
||||
... pos = self.pos_emb(x.size(1))
|
||||
... return x + pos # Добавляем позиционную информацию
|
||||
Архитектурные варианты:
|
||||
-----------------------
|
||||
- Learnable positional embeddings (как в GPT-2): обычный nn.Embedding инициализируется случайно, и веса учатся вместе с моделью.
|
||||
- Sinusoidal positional encoding (как в оригинальном Transformer): не имеет параметров, а создаётся по заданной формуле sin/cos(ω*x).
|
||||
|
||||
Принцип работы:
|
||||
---------------
|
||||
- Для каждой позиции t заполняется вектор emb_size длиной по формуле (или выбирается из weight matrix).
|
||||
- Эти вектора можно либо складывать с токеновыми эмбеддингами, либо конкатенировать.
|
||||
- Позволяет attention-механизму \"понимать\" порядок токенов/слов в последовательности.
|
||||
|
||||
Формулы (Or: Vaswani et al., 2017):
|
||||
------------------------------------
|
||||
PE(pos, 2i) = sin(pos / 10000^{2i/d})
|
||||
PE(pos, 2i+1) = cos(pos / 10000^{2i/d})
|
||||
где d = emb_size, pos = позиция (int), i = индекс пары компонент.
|
||||
|
||||
Аргументы конструктора:
|
||||
-----------------------
|
||||
max_seq_len: int — максимально поддерживаемая длина последовательности
|
||||
emb_size: int — размер возвращаемого positional vector для каждой позиции
|
||||
(иногда выбирается вариант — learnable или фиксация через sin/cos)
|
||||
|
||||
Пример:
|
||||
-------
|
||||
>>> pos = PositionalEmbeddings(max_seq_len=1024, emb_size=256)
|
||||
>>> p = pos(32) # Получить positional embeddings для 32 позиций
|
||||
>>> p.shape # torch.Size([32, 256])
|
||||
>>> token_emb = ... # [batch, seq_len, emb_size]
|
||||
>>> encoded = token_emb + p.unsqueeze(0) # Broadcast add
|
||||
|
||||
References:
|
||||
-----------
|
||||
- Vaswani et al., \"Attention is All You Need\", 2017: https://arxiv.org/abs/1706.03762
|
||||
- GPT-2 implementation: https://github.com/openai/gpt-2
|
||||
- Почему positional encoding важен: https://kazemnejad.com/blog/transformer_architecture_positional_encoding/
|
||||
"""
|
||||
|
||||
def __init__(self, max_seq_len: int, emb_size: int):
|
||||
"""
|
||||
Инициализация позиционного энкодера.
|
||||
|
||||
Аргументы:
|
||||
----------
|
||||
max_seq_len : int
|
||||
Максимальная длина последовательности (builds buffer for sin/cos or embedding)
|
||||
emb_size : int
|
||||
Длина позиционного вектора
|
||||
|
||||
Внутри:
|
||||
-------
|
||||
- Если используется learned embedding: создаётся nn.Embedding (можно легко менять в будущем).
|
||||
- Если fixed (sin/cos): вычисляется и хранится буфер (max_seq_len, emb_size).
|
||||
"""
|
||||
super().__init__()
|
||||
self.max_seq_len = max_seq_len
|
||||
self.emb_size = emb_size
|
||||
self.embedding = nn.Embedding(
|
||||
num_embeddings=max_seq_len,
|
||||
embedding_dim=emb_size
|
||||
num_embeddings=max_seq_len, embedding_dim=emb_size
|
||||
)
|
||||
|
||||
def forward(self, seq_len: int, start_pos: int = 0) -> Tensor:
|
||||
"""
|
||||
Возвращает позиционные эмбеддинги для заданной длины последовательности.
|
||||
|
||||
Args:
|
||||
seq_len (int): Длина последовательности (1 <= seq_len <= max_seq_len)
|
||||
|
||||
Returns:
|
||||
Tensor: Тензор позиционных эмбеддингов формы [seq_len, emb_size]
|
||||
|
||||
Raises:
|
||||
IndexError: Если seq_len выходит за допустимые границы
|
||||
|
||||
Получить positional embeddings для последовательности длиной seq_len.
|
||||
|
||||
Аргументы:
|
||||
----------
|
||||
seq_len : int
|
||||
Сколько позиций сгенерировать (обычно == входная длина x)
|
||||
start_pos : int, по умолчанию 0
|
||||
Возможность выдать positional embeddings \"с середины\" (для autoregressive генерации)
|
||||
|
||||
Возвращает:
|
||||
-----------
|
||||
torch.Tensor — positional embeddings формы [seq_len, emb_size]
|
||||
|
||||
Пример:
|
||||
>>> pos_encoder = PositionalEmbeddings(100, 64)
|
||||
>>> emb = pos_encoder(10) # Тензор 10x64
|
||||
-------
|
||||
>>> pos = PositionalEmbeddings(512, 128)
|
||||
>>> p = pos(10) # [10, 128]
|
||||
"""
|
||||
if seq_len < 1 or seq_len > self.max_seq_len:
|
||||
raise IndexError(f"Длина {seq_len} должна быть от 1 до {self.max_seq_len}")
|
||||
if start_pos == 0:
|
||||
positions = torch.arange(seq_len, device=self.embedding.weight.device)
|
||||
else:
|
||||
positions = torch.arange(start=start_pos, end=start_pos + seq_len, device=self.embedding.weight.device)
|
||||
positions = torch.arange(
|
||||
start=start_pos,
|
||||
end=start_pos + seq_len,
|
||||
device=self.embedding.weight.device,
|
||||
)
|
||||
return self.embedding(positions)
|
||||
|
||||
@@ -24,60 +24,100 @@ from typing import Optional
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
"""
|
||||
RMS Normalization (Root Mean Square Layer Normalization).
|
||||
RMSNorm (Root Mean Square Layer Normalization) — простая и эффективная альтернатива LayerNorm.
|
||||
|
||||
Нормализует входные данные по последнему измерению используя среднеквадратичное
|
||||
значение вместо среднего, как в стандартном LayerNorm.
|
||||
Назначение:
|
||||
-----------
|
||||
- Нормализует входной тензор по последнему измерению только с помощью RMS (root mean square), без вычитания среднего.
|
||||
- Используется в LLaMA, PaLM и других крупных языковых моделях для лучшей стабильности и ускорения обучения.
|
||||
- В отличие от LayerNorm, не центрирует значения, что особенно полезно для автогерессивных трансформеров с residual-связями.
|
||||
|
||||
Научная суть:
|
||||
- Упрощенный вариант LayerNorm без вычисления среднего, только деление на rms.
|
||||
- Лучшая численная стабильность на больших моделях, меньше вычислений.
|
||||
- Применяется в LLaMA, PaLM и др.
|
||||
|
||||
Формула:
|
||||
RMSNorm(x) = (x / sqrt(mean(x²) + eps)) * w (w — обучаемый вектор)
|
||||
Мотивация и математика:
|
||||
-----------------------
|
||||
- Формула для одного слоя и вектора x:
|
||||
rms = sqrt( mean( x ** 2 ) + eps )
|
||||
out = w * ( x / rms )
|
||||
где w — learnable scale, eps — небольшая константа для численной устойчивости.
|
||||
- Нет смещения/вычитания среднего — сигнал сохраняет абсолютные значения, меньше “искажает” автоагрегатные значения на накопленных резидуалах.
|
||||
|
||||
Аргументы конструктора:
|
||||
-----------------------
|
||||
dim : int
|
||||
Размер последнего нормализуемого измерения (обычно совпадает с размером embedding/final head).
|
||||
eps : float, default=1e-6
|
||||
Малое значение для устойчивости (additive epsilon).
|
||||
|
||||
Особенности:
|
||||
------------
|
||||
- Нет батч-нормализации, нет зависимости от размера батча.
|
||||
- Отлично подходит для больших моделей и автогерессии — меньше шуму от residual.
|
||||
|
||||
Пример использования:
|
||||
---------------------
|
||||
>>> norm = RMSNorm(emb_size=256)
|
||||
>>> x = torch.randn(4, 10, 256)
|
||||
>>> out = norm(x) # возвращает tensor той же формы
|
||||
|
||||
References:
|
||||
-----------
|
||||
- Zhang & Sennrich, "Root Mean Square Layer Normalization", 2019: https://arxiv.org/abs/1910.07467
|
||||
- Применение в LLaMA: https://arxiv.org/abs/2302.13971
|
||||
- HuggingFace implementation: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
|
||||
|
||||
Args:
|
||||
dim (int): размер последнего измерения (обычно emb_size)
|
||||
eps (float): для численной устойчивости
|
||||
|
||||
Пример:
|
||||
>>> norm = RMSNorm(emb_size)
|
||||
>>> out = norm(x)
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self, dim: int, eps: float = 1e-6):
|
||||
"""
|
||||
Инициализация RMSNorm слоя.
|
||||
|
||||
Инициализация RMSNorm.
|
||||
|
||||
Args:
|
||||
dim: Размерность нормализуемого измерения
|
||||
eps: Малое значение для численной стабильности (по умолчанию 1e-6)
|
||||
-----
|
||||
dim : int
|
||||
Последнее нормализуемое измерение (обычно размерность embedding или hidden).
|
||||
eps : float
|
||||
Малое значение для устойчивости (по умолчанию 1e-6).
|
||||
|
||||
Внутри:
|
||||
-------
|
||||
- Создаётся обучаемый scale weight w для каждой компоненты dim.
|
||||
- Сохраняется параметр eps для добавления к RMS.
|
||||
"""
|
||||
super().__init__()
|
||||
self._eps = eps
|
||||
self._w = nn.Parameter(torch.ones(dim))
|
||||
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Прямой проход через RMSNorm слой.
|
||||
|
||||
Прямой проход через RMSNorm.
|
||||
|
||||
Args:
|
||||
x: Входной тензор формы [..., dim]
|
||||
|
||||
-----
|
||||
x : torch.Tensor
|
||||
Входной тензор любого shape с последней размерностью dim.
|
||||
|
||||
Returns:
|
||||
Нормализованный тензор той же формы, что и входной
|
||||
|
||||
Формула:
|
||||
output = w * (x / sqrt(mean(x²) + eps))
|
||||
--------
|
||||
torch.Tensor — тот же shape, что и вход x, но нормализованный по RMS на последнем измерении.
|
||||
|
||||
Алгоритм:
|
||||
---------
|
||||
- Вычислить rms = sqrt( mean( x**2, dim=-1, keepdim=True ) + eps )
|
||||
- Поделить x на rms
|
||||
- Помасштабировать обучаемым весом w
|
||||
|
||||
Пример:
|
||||
-------
|
||||
>>> norm = RMSNorm(256)
|
||||
>>> out = norm(torch.randn(2, 10, 256))
|
||||
|
||||
"""
|
||||
# Вычисление RMS (Root Mean Square) по последнему измерению
|
||||
rms = (x.pow(2).mean(-1, keepdim=True) + self._eps) ** 0.5
|
||||
|
||||
|
||||
# Нормализация и масштабирование
|
||||
norm_x = x / rms
|
||||
return self._w * norm_x
|
||||
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
"""Строковое представление для отладки."""
|
||||
return f'dim={self._w.shape[0]}, eps={self._eps}'
|
||||
return f"dim={self._w.shape[0]}, eps={self._eps}"
|
||||
|
||||
@@ -1,21 +1,51 @@
|
||||
"""
|
||||
Rotary Positional Embeddings (RoPE) - ротационные позиционные эмбеддинги.
|
||||
Rotary Positional Embeddings (RoPE)
|
||||
===================================
|
||||
|
||||
Реализация ротационного позиционного кодирования, которое кодирует позиционную
|
||||
информацию через вращение векторов запросов и ключей в комплексном пространстве.
|
||||
Что такое RoPE?
|
||||
----------------
|
||||
RoPE — это способ "вписать" информацию о позиции токенов в скрытые вектора модели трансформера.
|
||||
Вместо простого сложения с абсолютным positional embedding, RoPE использует вращения векторов (как поворот стрелки на круге) внутри каждого attention head. Каждый элемент пары (вектор четного и нечетного индекса) поворачивается на угол, зависящий от позиции токена.
|
||||
|
||||
Научная статья: "RoFormer: Enhanced Transformer with Rotary Position Embedding"
|
||||
https://arxiv.org/abs/2104.09864
|
||||
Зачем это?
|
||||
-----------
|
||||
- RoPE реализует **относительное позиционное кодирование**: модель может сравнивать "расстояния" между токенами, а не просто помнить положение.
|
||||
- Такое кодирование **улучшает генерацию длинных последовательностей** и перенос модели на тексты большей длины, чем были в обучении.
|
||||
- Форма векторов и длина (норма) НЕ искажаются.
|
||||
|
||||
Математическая основа:
|
||||
Для позиции m и измерения i:
|
||||
θ_i = base^(-2i/d)
|
||||
q'_m = q_m * cos(mθ_i) + rotate(q_m) * sin(mθ_i)
|
||||
Как это работает? (главная формула)
|
||||
-------------------------------------
|
||||
Для каждой позиции m и пары компонент (2i, 2i+1) внутри head применяются:
|
||||
|
||||
θ_i = base^(-2i / d)
|
||||
q'_{m,2i} = q_{m,2i} * cos(m * θ_i) - q_{m,2i+1} * sin(m * θ_i)
|
||||
q'_{m,2i+1} = q_{m,2i+1} * cos(m * θ_i) + q_{m,2i} * sin(m * θ_i)
|
||||
|
||||
где d — размерность "головы" attention (head_size), base обычно 10_000.
|
||||
|
||||
То есть, берём каждый "вектор" (в рамках head), делим на четные/нечетные части и поворачиваем их на уникальный угол, связанный с позицией/частотой.
|
||||
|
||||
Архитектурные детали:
|
||||
---------------------
|
||||
- Ваш тензор должен быть строго 4-мерным: [batch, num_heads, seq_len, head_size].
|
||||
- Размер head_size должен быть чётным!
|
||||
- RoPE применяется отдельно к **Q** и **K** в механизме внимания (но не к V).
|
||||
|
||||
Где об этом читать:
|
||||
-------------------
|
||||
- RoFormer: Enhanced Transformer with Rotary Position Embedding
|
||||
https://arxiv.org/abs/2104.09864
|
||||
- Llama: Open and Efficient Foundation Language Models
|
||||
https://arxiv.org/abs/2302.13971
|
||||
- Визуализация позиционных кодировок:
|
||||
https://kazemnejad.com/blog/transformer_architecture_positional_encoding/
|
||||
|
||||
Пример использования:
|
||||
---------------------
|
||||
>>> rope = RoPE(head_size=64, max_seq_len=2048)
|
||||
>>> x = torch.randn(2, 8, 128, 64) # [batch, num_heads, seq_len, head_size]
|
||||
>>> x_enc = rope(x) # здесь вектор x обогатится позиционной информацией
|
||||
|
||||
Преимущества:
|
||||
- Относительное позиционное кодирование
|
||||
- Лучшая экстраполяция на длинные последовательности
|
||||
- Сохранение нормы векторов
|
||||
"""
|
||||
|
||||
import torch
|
||||
@@ -25,73 +55,136 @@ from typing import Optional
|
||||
|
||||
class RoPE(nn.Module):
|
||||
"""
|
||||
Rotary Positional Embeddings (RoPE) для механизма внимания.
|
||||
|
||||
Кодирует позиционную информацию через вращение векторов запросов и ключей
|
||||
в многомерном пространстве с использованием синусов и косинусов.
|
||||
|
||||
Args:
|
||||
head_size: Размерность головы внимания (должен быть четным)
|
||||
max_seq_len: Максимальная длина последовательности
|
||||
base: Базовое значение для вычисления частот (по умолчанию 10000)
|
||||
|
||||
Attributes:
|
||||
cos_matrix: Буферизованная матрица косинусов формы [max_seq_len, head_size//2]
|
||||
sin_matrix: Буферизованная матрица синусов формы [max_seq_len, head_size//2]
|
||||
Реализация RoPE (Rotary Positional Embeddings) для self-attention в трансформерах.
|
||||
|
||||
Этот слой добавляет позиционную информацию к векторам внимания (Q, K) —
|
||||
не с помощью простого сложения с positional embedding, а с помощью математического
|
||||
вращения (как если бы вы крутили стрелку на круге) для каждой пары компонент
|
||||
(even/odd) в каждом attention head.
|
||||
|
||||
Формула (для каждого токена и каждой пары компонент внутри head):
|
||||
θ_i = base^(-2i / d)
|
||||
out_{m,2i} = x_{m,2i} * cos(m * θ_i) - x_{m,2i+1} * sin(m * θ_i)
|
||||
out_{m,2i+1} = x_{m,2i+1} * cos(m * θ_i) + x_{m,2i} * sin(m * θ_i)
|
||||
где d — head_size, base обычно 10_000, степень i по head axis.
|
||||
|
||||
Какие входы принимает:
|
||||
----------------------
|
||||
- x: обязательно размерности [batch, num_heads, seq_len, head_size]!
|
||||
- head_size (размер внимания) должен быть чётным.
|
||||
- start_pos: опционально, позволяет сдвигать позиционный offset для генерации с кэшем.
|
||||
|
||||
Что возвращает:
|
||||
---------------
|
||||
- Тот же тензор (x), только со встроенной позиционной информацией (“повёрнутый” RoPE-кодировкой).
|
||||
- Форма и тип выходного тензора не меняются.
|
||||
|
||||
Где используется:
|
||||
-----------------
|
||||
- В любых современных LLM (Llama, Mistral, GPT-NeoX и др.) для повышения устойчивости и generalization transformer's attention.
|
||||
|
||||
Пример использования:
|
||||
---------------------
|
||||
>>> rope = RoPE(head_size=64, max_seq_len=2048)
|
||||
>>> x = torch.randn(2, 8, 128, 64) # (batch, num_heads, seq_len, head_size)
|
||||
>>> x_encoded = rope(x)
|
||||
|
||||
Подробнее про математику и примеры с визуализацией:
|
||||
---------------------------------------------------
|
||||
- RoFormer: https://arxiv.org/abs/2104.09864
|
||||
- Llama: https://arxiv.org/abs/2302.13971
|
||||
- Демонстрация наглядно: https://kazemnejad.com/blog/transformer_architecture_positional_encoding/
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self, head_size: int, max_seq_len: int, base: int = 10_000):
|
||||
"""
|
||||
Инициализация RoPE эмбеддингов.
|
||||
|
||||
Args:
|
||||
head_size: Размерность головы внимания (должен быть четным)
|
||||
max_seq_len: Максимальная поддерживаемая длина последовательности
|
||||
base: Базовое значение для вычисления частот (типично 10000)
|
||||
|
||||
Raises:
|
||||
AssertionError: Если head_size не четный
|
||||
Инициализация объекта RoPE — настраивает и предвычисляет все необходимые
|
||||
параметры для ротационного позиционного кодирования.
|
||||
|
||||
Аргументы:
|
||||
----------
|
||||
head_size : int
|
||||
Размер одного attention head (последнего измерения вектора) — сколько компонент
|
||||
(float-значений) отвечает за одну "голову". Должен быть ЧЁТНЫМ числом, иначе RoPE не применим.
|
||||
Обычно head_size = embed_dim // num_heads.
|
||||
max_seq_len : int
|
||||
Максимальная длина последовательности, которую RoPE сможет обработать.
|
||||
Если ваш текст длиннее этого числа — будет ошибка! Например, для GPT2 обычно 1024, у LLaMA — до 4096.
|
||||
Это число определяет размер внутренних буферов cos/sin.
|
||||
base : int, по умолчанию 10_000
|
||||
База для вычисления частот вращения (θ_i) для каждой компоненты.
|
||||
В оригинальных статьях почти всегда используют base=10000.
|
||||
Менять этот параметр не нужно, если вы не исследуете математические детали.
|
||||
|
||||
Что происходит внутри:
|
||||
----------------------
|
||||
- Проверяется чётность head_size.
|
||||
- Для каждого возможного положения в пределах max_seq_len и каждой пары component высчитываются уникальные cos/sin значения (матрицы частот).
|
||||
- Эти матрицы используются далее для быстрого наложения позиционного "вращения" токенов внутри attention.
|
||||
"""
|
||||
super().__init__()
|
||||
assert head_size % 2 == 0, "head_size должен быть четным"
|
||||
|
||||
|
||||
# Вычисление частот: θ_i = base^(-2i/d) для i ∈ [0, d/2-1]
|
||||
freqs = 1.0 / (base ** (2 * torch.arange(head_size // 2).float() / head_size))
|
||||
|
||||
|
||||
# Позиции от 0 до max_seq_len-1
|
||||
positions = torch.arange(max_seq_len).float()
|
||||
|
||||
|
||||
# Внешнее произведение: m * θ_i для всех позиций и частот
|
||||
freq_matrix = positions.unsqueeze(1) * freqs.unsqueeze(0)
|
||||
|
||||
# Предвычисление матриц косинусов и синусов
|
||||
self.register_buffer('cos_matrix', torch.cos(freq_matrix))
|
||||
self.register_buffer('sin_matrix', torch.sin(freq_matrix))
|
||||
self.register_buffer("cos_matrix", torch.cos(freq_matrix))
|
||||
self.register_buffer("sin_matrix", torch.sin(freq_matrix))
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
def forward(self, x: torch.Tensor, start_pos: int = 0) -> torch.Tensor:
|
||||
"""
|
||||
Применение ротационного позиционного кодирования к входному тензору.
|
||||
|
||||
Args:
|
||||
x: Входной тензор формы [batch_size, seq_len, head_size]
|
||||
|
||||
Returns:
|
||||
Тензор с примененным RoPE формы [batch_size, seq_len, head_size]
|
||||
|
||||
Алгоритм:
|
||||
1. Разделение векторов на четные и нечетные компоненты
|
||||
2. Применение вращения через синусы и косинусы
|
||||
3. Объединение компонент обратно
|
||||
Применяет ротационное позиционное кодирование (RoPE) к входному тензору.
|
||||
|
||||
Что делает эта функция:
|
||||
-----------------------
|
||||
Для каждого токена в последовательности внутри каждого attention head
|
||||
"поворачивает" его вектор в подпространстве (even/odd пар) на свой уникальный угол,
|
||||
зависящий от позиции токена. Это позволяет attention "понимать расстояния" между токенами.
|
||||
|
||||
Аргументы:
|
||||
----------
|
||||
x : torch.Tensor
|
||||
Входной тензор строго формы [batch, num_heads, seq_len, head_size].
|
||||
Это обычно либо Q, либо K из механизма внимания.
|
||||
start_pos : int, по умолчанию 0
|
||||
Сдвиг начала позиции (нужно при генерации с кэшем, почти всегда оставить 0 если не пишете автогенератор).
|
||||
|
||||
Возвращает:
|
||||
-----------
|
||||
torch.Tensor с теми же формой и типом, что и x, но уже с наложенным позиционным кодированием.
|
||||
|
||||
Важно:
|
||||
-------
|
||||
- Если передан тензор не 4D, будет выброшено исключение!
|
||||
- Не изменяет значения "на месте", всегда возвращает новый тензор.
|
||||
|
||||
Пример:
|
||||
-------
|
||||
>>> rope = RoPE(head_size=64, max_seq_len=1024)
|
||||
>>> q = torch.randn(2, 8, 32, 64) # batch, num_heads, seq_len, head_size
|
||||
>>> q_rope = rope(q)
|
||||
"""
|
||||
seq_len = x.size(1)
|
||||
|
||||
assert x.ndim == 4, "RoPE поддерживает только 4D-вход [batch, num_heads, seq_len, head_size]"
|
||||
batch_size, num_heads, seq_len, head_size = x.shape
|
||||
|
||||
# Берем нужную часть матриц и приводим к типу x
|
||||
cos = self.cos_matrix[:seq_len].to(x.dtype) # [seq_len, head_size//2]
|
||||
sin = self.sin_matrix[:seq_len].to(x.dtype) # [seq_len, head_size//2]
|
||||
|
||||
# Разделяем на четные и нечетные компоненты
|
||||
x_even = x[:, :, 0::2] # [batch_size, seq_len, head_size//2]
|
||||
x_odd = x[:, :, 1::2] # [batch_size, seq_len, head_size//2]
|
||||
cos = self.cos_matrix[start_pos:start_pos+seq_len].to(x.dtype) # [seq_len, head_size//2]
|
||||
sin = self.sin_matrix[start_pos:start_pos+seq_len].to(x.dtype) # [seq_len, head_size//2]
|
||||
|
||||
# Явное изменение формы для broadcasting
|
||||
cos = cos.reshape(1, 1, seq_len, head_size // 2)
|
||||
sin = sin.reshape(1, 1, seq_len, head_size // 2)
|
||||
|
||||
# Разделяем на четные и нечетные компоненты по ПОСЛЕДНЕМУ измерению
|
||||
x_even = x[..., 0::2] # [batch_size, num_heads, seq_len, head_size//2]
|
||||
x_odd = x[..., 1::2] # [batch_size, num_heads, seq_len, head_size//2]
|
||||
|
||||
# Применяем поворот: q' = q * cos(mθ) + rotate(q) * sin(mθ)
|
||||
x_rotated_even = x_even * cos - x_odd * sin
|
||||
@@ -101,4 +194,4 @@ class RoPE(nn.Module):
|
||||
x_rotated = torch.stack([x_rotated_even, x_rotated_odd], dim=-1)
|
||||
x_rotated = x_rotated.flatten(-2) # [batch_size, seq_len, head_size]
|
||||
|
||||
return x_rotated
|
||||
return x_rotated
|
||||
|
||||
@@ -1,19 +1,70 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
class SiLU(nn.Module):
|
||||
"""
|
||||
SiLU (Swish) — современная активационная функция для нейросетей.
|
||||
|
||||
Научная суть:
|
||||
- Формула: $SiLU(x) = x * \sigm(x)$, где $\sigm(x)$ — сигмоида.
|
||||
- Более гладкая альтернатива ReLU, улучшает поток градиентов в глубоких сетях.
|
||||
- Используется во многих «state-of-the-art» архитектурах (SwiGLU, PaLM, LLaMA).
|
||||
- Также известна как Swish (Ramachandran et al, 2017).
|
||||
Пример:
|
||||
>>> act = SiLU()
|
||||
>>> x = torch.tensor([-1.0, 0.0, 1.0])
|
||||
>>> print(act(x))
|
||||
SiLU (Sigmoid Linear Unit, также известная как Swish) — современная функция активации для нейросетей и LLM.
|
||||
|
||||
Назначение:
|
||||
-----------
|
||||
- Формирует плавную нелинейную активацию: SiLU(x) = x * sigmoid(x).
|
||||
- Активно используется во всех новых архитектурах для больших языковых моделей (PaLM, LLaMA, Mistral, GPT-4 и др.).
|
||||
- Дает лучший поток градиентов по сравнению с ReLU, SELU, GELU в глубоких слоях — позволяет делать сети больше и глубже.
|
||||
|
||||
Мотивация и свойства:
|
||||
---------------------
|
||||
- SiLU объединяет свойства identity (для больших x) и ReLU (для отрицательных x, где есть затухание), но более плавно.
|
||||
- Позволяет проходить отрицательным значениям, а не "обрубает" как ReLU.
|
||||
- Better for optimization and training dynamics in deep LLMs, приводит к более богатым аппроксимациям.
|
||||
|
||||
Математическая формула:
|
||||
-----------------------
|
||||
SiLU(x) = x * sigmoid(x)
|
||||
где sigmoid(x) = 1 / (1 + exp(-x))
|
||||
|
||||
Сравнение с другими активациями:
|
||||
--------------------------------
|
||||
- ReLU(x): max(0, x) — простая отсечка
|
||||
- GELU(x): плавная вероятностная активация (используется в BERT/GPT-2)
|
||||
- SiLU(x): плавная альтернатива, часто лучше в современных LLM
|
||||
- Swish (Ramachandran et al., 2017) = SiLU
|
||||
|
||||
Args:
|
||||
-----
|
||||
Нет learnable параметров, чисто функциональная активация.
|
||||
|
||||
Пример использования:
|
||||
---------------------
|
||||
>>> silu = SiLU()
|
||||
>>> x = torch.tensor([-2.0, 0.0, 2.0])
|
||||
>>> print(silu(x)) # тензор с элементами [-0.2384, 0.0, 1.7616] (примерно)
|
||||
|
||||
References:
|
||||
-----------
|
||||
- Ramachandran et al., "Searching for Activation Functions", 2017: https://arxiv.org/abs/1710.05941
|
||||
- LLaMA: https://arxiv.org/abs/2302.13971
|
||||
- Swish в TensorFlow: https://arxiv.org/abs/1710.05941
|
||||
- Сравнение всех актив. функций: https://paperswithcode.com/method/silu
|
||||
"""
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
return torch.sigmoid(x) * x
|
||||
"""
|
||||
Применяет SiLU активацию ко всем компонентам тензора (x * sigmoid(x)).
|
||||
|
||||
Args:
|
||||
-----
|
||||
x : torch.Tensor
|
||||
Входной тензор любой формы.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
torch.Tensor — тензор той же формы, каждый элемент преобразован по формуле SiLU(x).
|
||||
|
||||
Пример:
|
||||
-------
|
||||
>>> silu = SiLU()
|
||||
>>> x = torch.linspace(-3, 3, 7)
|
||||
>>> y = silu(x)
|
||||
"""
|
||||
return torch.sigmoid(x) * x
|
||||
|
||||
@@ -24,37 +24,61 @@ from .silu import SiLU
|
||||
|
||||
class SwiGLU(nn.Module):
|
||||
"""
|
||||
SwiGLU (Swish-Gated Linear Unit) — современная нелинейность для архитектур LLM (LLaMA, PaLM).
|
||||
SwiGLU (Swish-Gated Linear Unit) — эффективная feed-forward нелинейность для трансформеров (LLAMA, PaLM, Mistral).
|
||||
|
||||
Реализация SwiGLU активационной функции.
|
||||
|
||||
Состоит из трех линейных слоев и активации SiLU:
|
||||
1. Gate слой + SiLU активация
|
||||
2. Up слой (линейное преобразование)
|
||||
3. Element-wise multiplication gate и up
|
||||
4. Down слой (линейная проекция)
|
||||
|
||||
Научная суть:
|
||||
- Сохраняет преимущества GLU (раздельные гейтом и телом) + мощность Swish/SiLU активации.
|
||||
- Дает надежную гладкую активацию, хорошо работает на больших масштабах.
|
||||
- Статья: "GLU Variants Improve Transformer" (Shazeer, 2020).
|
||||
Назначение:
|
||||
-----------
|
||||
- Улучшает классический блок FeedForward (FFN) в трансформерах за счёт \"gating\" (механизма управления информационным потоком).
|
||||
- Использует нелинейность SiLU (Swish) вместо ReLU или GELU, повышая capacity блока.
|
||||
- Является дефолтом во всех современных LLM, начиная с PaLM, LLaMA и Mistral.
|
||||
|
||||
Формула:
|
||||
SwiGLU(x) = SiLU(W_g·x) * (W_u·x)
|
||||
где SiLU(x) = x*sigma(x)
|
||||
Формула и математика:
|
||||
---------------------
|
||||
Пусть x — вход, then:
|
||||
|
||||
SwiGLU(x) = (SiLU(W_g x + b_g)) ⊙ (W_u x + b_u) W_d + b_d
|
||||
|
||||
Типовая реализация (как здесь, по LLAMA/Mistral):
|
||||
gate = SiLU(Linear_gate(x)) # фитчерный \"gate\"
|
||||
up = Linear_up(x) # пропускная ветка
|
||||
mult = gate * up # поэлементное умножение (контроль информации)
|
||||
out = Linear_down(mult) # финальная проекция
|
||||
out = Dropout(out) # регуляризация
|
||||
|
||||
Почему это работает:
|
||||
-------------------
|
||||
- Gating позволяет информации проходить \"частично\", динамически подавляя/усиливая сигналы в hidden-space.
|
||||
- SiLU обеспечивает smooth градиенты (лучше для обучения LLM).
|
||||
- В экспериментах (PaLM, LLAMA) SwiGLU consistently outperforms ReLU, GELU, обычные GLU.
|
||||
|
||||
Параметры конструктора:
|
||||
-----------------------
|
||||
emb_size: int
|
||||
Размерность входного (и выходного) признакового пространства.
|
||||
dropout: float
|
||||
Dropout после final linear (обычно около 0.1).
|
||||
|
||||
Пример использования:
|
||||
---------------------
|
||||
>>> block = SwiGLU(emb_size=512, dropout=0.1)
|
||||
>>> x = torch.randn(8, 16, 512)
|
||||
>>> y = block(x)
|
||||
>>> print(y.shape) # torch.Size([8, 16, 512])
|
||||
|
||||
References:
|
||||
-----------
|
||||
- Shazeer, \"GLU Variants Improve Transformer\", 2020: https://arxiv.org/abs/2002.05202
|
||||
- PaLM: https://arxiv.org/abs/2204.02311 (Section 4.1)
|
||||
- LLaMA: https://arxiv.org/abs/2302.13971
|
||||
- Mistral: https://arxiv.org/abs/2310.06825
|
||||
- HuggingFace discussion: https://huggingface.co/docs/transformers/main/en/model_doc/llama
|
||||
|
||||
Args:
|
||||
emb_size (int): размер входов/выходов
|
||||
dropout (float): после выходной проекции
|
||||
Пример:
|
||||
>>> ff = SwiGLU(emb_size=512, dropout=0.1)
|
||||
>>> y = ff(torch.randn(2,10,512))
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self, emb_size: int, dropout: float = 0.1):
|
||||
"""
|
||||
Инициализация SwiGLU слоя.
|
||||
|
||||
|
||||
Args:
|
||||
emb_size: Размерность входных/выходных эмбеддингов
|
||||
dropout: Вероятность dropout (по умолчанию 0.1)
|
||||
@@ -68,34 +92,39 @@ class SwiGLU(nn.Module):
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Прямой проход через SwiGLU слой.
|
||||
|
||||
Прямой проход через блок SwiGLU.
|
||||
|
||||
Args:
|
||||
x: Входной тензор формы [batch_size, seq_len, emb_size]
|
||||
|
||||
-----
|
||||
x : torch.Tensor
|
||||
Входной тензор формы [batch_size, seq_len, emb_size]
|
||||
|
||||
Returns:
|
||||
Выходной тензор формы [batch_size, seq_len, emb_size]
|
||||
|
||||
--------
|
||||
torch.Tensor той же формы
|
||||
|
||||
Алгоритм:
|
||||
---------
|
||||
1. gate = SiLU(linear_gate(x))
|
||||
2. up = linear_up(x)
|
||||
3. output = linear_down(gate ⊙ up)
|
||||
4. apply dropout
|
||||
3. mult = gate * up # поэлементно
|
||||
4. out = linear_down(mult)
|
||||
5. out = dropout(out)
|
||||
"""
|
||||
# Gate ветвь: линейное преобразование + активация
|
||||
gate_out = self._gate(x) # [batch, seq, 4*emb]
|
||||
activation_out = self._activation(gate_out) # [batch, seq, 4*emb]
|
||||
|
||||
gate_out = self._gate(x) # [batch, seq, 4*emb]
|
||||
activation_out = self._activation(gate_out) # [batch, seq, 4*emb]
|
||||
|
||||
# Up ветвь: линейное преобразование
|
||||
up_out = self._up(x) # [batch, seq, 4*emb]
|
||||
|
||||
up_out = self._up(x) # [batch, seq, 4*emb]
|
||||
|
||||
# Element-wise multiplication (gating mechanism)
|
||||
out = up_out * activation_out # поэлементное умножение!
|
||||
|
||||
out = up_out * activation_out # поэлементное умножение!
|
||||
|
||||
# Final projection and dropout
|
||||
out = self._down(out) # [batch, seq, emb]
|
||||
out = self._down(out) # [batch, seq, emb]
|
||||
return self._dropout(out)
|
||||
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
"""Строковое представление для отладки."""
|
||||
return f'emb_size={self._gate.in_features}, dropout={self._dropout.p}'
|
||||
return f"emb_size={self._gate.in_features}, dropout={self._dropout.p}"
|
||||
|
||||
@@ -2,68 +2,96 @@ import torch
|
||||
from torch import nn
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
class TokenEmbeddings(nn.Module):
|
||||
"""
|
||||
Токеновые эмбеддинги — обучаемые векторные представления для каждого токена словаря.
|
||||
TokenEmbeddings — обучаемый слой эмбеддингов для токенов (слов, сабслов, байтов и т.д.) в трансформерах.
|
||||
|
||||
Преобразует целочисленные индексы токенов в обучаемые векторные представления фиксированного размера.
|
||||
Обычно используется как первый слой в нейронных сетях для задач NLP.
|
||||
|
||||
Научная суть:
|
||||
- Первый шаг для любого NLP-модуля: вместо индекса токена подаём его dense-вектор.
|
||||
- Эти вектора изучаются в процессе обучения и отражают скрытые взаимосвязи между токенами.
|
||||
- Позволяют обрабатывать тексты как матрицу чисел, а не как символы или индексы.
|
||||
- Аналог словарных эмбеддингов в word2vec, но обучаются энд-ту-энд с моделью.
|
||||
Назначение:
|
||||
-----------
|
||||
- Преобразует каждый целочисленный индекс-токен из словаря (vocab) в обучаемый dense-вектор фиксированной длины.
|
||||
- Это "входной слой" для любой нейросетевой языковой модели: позволяет работать с текстом как с матрицей чисел, а не с индексами/категориальными значениями.
|
||||
- Обеспечивает возможность end-to-end обучения embedding-матрицы совместно с целью модели.
|
||||
|
||||
Мотивация и особенности:
|
||||
------------------------
|
||||
- Каждый токен (индекс) получает свой learnable embedding (float-вектор).
|
||||
- Размерность слоя: [vocab_size, emb_size] (матрица эмбеддингов).
|
||||
- Веса эмбеддингов инициализируются случайно и обучаются вместе с остальной моделью.
|
||||
- Аналог таблицы эмбеддингов в word2vec/fastText, но управляется end-to-end.
|
||||
- Могут использоваться с любым токенизатором (BPE, SentencePiece, WordPiece и др.).
|
||||
|
||||
Формула:
|
||||
--------
|
||||
emb(x) = W[x], где W — матрица размера [vocab_size, emb_dim], x — индексы shape [batch, seq_len]
|
||||
На выходе: тензор [batch, seq_len, emb_dim]
|
||||
|
||||
Args:
|
||||
vocab_size (int): размер словаря (количество уникальных токенов)
|
||||
emb_size (int): размерность эмбеддинга (длина вектора)
|
||||
-----
|
||||
vocab_size: int — размер словаря/алфавита (количество уникальных токенов)
|
||||
emb_size: int — размерность (длина) эмбеддинговых векторов (обычно 256/512/1024...)
|
||||
|
||||
Примечание:
|
||||
- Индексы должны быть в диапазоне [0, vocab_size-1]
|
||||
- Эмбеддинги инициализируются случайно и обучаются в процессе тренировки модели
|
||||
|
||||
Пример:
|
||||
>>> emb = TokenEmbeddings(vocab_size=10000, emb_size=256)
|
||||
>>> tokens = torch.tensor([[1, 2, 3]])
|
||||
>>> vecs = emb(tokens)
|
||||
>>> vecs.shape # torch.Size([1, 3, 256])
|
||||
-------
|
||||
>>> embedding = TokenEmbeddings(vocab_size=5000, emb_size=256)
|
||||
>>> tokens = torch.tensor([[12, 47, 301], [6, 88, 413]])
|
||||
>>> vecs = embedding(tokens)
|
||||
>>> print(vecs.shape) # torch.Size([2, 3, 256])
|
||||
|
||||
References:
|
||||
-----------
|
||||
- Mikolov et al., "Efficient Estimation of Word Representations in Vector Space (word2vec)", 2013
|
||||
- Vaswani et al., "Attention is All You Need", 2017: https://arxiv.org/abs/1706.03762
|
||||
- BPE, SentencePiece overviews: https://huggingface.co/docs/transformers/tokenizer_summary
|
||||
"""
|
||||
|
||||
def __init__(self, vocab_size: int, emb_size: int):
|
||||
"""
|
||||
Инициализация слоя эмбеддингов.
|
||||
|
||||
Args:
|
||||
-----
|
||||
vocab_size: int
|
||||
Размер словаря (уникальных токенов/индексов).
|
||||
emb_size: int
|
||||
Длина эмбеддингового вектора для каждого токена.
|
||||
|
||||
Внутри:
|
||||
-------
|
||||
- Создаёт nn.Embedding с [vocab_size, emb_size] learnable весами.
|
||||
"""
|
||||
super().__init__()
|
||||
self._embedding = nn.Embedding(
|
||||
num_embeddings=vocab_size,
|
||||
embedding_dim=emb_size
|
||||
num_embeddings=vocab_size, embedding_dim=emb_size
|
||||
)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
"""
|
||||
Получить эмбеддинги для входных токенов.
|
||||
|
||||
Args:
|
||||
-----
|
||||
x : torch.Tensor
|
||||
Тензор shape [...], содержащий индексы токенов (каждое значение от 0 до vocab_size-1).
|
||||
|
||||
Returns:
|
||||
--------
|
||||
torch.Tensor — тензор обычной формы [..., emb_size] (на каждую позицию — свой embedding-вектор).
|
||||
|
||||
Пример:
|
||||
-------
|
||||
>>> embedding = TokenEmbeddings(vocab_size=100, emb_size=64)
|
||||
>>> tokens = torch.tensor([[0, 99, 5]])
|
||||
>>> vecs = embedding(tokens) # [1, 3, 64]
|
||||
"""
|
||||
return self._embedding(x)
|
||||
|
||||
@property
|
||||
def num_embeddings(self) -> int:
|
||||
"""Возвращает размер словаря"""
|
||||
"""Возвращает размер словаря (количество уникальных токенов)."""
|
||||
return self._embedding.num_embeddings
|
||||
|
||||
@property
|
||||
def embedding_dim(self) -> int:
|
||||
"""Возвращает размерность эмбеддингов"""
|
||||
"""Возвращает размерность эмбеддингов (длина вектора каждого токена)."""
|
||||
return self._embedding.embedding_dim
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Пример использования
|
||||
embedding = TokenEmbeddings(vocab_size=100, emb_size=128)
|
||||
|
||||
# Создаем тензор с индексами в пределах vocab_size (0-99)
|
||||
tensor = torch.tensor([
|
||||
[11, 45, 76, 34],
|
||||
[34, 67, 45, 54]
|
||||
])
|
||||
|
||||
# Проверяем индексы
|
||||
if (tensor >= 100).any():
|
||||
raise ValueError("Some indices are out of vocabulary range (vocab_size=100)")
|
||||
|
||||
output = embedding(tensor)
|
||||
print("Embeddings shape:", output.shape)
|
||||
print(f"{output.shape} | {output.mean().item():.11f}") # Формат как в ТЗ
|
||||
0
llm/src/llm/datasets/__init__.py
Normal file
0
llm/src/llm/datasets/__init__.py
Normal file
120
llm/src/llm/datasets/streaming_text_dataset.py
Normal file
120
llm/src/llm/datasets/streaming_text_dataset.py
Normal file
@@ -0,0 +1,120 @@
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
from typing import List, Any
|
||||
|
||||
|
||||
class StreamingTextDataset(Dataset):
|
||||
"""
|
||||
StreamingTextDataset — потоковый датасет для LLM/NLP на базе списка строк.
|
||||
|
||||
Назначение:
|
||||
-----------
|
||||
- Позволяет эффективно обрабатывать большие текстовые выборки, итерируя по заранее подготовленному списку строк.
|
||||
- При итерации строки токенизируются на лету, превращаются в примеры фиксированной длины block_size (padding/truncation внутри класса).
|
||||
- Поддерживает стандартный DataLoader PyTorch.
|
||||
|
||||
Ключевые особенности:
|
||||
---------------------
|
||||
- Не требует загрузки всей коллекции токенов в RAM: поддерживает работу с любым размером датасета, если список строк заранее подготовлен.
|
||||
- Каждый пример (sample) формируется при обращении; не хранит массив батчей, не использует файлы внутри.
|
||||
- Поддерживает любой токенизатор с методом encode (например, BPE, SentencePiece, HF Tokenizer).
|
||||
- batch_size и параллелизм (num_workers) контролируются через DataLoader.
|
||||
|
||||
Аргументы конструктора:
|
||||
-----------------------
|
||||
texts: List[str] — список строк (предварительно загруженных обучающих примеров).
|
||||
tokenizer: BaseTokenizer/Any — объект с методом encode(str, **kwargs) -> List[int].
|
||||
block_size: int — длина одного выходного примера в токенах (padding/truncation если нужно).
|
||||
|
||||
Пример использования:
|
||||
---------------------
|
||||
>>> texts = open("wiki_sample.txt", encoding="utf-8").read().splitlines()
|
||||
>>> ds = StreamingTextDataset(texts, tokenizer=tokenizer, block_size=512)
|
||||
>>> loader = torch.utils.data.DataLoader(ds, batch_size=8)
|
||||
>>> for batch in loader:
|
||||
... print(batch['input_ids'].shape) # torch.Size([8, 512])
|
||||
|
||||
Особенности:
|
||||
------------
|
||||
- Проектирован для бесконечного стриминга текстовых данных из больших коллекций.
|
||||
- При batch_size > 1 каждый batch формируется DataLoader-ом из yield'ов этого датасета.
|
||||
- Не работает с файлами напрямую, только со строками (списком).
|
||||
- Подходит для обучения LLM, тестирования, дообучения, оценки на больших потоковых данных.
|
||||
|
||||
References:
|
||||
-----------
|
||||
- PyTorch IterableDataset: https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset
|
||||
- HuggingFace streaming datasets: https://huggingface.co/docs/datasets/stream
|
||||
- Практика масштабного обучения LLM: https://github.com/karpathy/nanoGPT/issues/182
|
||||
"""
|
||||
|
||||
def __init__(self, texts: List[str], tokenizer: Any, block_size: int = 128):
|
||||
"""
|
||||
Инициализация StreamingTextDataset из списка строк.
|
||||
|
||||
Аргументы:
|
||||
texts (List[str]): Список строк — текстовые обучающие примеры; весь датасет должен помещаться в этот список.
|
||||
tokenizer (Any): Токенизатор с методом encode(text, **kwargs) -> List[int].
|
||||
block_size (int, по умолчанию 128): Желаемая длина токенизированного примера (padding/truncation внутри класса).
|
||||
|
||||
Особенности:
|
||||
- Поддерживает итеративную загрузку, эффективен для больших текстовых выборок.
|
||||
- Каждый пример автоматически дополняется или усекается до block_size.
|
||||
- Не читает данные из файла/буфера, а только из заранее подготовленного списка строк.
|
||||
|
||||
Пример:
|
||||
>>> ds = StreamingTextDataset(texts=all_lines, tokenizer=tokenizer, block_size=256)
|
||||
>>> for ex in ds:
|
||||
... print(ex['input_ids'].shape) # torch.Size([256])
|
||||
"""
|
||||
self.texts = texts
|
||||
self.tokenizer = tokenizer
|
||||
self.block_size = block_size
|
||||
|
||||
# Получаем pad_token_id из токенизатора
|
||||
self.pad_token_id = getattr(tokenizer, "pad_token_id", 0)
|
||||
|
||||
def __len__(self):
|
||||
"""
|
||||
Возвращает количество доступных примеров в датасете.
|
||||
|
||||
Returns:
|
||||
int: Число примеров (равно длине исходного списка строк).
|
||||
"""
|
||||
return len(self.texts)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
"""
|
||||
Получить обработанный пример по индексу из потокового датасета.
|
||||
|
||||
Аргументы:
|
||||
idx (int): Индекс примера в исходном списке строк.
|
||||
|
||||
Возвращает:
|
||||
dict: Словарь с тензорами для обучения LLM:
|
||||
- 'input_ids': torch.Tensor формы [block_size] — индексы токенов (padding/truncation выполнены)
|
||||
- 'labels': torch.Tensor формы [block_size] — целевые метки (обычно совпадают с input_ids)
|
||||
|
||||
Пример:
|
||||
>>> item = dataset[10]
|
||||
>>> assert isinstance(item, dict)
|
||||
>>> assert item['input_ids'].shape == (block_size,)
|
||||
>>> assert 'labels' in item
|
||||
"""
|
||||
text = self.texts[idx]
|
||||
|
||||
# Токенизация на лету
|
||||
input_ids = self.tokenizer.encode(text, add_special_tokens=False)
|
||||
|
||||
# Обрезаем или дополняем до нужной длины
|
||||
if len(input_ids) > self.block_size:
|
||||
input_ids = input_ids[: self.block_size]
|
||||
else:
|
||||
input_ids = input_ids + [self.pad_token_id] * (
|
||||
self.block_size - len(input_ids)
|
||||
)
|
||||
|
||||
input_ids = torch.tensor(input_ids, dtype=torch.long)
|
||||
labels = input_ids.clone()
|
||||
|
||||
return {"input_ids": input_ids, "labels": labels}
|
||||
112
llm/src/llm/datasets/text_dataset.py
Normal file
112
llm/src/llm/datasets/text_dataset.py
Normal file
@@ -0,0 +1,112 @@
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
from typing import List, Any
|
||||
|
||||
|
||||
class TextDataset(Dataset):
|
||||
"""
|
||||
TextDataset — простой датасет для подачи обучающих токенов LLM (batch-режим или по одному примеру).
|
||||
|
||||
Назначение:
|
||||
-----------
|
||||
- Хранит последовательности текста (каждую строку или пример) в виде списка строк.
|
||||
- При обращении сам токенизирует строку в последовательность индексов с помощью заданного токенизатора.
|
||||
- Каждый пример автоматически усекётся или будет дополнен до фиксированной длины block_size (padding — zeros).
|
||||
|
||||
Формат и аргументы конструктора:
|
||||
-------------------------------
|
||||
texts: List[str]
|
||||
Список строк, каждая из которых рассматривается как отдельный обучающий пример.
|
||||
tokenizer: любой объект с методом encode(str, **kwargs) → List[int]
|
||||
Обеспечивает сопоставление строки списку токенов (например, BPE, HuggingFace, SentencePiece и др.).
|
||||
block_size: int, по умолчанию 128
|
||||
Желаемая длина выходной последовательности (padding/truncation внутри класса).
|
||||
|
||||
Особенности:
|
||||
------------
|
||||
- Класс не работает с файлами напрямую: данные передаются готовым списком строк.
|
||||
- При недостаточной длине пример дополняется паддингом (нулём или другим токеном, зависит от реализации).
|
||||
- Может возвращать dict с input_ids, labels и прочими ключами (см. реализацию в функции __getitem__).
|
||||
|
||||
Пример использования:
|
||||
---------------------
|
||||
>>> with open("dataset.txt", encoding="utf-8") as f:
|
||||
... texts = f.read().splitlines()
|
||||
>>> dataset = TextDataset(texts, tokenizer, block_size=256)
|
||||
>>> from torch.utils.data import DataLoader
|
||||
>>> loader = DataLoader(dataset, batch_size=4)
|
||||
>>> for item in loader:
|
||||
... # item['input_ids'] для обучения LLM
|
||||
|
||||
References:
|
||||
-----------
|
||||
- Torch Dataset: https://pytorch.org/docs/stable/data.html
|
||||
- Примеры LLM датасетов в open-source: https://github.com/karpathy/nanoGPT/blob/master/data/shakespeare_char/tokenize.py
|
||||
"""
|
||||
|
||||
def __init__(self, texts: List[str], tokenizer: Any, block_size: int = 128):
|
||||
"""
|
||||
Инициализация датасета из списка строк.
|
||||
|
||||
Аргументы:
|
||||
texts (List[str]): Список строк — каждый элемент отдельный обучающий пример.
|
||||
tokenizer (Any): Токенизатор с методом encode(str, **kwargs) -> List[int].
|
||||
block_size (int, по умолчанию 128): Желаемая длина результата —
|
||||
длинные последовательности будут усечены, короткие — дополнены паддингом (pad_token_id или 0).
|
||||
|
||||
Особенности:
|
||||
- Строки не фильтруются и не изменяются внутри датасета.
|
||||
- Для PAD используется pad_token_id из токенизатора (если есть) либо 0.
|
||||
- Dict, возвращаемый __getitem__, содержит 'input_ids' и 'labels'.
|
||||
|
||||
Пример:
|
||||
>>> dataset = TextDataset(["hello world", "test string"], tokenizer, block_size=16)
|
||||
"""
|
||||
self.examples = []
|
||||
self.tokenizer = tokenizer
|
||||
self.block_size = block_size
|
||||
|
||||
for text in texts:
|
||||
# Кодируем текст в токены
|
||||
input_ids = tokenizer.encode(text, add_special_tokens=False)
|
||||
|
||||
# Обрезаем или дополняем до нужной длины
|
||||
if len(input_ids) > block_size:
|
||||
input_ids = input_ids[:block_size]
|
||||
else:
|
||||
# Дополняем pad_token_id
|
||||
pad_token_id = getattr(tokenizer, "pad_token_id", 0)
|
||||
input_ids = input_ids + [pad_token_id] * (block_size - len(input_ids))
|
||||
|
||||
self.examples.append(input_ids)
|
||||
|
||||
def __len__(self):
|
||||
"""
|
||||
Возвращает количество примеров в датасете (длина списка текстов).
|
||||
|
||||
Returns:
|
||||
int: Число примеров в датасете.
|
||||
"""
|
||||
return len(self.examples)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
"""
|
||||
Получить пример из датасета по индексу.
|
||||
|
||||
Аргументы:
|
||||
idx (int): Индекс примера.
|
||||
|
||||
Возвращает:
|
||||
dict: Словарь с тензорами токенов для модели:
|
||||
- 'input_ids': torch.Tensor shape [block_size], индексы токенов для входа.
|
||||
- 'labels': torch.Tensor shape [block_size], метки для LM задачи (обычно совпадают с input_ids).
|
||||
|
||||
Пример:
|
||||
>>> item = dataset[7]
|
||||
>>> assert isinstance(item, dict)
|
||||
>>> assert item['input_ids'].shape == (block_size,)
|
||||
>>> assert 'labels' in item
|
||||
"""
|
||||
input_ids = torch.tensor(self.examples[idx], dtype=torch.long)
|
||||
labels = input_ids.clone()
|
||||
return {"input_ids": input_ids, "labels": labels}
|
||||
124
llm/src/llm/datasets/text_with_special_tokens_dataset.py
Normal file
124
llm/src/llm/datasets/text_with_special_tokens_dataset.py
Normal file
@@ -0,0 +1,124 @@
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
from typing import List, Any
|
||||
from llm.datasets.text_dataset import TextDataset
|
||||
|
||||
|
||||
class TextWithSpecialTokensDataset(TextDataset):
|
||||
"""
|
||||
TextWithSpecialTokensDataset — датасет для языковых моделей с поддержкой специальных токенов (BOS, EOS, PAD).
|
||||
|
||||
Назначение:
|
||||
-----------
|
||||
- Работает с уже готовым списком строк (не с файлом!).
|
||||
- Токенизирует строки с помощью заданного токенизатора, вручную вставляет специальные токены (BOS/ EOS/ PAD).
|
||||
- Обрезает или дополняет каждую последовательность до длины block_size.
|
||||
|
||||
Аргументы конструктора:
|
||||
-----------------------
|
||||
texts (List[str]): Список обучающих строк (примеров).
|
||||
tokenizer (Any): Любой токенизатор с методом encode(text, **kwargs).
|
||||
block_size (int, default=128): Желаемая длина примера (padding/truncation).
|
||||
add_bos (bool, default=False): Если True, добавляет BOS-токен в начало каждой последовательности.
|
||||
add_eos (bool, default=False): Если True, добавляет EOS-токен в конец.
|
||||
|
||||
Особенности:
|
||||
------------
|
||||
- Если pad_token_id не задан — по умолчанию паддит нулями.
|
||||
- Все returned примеры — dict с 'input_ids' и 'labels' (shape == block_size).
|
||||
- Обрезание/дополнение учётное: BOS/EOS не "выдавливаются" обрезкой.
|
||||
- Пример вызова:
|
||||
>>> texts = ["пример текста", "ещё текст"]
|
||||
>>> ds = TextWithSpecialTokensDataset(texts, tokenizer, block_size=16, add_bos=True, add_eos=True)
|
||||
>>> out = ds[0]
|
||||
>>> assert out['input_ids'].shape == (16,)
|
||||
|
||||
References:
|
||||
-----------
|
||||
- OpenAI GPT-2 data loader: https://github.com/openai/gpt-2/blob/master/src/encode.py
|
||||
- HuggingFace data docs: https://huggingface.co/docs/transformers/pad_truncation
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
texts: List[str],
|
||||
tokenizer: Any,
|
||||
block_size: int = 128,
|
||||
add_bos: bool = False,
|
||||
add_eos: bool = False,
|
||||
):
|
||||
"""
|
||||
Инициализация датасета с поддержкой специальных токенов.
|
||||
|
||||
Args:
|
||||
texts (List[str]): Список строк (все ваши обучающие примеры).
|
||||
tokenizer (Any): Токенизатор с методом encode(text, **kwargs).
|
||||
block_size (int): Длина выходного примера.
|
||||
add_bos (bool): Добавлять ли BOS токен в начало.
|
||||
add_eos (bool): Добавлять ли EOS токен в конец.
|
||||
"""
|
||||
self.examples = []
|
||||
self.tokenizer = tokenizer
|
||||
self.block_size = block_size
|
||||
self.add_bos = add_bos
|
||||
self.add_eos = add_eos
|
||||
|
||||
for text in texts:
|
||||
# Кодируем с специальными токенами
|
||||
input_ids = tokenizer.encode(
|
||||
text, add_special_tokens=True, add_bos_token=add_bos, add_eos_token=add_eos
|
||||
)
|
||||
|
||||
# Учитываем специальные токены при обрезке/дополнении
|
||||
effective_block_size = block_size
|
||||
if add_bos:
|
||||
effective_block_size -= 1
|
||||
if add_eos:
|
||||
effective_block_size -= 1
|
||||
|
||||
if len(input_ids) > effective_block_size:
|
||||
input_ids = input_ids[:effective_block_size]
|
||||
|
||||
# Добавляем специальные токены если нужно
|
||||
if (
|
||||
add_bos
|
||||
and hasattr(tokenizer, "bos_token_id")
|
||||
and tokenizer.bos_token_id is not None
|
||||
):
|
||||
input_ids = [tokenizer.bos_token_id] + input_ids
|
||||
if (
|
||||
add_eos
|
||||
and hasattr(tokenizer, "eos_token_id")
|
||||
and tokenizer.eos_token_id is not None
|
||||
):
|
||||
input_ids = input_ids + [tokenizer.eos_token_id]
|
||||
|
||||
# Дополняем до полной длины
|
||||
pad_token_id = getattr(tokenizer, "pad_token_id", 0)
|
||||
if len(input_ids) < block_size:
|
||||
input_ids = input_ids + [pad_token_id] * (block_size - len(input_ids))
|
||||
|
||||
self.examples.append(input_ids)
|
||||
|
||||
def __len__(self):
|
||||
"""
|
||||
Возвращает количество примеров в датасете.
|
||||
|
||||
Returns:
|
||||
int: Размер (len(self.examples)).
|
||||
"""
|
||||
return len(self.examples)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
"""
|
||||
Получить пример с учётом специальных токенов и паддинга.
|
||||
|
||||
Args:
|
||||
idx (int): Индекс в dataset.
|
||||
|
||||
Returns:
|
||||
dict: {'input_ids': torch.Tensor [block_size], 'labels': torch.Tensor [block_size]}
|
||||
"""
|
||||
input_ids = torch.tensor(self.examples[idx], dtype=torch.long)
|
||||
labels = input_ids.clone()
|
||||
return {"input_ids": input_ids, "labels": labels}
|
||||
3
llm/src/llm/models/gemma/__init__.py
Normal file
3
llm/src/llm/models/gemma/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .gemma import Gemma
|
||||
|
||||
__all__ = ["Gemma"]
|
||||
346
llm/src/llm/models/gemma/gemma.py
Normal file
346
llm/src/llm/models/gemma/gemma.py
Normal file
@@ -0,0 +1,346 @@
|
||||
import torch
|
||||
import math
|
||||
from torch import nn
|
||||
from torch import Tensor
|
||||
import torch.nn.functional as F
|
||||
from math import sqrt
|
||||
from llm.core.base_model import BaseModel
|
||||
from llm.core.token_embeddings import TokenEmbeddings
|
||||
from llm.core.rope import RoPE
|
||||
from llm.core.rms_norm import RMSNorm
|
||||
from llm.core.gemma_decoder import GemmaDecoder
|
||||
|
||||
|
||||
class Gemma(BaseModel):
|
||||
"""
|
||||
Gemma — языковая трансформер-модель от Google, с архитектурой, оптимизированной для open-source и research-комьюнити.
|
||||
|
||||
Назначение:
|
||||
-----------
|
||||
Модель Gemma реализует стек современных декодерных блоков (GemmaDecoder), поддерживает rotary-позиционирование, multi-query self-attention,
|
||||
эффективный режим генерации (KV-cache), dropout, compact residual connections, базируется на best-practice LLM-инженерии последних лет.
|
||||
Поддерживает batched-тренировку и inference, генерацию с различными стратегиями выборки (greedy, top-k, top-p), автосохранение.
|
||||
|
||||
Архитектурные особенности:
|
||||
--------------------------
|
||||
- Stack из N слоёв GemmaDecoder (attention с Multi-Query либо Grouped heads, FFN с GeGLU/SwiGLU)
|
||||
- RMSNorm или LayerNorm для стабилизации
|
||||
- Dropout для регуляризации
|
||||
- Rotary Position Embedding (RoPE) для позиционных кодов
|
||||
- Выходная проекция (linear → logits) к словарю токенов
|
||||
- Полная поддержка cache для ускорения autoregressive генерации
|
||||
|
||||
Конфиг/Параметры конструктора:
|
||||
------------------------------
|
||||
config : dict
|
||||
Словарь c параметрами модели:
|
||||
- vocab_size : int — размер словаря
|
||||
- embed_dim : int — размер скрытого (hidden) пространства
|
||||
- max_position_embeddings : int — максимальная длина последовательности
|
||||
- num_layers : int — количество декодерных блоков
|
||||
- num_q_heads : int — количество attention голов (Queries)
|
||||
- num_kv_heads : int — количество ключевых/значенческих attention голов
|
||||
- dropout : float — Dropout率
|
||||
- ... (доп. гиперпараметры, требуемые GemmaDecoder'ами)
|
||||
|
||||
Основные методы:
|
||||
----------------
|
||||
- forward(x, use_cache=True, cache=None): выдает батч логитов по токенам, возвращает при необходимости обновленный cache.
|
||||
- generate(...): автотекстогенерация с greedy, temperature, top-k/p sampling, поддержкой кэша (ускорение inference).
|
||||
- save(path)/load(path, device): сохранение и загрузка предобученных весов, параметров и состояния.
|
||||
|
||||
Пример:
|
||||
-------
|
||||
>>> config = {...} # словарь с параметрами
|
||||
>>> model = Gemma(config)
|
||||
>>> x = torch.randint(0, config["vocab_size"], (4, 64))
|
||||
>>> logits, cache = model(x, use_cache=True)
|
||||
>>> print(logits.shape) # [4, 64, vocab_size]
|
||||
>>> out = model.generate(x, max_new_tokens=20, do_sample=True, top_k=10, temperature=0.8)
|
||||
|
||||
Литература и ссылки:
|
||||
--------------------
|
||||
- Gemma: https://ai.google.dev/gemma (официальная страница)
|
||||
- Разработка и архитектура: https://arxiv.org/abs/2403.07794
|
||||
- Rotary Embedding: https://arxiv.org/abs/2104.09864
|
||||
- Multi-Query Attention: https://arxiv.org/abs/1911.02150
|
||||
- Llama: https://arxiv.org/abs/2302.13971
|
||||
"""
|
||||
def __init__(self, config):
|
||||
"""
|
||||
Конструктор класса Gemma.
|
||||
|
||||
Позволяет создать объект языковой модели с архитектурой Gemma и
|
||||
произвольной конфигурацией (гибкая поддержка разных масштабов, ширин, глубин).
|
||||
|
||||
Аргументы:
|
||||
----------
|
||||
config : dict
|
||||
Словарь со всеми необходимыми гиперпараметрами и архитектурными детальями модели Gemma.
|
||||
Ожидаемые ключи (группы параметров):
|
||||
- vocab_size : int — размер словаря токенов (размерность входа/выхода)
|
||||
- embed_dim : int — скрытый размер эмбеддинга (hidden dim)
|
||||
- max_position_embeddings : int — максимальная длина последовательности
|
||||
- num_layers : int — количество декодерных блоков (глубина стека)
|
||||
- num_q_heads : int — число attention голов (Query heads)
|
||||
- num_kv_heads : int — число голов для Key/Value (MultiQuery Attention)
|
||||
- dropout : float — Dropout для регуляризации
|
||||
- остальные специфичные для GemmaDecoder'ов параметры
|
||||
|
||||
Внутри:
|
||||
-------
|
||||
- Инициализируются модули эмбеддинга токенов, позиционного кодирования (RoPE) и Dropout,
|
||||
стек декодеров (GemmaDecoder(...)), слой финальной нормализации и выходная проекция (linear).
|
||||
- Все архитектурные параметры напрямую берутся из config.
|
||||
|
||||
Пример:
|
||||
-------
|
||||
>>> config = {
|
||||
... "vocab_size": 32000,
|
||||
... "embed_dim": 512,
|
||||
... "max_position_embeddings": 2048,
|
||||
... "num_layers": 24,
|
||||
... "num_q_heads": 8,
|
||||
... "num_kv_heads": 4,
|
||||
... "dropout": 0.1,
|
||||
... }
|
||||
>>> model = Gemma(config)
|
||||
|
||||
Примечание:
|
||||
-----------
|
||||
- Внимание: значения config должны быть согласованы друг с другом! Например, embed_dim должен быть кратным num_q_heads и т.д.
|
||||
- Поддерживается дальнейшая кастомизация стека декодеров через ключи в config.
|
||||
"""
|
||||
super().__init__(config)
|
||||
|
||||
self._max_seq_len = config["max_position_embeddings"]
|
||||
|
||||
# Инициализация слоев
|
||||
self._token_embeddings = TokenEmbeddings(
|
||||
vocab_size=config["vocab_size"],
|
||||
emb_size=config["embed_dim"]
|
||||
)
|
||||
self._position_embeddings = RoPE(
|
||||
head_size=config["embed_dim"] // config["num_q_heads"],
|
||||
max_seq_len=config["max_position_embeddings"]
|
||||
)
|
||||
#self._position_embeddings = PositionalEmbeddings(
|
||||
# max_seq_len=max_seq_len,
|
||||
# emb_size=emb_size
|
||||
#)
|
||||
self._dropout = nn.Dropout(config["dropout"])
|
||||
self._decoders = nn.ModuleList([GemmaDecoder(
|
||||
num_q_heads=config["num_q_heads"],
|
||||
emb_size=config["embed_dim"],
|
||||
head_size=config["embed_dim"] // config["num_q_heads"],
|
||||
max_seq_len=config["max_position_embeddings"],
|
||||
rope=self._position_embeddings,
|
||||
dropout=config["dropout"]
|
||||
) for _ in range(config["num_layers"])])
|
||||
self._norm = RMSNorm(config["embed_dim"])
|
||||
self._linear = nn.Linear(config["embed_dim"], config["vocab_size"])
|
||||
|
||||
def forward(self, x: torch.Tensor, use_cache: bool = True, cache: list = None) -> tuple:
|
||||
"""
|
||||
Прямой проход (forward) через полную модель Gemma.
|
||||
|
||||
Трансформирует входную последовательность токенов через стек из декодерных блоков GemmaDecoder.
|
||||
Возвращает логиты по всем токенам и (при необходимости) кэш attention для быстрой autoregressive-генерации.
|
||||
|
||||
Аргументы:
|
||||
----------
|
||||
x : torch.Tensor
|
||||
Входной тензор shape [batch_size, seq_len], содержащий токен-IDs.
|
||||
use_cache : bool, по умолчанию True
|
||||
Если True — сохраняет и возвращает KV-кэш attention (ускоряет автогенерацию).
|
||||
Если False — кэш не используется.
|
||||
cache : list, optional
|
||||
(Необязательно) Список/None: с кэшами KV-матриц для каждого слоя (для режима генерации статей/диalogов).
|
||||
|
||||
Возвращает:
|
||||
-----------
|
||||
tuple:
|
||||
- logits : torch.Tensor shape [batch_size, seq_len, vocab_size]
|
||||
Логиты по словарю для каждого токена (input + сколь угодно новых).
|
||||
- new_cache : list или None
|
||||
Обновлённый cache (если use_cache=True).
|
||||
|
||||
Пример:
|
||||
-------
|
||||
>>> logits, new_cache = model(x, use_cache=True, cache=None)
|
||||
>>> logits.shape # [batch_size, seq_len, vocab_size]
|
||||
|
||||
Примечания:
|
||||
-----------
|
||||
- Используется при обучении и инференсе.
|
||||
- Если нужно только инференс last-token — используйте logits[:, -1, :].
|
||||
- При превышении x.shape[1] > max_seq_len выдаёт ValueError.
|
||||
"""
|
||||
# Проверка длины последовательности (только при отсутствии кэша)
|
||||
if cache is None and x.size(1) > self._max_seq_len:
|
||||
raise ValueError(f"Длина последовательности {x.size(1)} превышает максимальную {self.max_seq_len}")
|
||||
|
||||
# Эмбеддинги токенов и позиций
|
||||
tok_out = self._token_embeddings(x) # [batch, seq_len, emb_size]
|
||||
#pos_out = self._position_embeddings(x) # [batch, seq_len, emb_size]
|
||||
|
||||
# Комбинирование
|
||||
out = self._dropout(tok_out) # [batch, seq_len, emb_size]
|
||||
|
||||
# Стек декодеров с передачей кэша
|
||||
new_cache = []
|
||||
for i, decoder in enumerate(self._decoders):
|
||||
decoder_cache = cache[i] if cache is not None else None
|
||||
decoder_result = decoder(out, use_cache=use_cache, cache=decoder_cache)
|
||||
|
||||
# Извлекаем результат из кортежа
|
||||
if use_cache:
|
||||
out, decoder_new_cache = decoder_result
|
||||
new_cache.append(decoder_new_cache)
|
||||
else:
|
||||
out = decoder_result[0]
|
||||
|
||||
out = self._norm(out)
|
||||
logits = self._linear(out)
|
||||
|
||||
# Возвращаем результат с учетом use_cache
|
||||
if use_cache:
|
||||
return (logits, new_cache)
|
||||
else:
|
||||
return (logits, None)
|
||||
|
||||
def generate(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
max_new_tokens: int,
|
||||
do_sample: bool,
|
||||
temperature: float = 1.0,
|
||||
top_k: int = None,
|
||||
top_p: float = None,
|
||||
use_cache: bool = True,
|
||||
attention_mask: torch.Tensor = None,
|
||||
**kwargs
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Авторегрессивная генерация токенов с использованием greedy, temperature, top-k и top-p sampling.
|
||||
Реализует generation-loop с обновлением attention-кэша для ускорения инференса.
|
||||
|
||||
Аргументы:
|
||||
----------
|
||||
x : torch.Tensor
|
||||
Входной тензор с последовательностью токенов (shape [batch_size, seq_len]), который необходимо продолжить.
|
||||
max_new_tokens : int
|
||||
Сколько новых токенов сгенерировать (максимум).
|
||||
do_sample : bool
|
||||
Если True — сэмплирует следующий токен согласно распределению вероятностей (stochastic), иначе выбирает токен с максимальной вероятностью (greedy).
|
||||
temperature : float, default=1.0
|
||||
Параметр для шкалирования распределения вероятностей логитов. Больше 1.0 — больше случайности, меньше 1.0 — более детерминированный (жёсткий) выбор.
|
||||
top_k : int, optional
|
||||
Если задано — для сэмплирования учитываются только top_k наиболее вероятных токенов.
|
||||
top_p : float, optional
|
||||
Если задано — работают nucleus sampling: учитываются токены, суммарная вероятность которых не превышает top_p.
|
||||
use_cache : bool, default=True
|
||||
Если True — для ускорения использует и обновляет attention-кэши (KV-cache).
|
||||
|
||||
Возвращает:
|
||||
-----------
|
||||
torch.Tensor
|
||||
Тензор shape [batch_size, seq_len + max_new_tokens] с исходными и сгенерированными токенами (token IDs).
|
||||
|
||||
Пример:
|
||||
-------
|
||||
>>> out = model.generate(
|
||||
... x, max_new_tokens=20, do_sample=True, temperature=0.8, top_k=50
|
||||
... )
|
||||
>>> print(out.shape) # [batch_size, seq_len+20]
|
||||
|
||||
Примечания:
|
||||
-----------
|
||||
- Нельзя указывать одновременно top_k и top_p (будет выброшено исключение).
|
||||
- temperature <= 0 некорректно (будет выброшено исключение).
|
||||
- Поддержка cache (use_cache=True) значительно ускоряет генерацию длинных последовательностей и позволяет использовать beam search/decoding.
|
||||
- Для воспроизводимых результатов установите torch.manual_seed перед генерацией.
|
||||
- Метод возвращает только token_ids, если нужны logits — используйте .forward напрямую.
|
||||
|
||||
Литература:
|
||||
-----------
|
||||
- Holtzman et al., "The Curious Case of Neural Text Degeneration" (nucleus/top-p sampling): https://arxiv.org/abs/1904.09751
|
||||
- Gemma: https://arxiv.org/abs/2403.07794
|
||||
"""
|
||||
|
||||
cache = None
|
||||
|
||||
for _ in range(max_new_tokens):
|
||||
if use_cache and cache is not None:
|
||||
# Используем кэш - передаем только последний токен
|
||||
x_input = x[:, -1:] # [batch_size, 1]
|
||||
else:
|
||||
# Первая итерация или кэш отключен - передаем всю последовательность
|
||||
x_input = x
|
||||
|
||||
# Прямой проход с кэшем
|
||||
logits, new_cache = self.forward(x_input, use_cache=use_cache, cache=cache)
|
||||
|
||||
# Обновляем кэш для следующей итерации
|
||||
if use_cache:
|
||||
cache = new_cache
|
||||
|
||||
last_logits = logits[:, -1, :] # [batch_size, vocab_size]
|
||||
|
||||
# Масштабируем логиты температурой
|
||||
if temperature > 0:
|
||||
logits_scaled = last_logits / temperature
|
||||
else:
|
||||
logits_scaled = last_logits
|
||||
|
||||
if do_sample == True and top_k != None:
|
||||
_, topk_indices = torch.topk(logits_scaled, top_k, dim=-1)
|
||||
|
||||
# # Заменим все НЕ top-k логиты на -inf
|
||||
masked_logits = logits_scaled.clone()
|
||||
vocab_size = logits_scaled.size(-1)
|
||||
|
||||
# создаём маску: 1, если токен НЕ в topk_indices
|
||||
mask = torch.ones_like(logits_scaled, dtype=torch.bool if hasattr(torch, "bool") else torch.uint8)
|
||||
mask.scatter_(1, topk_indices, False if hasattr(torch, "bool") else 0) # 0 там, где top-k индексы
|
||||
masked_logits[mask.bool() if hasattr(torch, "bool") else mask.byte()] = float('-inf')
|
||||
|
||||
logits_scaled = masked_logits
|
||||
|
||||
if do_sample == True and top_p != None:
|
||||
# 1. Применим softmax, чтобы получить вероятности:
|
||||
probs = F.softmax(logits_scaled, dim=-1) # [B, vocab_size]
|
||||
# 2. Отсортируем токены по убыванию вероятностей:
|
||||
sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1)
|
||||
# 3. Посчитаем кумулятивную сумму вероятностей:
|
||||
cum_probs = torch.cumsum(sorted_probs, dim=-1) # [B, vocab_size]
|
||||
# 4. Определим маску: оставить токены, пока сумма < top_p
|
||||
sorted_mask = (cum_probs <= top_p).bool() if hasattr(torch, "bool") else (cum_probs <= top_p).byte() # [B, vocab_size]
|
||||
# Гарантируем, что хотя бы первый токен останется
|
||||
sorted_mask[:, 0] = True if hasattr(torch, "bool") else 1
|
||||
# 5. Преобразуем маску обратно в оригинальный порядок:
|
||||
# Создаём полную маску из 0
|
||||
mask = torch.zeros_like(probs, dtype=torch.bool if hasattr(torch, "bool") else torch.uint8)
|
||||
# Устанавливаем 1 в местах нужных токенов
|
||||
mask.scatter_(dim=1, index=sorted_indices, src=sorted_mask)
|
||||
# 6. Зануляем логиты токенов вне топ-p:
|
||||
logits_scaled[~mask] = float('-inf')
|
||||
|
||||
# 4. Применяем Softmax
|
||||
probs = F.softmax(logits_scaled, dim=-1) # [batch_size, vocab_size]
|
||||
|
||||
|
||||
if do_sample == True:
|
||||
# 5. Если do_sample равен True, то отбираем токен случайно с помощью torch.multinomial
|
||||
next_token = torch.multinomial(probs, num_samples=1) # [batch_size, 1]
|
||||
else:
|
||||
# 5. Если do_sample равен False, то выбираем токен с максимальной вероятностью
|
||||
next_token = torch.argmax(probs, dim=-1, keepdim=True) # [batch_size, 1]
|
||||
|
||||
# 6. Добавляем его к последовательности
|
||||
x = torch.cat([x, next_token], dim=1) # [batch_size, seq_len+1]
|
||||
return x
|
||||
|
||||
@property
|
||||
def max_seq_len(self) -> int:
|
||||
return self._max_seq_len
|
||||
@@ -26,212 +26,272 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from typing import Optional, Dict
|
||||
from llm.core.base_model import BaseModel
|
||||
from llm.core.decoder import Decoder
|
||||
from llm.core.gpt_decoder import GptDecoder
|
||||
from llm.core.token_embeddings import TokenEmbeddings
|
||||
from llm.core.positional_embeddings import PositionalEmbeddings
|
||||
|
||||
|
||||
class GPT(BaseModel):
|
||||
"""
|
||||
Original GPT (Generative Pre-trained Transformer) модель.
|
||||
|
||||
Первая версия трансформерной архитектуры от OpenAI, предназначенная
|
||||
для генеративного предобучения на текстовых данных.
|
||||
|
||||
Args:
|
||||
config: Словарь конфигурации с параметрами:
|
||||
- vocab_size: Размер словаря токенов
|
||||
- embed_dim: Размерность векторных представлений
|
||||
- num_heads: Количество голов внимания
|
||||
- num_layers: Количество декодерных слоев
|
||||
- max_position_embeddings: Максимальная длина последовательности
|
||||
- dropout: Вероятность dropout
|
||||
|
||||
Attributes:
|
||||
_token_embeddings: Слой векторных представлений токенов
|
||||
_position_embeddings: Слой позиционных эмбеддингов
|
||||
_decoders: Список декодерных слоев
|
||||
_norm: Финальный слой нормализации
|
||||
_linear: Выходной линейный слой
|
||||
GPT (Generative Pretrained Transformer) — автогерессивная языковая модель по мотивам оригинального GPT/GPT-2 architecture.
|
||||
|
||||
Назначение:
|
||||
-----------
|
||||
- Позволяет предсказывать и генерировать последовательности текста, обучаясь на задаче language modeling (предсказывать следующий токен).
|
||||
- Класс реализует архитектуру classic Transformer Decoder Stack с masked multi-head attention и token/positional embeddings.
|
||||
- Используется как базовая модель для генерации, zero-/few-shot, задач обучения с подкреплением и пр.
|
||||
|
||||
Архитектурные особенности:
|
||||
--------------------------
|
||||
- Embedding-слои для токенов (token_embeddings) и позиций (position_embeddings).
|
||||
- Stack из N декодер-блоков (MultiHeadAttention + FeedForward + residual + LayerNorm).
|
||||
- Masked self-attention — каждый токен видит только свои и предыдущие, обеспечивая автогерессию.
|
||||
- LayerNorm до проекции на словарь (pre-LN).
|
||||
- Поддержка efficient KV кэша — ускоряет autoregressive inference/generation.
|
||||
|
||||
Основные параметры:
|
||||
-------------------
|
||||
config: dict в формате {
|
||||
vocab_size, # размер словаря токенов
|
||||
embed_dim, # размерность эмбеддинга
|
||||
num_heads, # количество attention heads
|
||||
num_layers, # глубина модели (число блоков)
|
||||
max_position_embeddings,
|
||||
dropout
|
||||
}
|
||||
|
||||
Формула и поток данных:
|
||||
-----------------------
|
||||
x -> token_embeddings -> + position_embeddings -> dropout ->
|
||||
-> stack([DecoderBlock]) ->
|
||||
-> LayerNorm ->
|
||||
-> Linear(out_dim=vocab_size) -> output_logits
|
||||
|
||||
Пример использования:
|
||||
---------------------
|
||||
>>> gpt = GPT({...})
|
||||
>>> tokens = torch.tensor([[12, 123, 44]])
|
||||
>>> logits = gpt(tokens)
|
||||
>>> generated = gpt.generate(tokens, max_new_tokens=10)
|
||||
|
||||
References:
|
||||
-----------
|
||||
- Radford et al., "Improving Language Understanding by Generative Pre-Training" (GPT-1, 2018)
|
||||
https://cdn.openai.com/research-covers/languageunsupervised/language_understanding_paper.pdf
|
||||
- Original BPE Tokenizer code: https://github.com/openai/gpt-2/blob/master/src/encoder.py
|
||||
- Формула masked self-attention: Vaswani et al., "Attention is All You Need", 2017
|
||||
https://arxiv.org/abs/1706.03762
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
"""
|
||||
Инициализация модели GPT.
|
||||
|
||||
Args:
|
||||
-----
|
||||
config: dict
|
||||
Параметры архитектуры:
|
||||
vocab_size: int — размер словаря токенов
|
||||
embed_dim: int — размерность эмбеддинга
|
||||
num_heads: int — количество attention-heads
|
||||
num_layers: int — число Transformer блоков
|
||||
max_position_embeddings: int — макс. длина последовательности
|
||||
dropout: float — dropout
|
||||
|
||||
Внутри:
|
||||
-------
|
||||
- Создаёт слой эмбеддингов, позиционку, стек декодеров, нормализацию, линейную проекцию.
|
||||
"""
|
||||
super().__init__(config)
|
||||
|
||||
# Инициализация слоев
|
||||
self._max_seq_len = config["max_position_embeddings"]
|
||||
self._token_embeddings = TokenEmbeddings(
|
||||
vocab_size=config["vocab_size"],
|
||||
emb_size=config["embed_dim"]
|
||||
vocab_size=config["vocab_size"], emb_size=config["embed_dim"]
|
||||
)
|
||||
self._position_embeddings = PositionalEmbeddings(
|
||||
max_seq_len=config["max_position_embeddings"],
|
||||
emb_size=config["embed_dim"]
|
||||
max_seq_len=config["max_position_embeddings"], emb_size=config["embed_dim"]
|
||||
)
|
||||
self._dropout = nn.Dropout(config["dropout"])
|
||||
# head_size = emb_size // num_heads
|
||||
self._decoders = nn.ModuleList([Decoder(
|
||||
num_heads=config["num_heads"],
|
||||
emb_size=config["embed_dim"],
|
||||
head_size=config["embed_dim"] // config["num_heads"],
|
||||
max_seq_len=config["max_position_embeddings"],
|
||||
dropout=config["dropout"]
|
||||
) for _ in range(config["num_layers"])])
|
||||
self._decoders = nn.ModuleList(
|
||||
[
|
||||
GptDecoder(
|
||||
num_heads=config["num_heads"],
|
||||
emb_size=config["embed_dim"],
|
||||
head_size=config["embed_dim"] // config["num_heads"],
|
||||
max_seq_len=config["max_position_embeddings"],
|
||||
dropout=config["dropout"],
|
||||
)
|
||||
for _ in range(config["num_layers"])
|
||||
]
|
||||
)
|
||||
self._linear = nn.Linear(config["embed_dim"], config["vocab_size"])
|
||||
|
||||
|
||||
@property
|
||||
def max_seq_len(self):
|
||||
"""Возвращает максимальную длину последовательности."""
|
||||
return self._max_seq_len
|
||||
|
||||
def forward(self, x: torch.Tensor, attention_mask=None) -> torch.Tensor:
|
||||
"""Прямой проход через GPT
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor, attention_mask=None, use_cache: bool = True, cache: list = None
|
||||
) -> tuple:
|
||||
"""
|
||||
Прямой проход для получения логитов по последовательности токенов.
|
||||
|
||||
Args:
|
||||
x: Входной тензор [batch_size, seq_len]
|
||||
|
||||
-----
|
||||
x : torch.Tensor [batch, seq_len]
|
||||
Индексы входных токенов.
|
||||
use_cache : bool, optional
|
||||
Использовать ли кэш attention (ускоряет инференс, важно для генерации)
|
||||
cache : list, optional
|
||||
Список старых KV (key/value)-кэшей
|
||||
|
||||
Returns:
|
||||
Тензор логитов [batch_size, seq_len, vocab_size]
|
||||
--------
|
||||
logits: [batch, seq_len, vocab_size] (логиты для softmax по словарю)
|
||||
new_cache: кэш KV после прохода
|
||||
"""
|
||||
# Проверка длины последовательности
|
||||
if x.size(1) > self._max_seq_len:
|
||||
raise ValueError(f"Длина последовательности {x.size(1)} превышает максимальную {self._max_seq_len}")
|
||||
|
||||
raise ValueError(
|
||||
f"Длина последовательности {x.size(1)} превышает максимальную {self._max_seq_len}"
|
||||
)
|
||||
|
||||
# Вычисление start_pos из кэша (если кэш передан)
|
||||
if cache is not None:
|
||||
seq_len = 1
|
||||
# Безопасно извлекаем key_cache для вычисления start_pos
|
||||
if (
|
||||
isinstance(cache, (list, tuple))
|
||||
and len(cache) > 0
|
||||
and cache[0] is not None
|
||||
and isinstance(cache[0], (list, tuple))
|
||||
and len(cache[0]) > 0
|
||||
and cache[0][0] is not None
|
||||
and isinstance(cache[0][0], (tuple, list))
|
||||
and len(cache[0][0]) > 0
|
||||
):
|
||||
key_cache, _ = cache[0][0]
|
||||
start_pos = key_cache.size(1)
|
||||
else:
|
||||
start_pos = 0
|
||||
else:
|
||||
# Без кэша работаем как раньше
|
||||
start_pos = 0
|
||||
seq_len = x.size(1)
|
||||
|
||||
# Эмбеддинги токенов и позиций
|
||||
tok_out = self._token_embeddings(x) # [batch, seq_len, emb_size]
|
||||
pos_out = self._position_embeddings(x.size(1)) # [seq_len, emb_size]
|
||||
|
||||
pos_out = self._position_embeddings(
|
||||
seq_len, start_pos=start_pos
|
||||
) # [seq_len, emb_size]
|
||||
|
||||
# Комбинирование
|
||||
out = self._dropout(tok_out + pos_out.unsqueeze(0)) # [batch, seq_len, emb_size]
|
||||
|
||||
# Стек декодеров
|
||||
for decoder in self._decoders:
|
||||
out = decoder(out)
|
||||
|
||||
return self._linear(out) # [batch, seq_len, vocab_size]
|
||||
out = self._dropout(
|
||||
tok_out + pos_out.unsqueeze(0)
|
||||
) # [batch, seq_len, emb_size]
|
||||
|
||||
# Стек декодеров с передачей кэша
|
||||
new_cache = []
|
||||
for i, decoder in enumerate(self._decoders):
|
||||
decoder_cache = cache[i] if cache is not None else None
|
||||
decoder_result = decoder(out, use_cache=use_cache, cache=decoder_cache)
|
||||
|
||||
# def forward(self, input_ids, attention_mask=None):
|
||||
# B, T = input_ids.size()
|
||||
# pos = torch.arange(0, T, device=input_ids.device).unsqueeze(0)
|
||||
#
|
||||
# x = self.token_emb(input_ids) + self.pos_emb(pos)
|
||||
#
|
||||
# for block in self.blocks:
|
||||
# x = block(x, attention_mask)
|
||||
#
|
||||
# x = self.ln_f(x)
|
||||
# logits = self.head(x)
|
||||
# return logits
|
||||
# Извлекаем результат из кортежа
|
||||
if use_cache:
|
||||
out, decoder_new_cache = decoder_result
|
||||
new_cache.append(decoder_new_cache)
|
||||
else:
|
||||
out = decoder_result[0]
|
||||
|
||||
logits = self._linear(out) # [batch, seq_len, vocab_size]
|
||||
|
||||
def generate(self,
|
||||
x: torch.Tensor,
|
||||
max_new_tokens: int,
|
||||
# Возвращаем результат с учетом use_cache
|
||||
if use_cache:
|
||||
return (logits, new_cache)
|
||||
else:
|
||||
return (logits, None)
|
||||
|
||||
def generate(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
max_new_tokens: int,
|
||||
do_sample: bool,
|
||||
temperature: float = 1.0,
|
||||
top_k: int = None,
|
||||
top_p: float = None,
|
||||
attention_mask: torch.Tensor = None, # Добавляем для совместимости с HF
|
||||
**kwargs # Игнорируем остальные параметры
|
||||
use_cache: bool = True,
|
||||
attention_mask: torch.Tensor = None,
|
||||
**kwargs
|
||||
) -> torch.Tensor:
|
||||
"""Авторегрессивная генерация текста.
|
||||
|
||||
Параметры:
|
||||
x: Входной тензор с индексами токенов формы [batch_size, seq_len],
|
||||
где batch_size - размер батча, seq_len - длина последовательности.
|
||||
max_new_tokens: Максимальное количество новых токенов для генерации.
|
||||
do_sample: Флаг выбора режима генерации:
|
||||
- True: вероятностное сэмплирование
|
||||
- False: жадный поиск (argmax)
|
||||
temperature: Параметр температуры для сэмплирования:
|
||||
- >1.0 - более случайные результаты
|
||||
- 1.0 - нейтральное значение
|
||||
- <1.0 - более предсказуемые результаты
|
||||
Должна быть > 0 (по умолчанию: 1.0)
|
||||
top_k: Если задан (и do_sample=True), используется top-k сэмплирование:
|
||||
- Выбираются только top_k самых вероятных токенов
|
||||
- Остальным токенам устанавливается вероятность 0
|
||||
- None: отключено (по умолчанию)
|
||||
top_p: Если задан (и do_sample=True), используется nucleus (top-p) сэмплирование:
|
||||
- Выбираются токены с кумулятивной вероятностью ≤ top_p
|
||||
- Гарантируется, что хотя бы один токен остаётся (даже если его вероятность > top_p)
|
||||
- None: отключено (по умолчанию)
|
||||
- Должен быть в диапазоне (0, 1]
|
||||
|
||||
Возвращает:
|
||||
torch.Tensor: Тензор с расширенной последовательностью токенов формы
|
||||
[batch_size, seq_len + max_new_tokens]
|
||||
|
||||
Исключения:
|
||||
ValueError: Если входная последовательность длиннее max_seq_len
|
||||
ValueError: Если temperature <= 0
|
||||
ValueError: Если одновременно заданы top_k и top_p
|
||||
ValueError: Если top_k задан и ≤ 0
|
||||
ValueError: Если top_p задан и не в диапазоне (0, 1]
|
||||
|
||||
Примеры:
|
||||
>>> # Жадная генерация
|
||||
>>> output = model.generate(input_ids, max_new_tokens=10, do_sample=False)
|
||||
>>>
|
||||
>>> # Вероятностная генерация с top-k
|
||||
>>> output = model.generate(input_ids, max_new_tokens=10, do_sample=True, top_k=50)
|
||||
>>>
|
||||
>>> # Nucleus sampling (top-p)
|
||||
>>> output = model.generate(input_ids, max_new_tokens=10, do_sample=True, top_p=0.9)
|
||||
>>>
|
||||
>>> # Комбинация температуры и top-k
|
||||
>>> output = model.generate(input_ids, max_new_tokens=10, do_sample=True,
|
||||
... temperature=0.7, top_k=50)
|
||||
|
||||
Примечания:
|
||||
1. Для детерминированных результатов в режиме сэмплирования
|
||||
зафиксируйте random seed (torch.manual_seed).
|
||||
2. Температура влияет только на режим сэмплирования (do_sample=True).
|
||||
3. Одновременное использование top_k и top_p запрещено.
|
||||
4. При do_sample=False параметры top_k, top_p и temperature игнорируются.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Входной тензор с индексами токенов формы [batch_size, seq_len],
|
||||
где batch_size - размер батча, seq_len - длина последовательности.
|
||||
max_new_tokens (int): Максимальное количество новых токенов для генерации.
|
||||
do_sample (bool): Флаг выбора режима генерации:
|
||||
- True: вероятностное сэмплирование
|
||||
- False: жадный поиск (argmax)
|
||||
temperature (float): Параметр температуры для сэмплирования:
|
||||
- >1.0 - более случайные результаты
|
||||
- 1.0 - нейтральное значение
|
||||
- <1.0 - более предсказуемые результаты
|
||||
Должна быть > 0 (по умолчанию: 1.0)
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Тензор с расширенной последовательностью токенов формы
|
||||
[batch_size, seq_len + max_new_tokens]
|
||||
|
||||
Raises:
|
||||
ValueError: Если входная последовательность длиннее max_seq_len
|
||||
ValueError: Если temperature <= 0
|
||||
|
||||
Examples:
|
||||
>>> # Жадная генерация
|
||||
>>> output = model.generate(input_ids, max_new_tokens=10, do_sample=False)
|
||||
>>>
|
||||
>>> # Вероятностная генерация с температурой
|
||||
>>> output = model.generate(input_ids, max_new_tokens=10, do_sample=True, temperature=0.7)
|
||||
>>>
|
||||
>>> # Более случайная генерация
|
||||
>>> output = model.generate(input_ids, max_new_tokens=10, do_sample=True, temperature=1.5)
|
||||
|
||||
Note:
|
||||
Для детерминированных результатов в режиме сэмплирования
|
||||
зафиксируйте random seed (torch.manual_seed).
|
||||
Температура влияет только на режим сэмплирования (do_sample=True).
|
||||
"""
|
||||
Авторегрессивная генерация текста с поддержкой жадного поиска (greedy), вероятностного сэмплирования с температурой,
|
||||
top-k и nucleus (top-p) sampling.
|
||||
|
||||
Аргументы:
|
||||
x (torch.Tensor): Входной тензор с индексами токенов, форма [batch_size, seq_len].
|
||||
max_new_tokens (int): Максимальное количество новых токенов для генерации.
|
||||
do_sample (bool): Если True — вероятностное сэмплирование; если False — жадная генерация (argmax).
|
||||
temperature (float): Температура для управления случайностью (>0, влияет только если do_sample=True).
|
||||
>1.0 — более случайно, <1.0 — более детерминированно.
|
||||
top_k (int, опц.): При do_sample=True ограничивает выбор top_k самых вероятных токенов (top-k sampling).
|
||||
top_p (float, опц.): При do_sample=True включает top-p (nucleus) sampling: кумулятивная вероятность ≤ top_p.
|
||||
Должно быть в (0, 1].
|
||||
attention_mask (torch.Tensor, опц.): Внешняя маска внимания (для совместимости с HuggingFace).
|
||||
**kwargs: Игнорируются.
|
||||
|
||||
Возвращает:
|
||||
torch.Tensor: Последовательность токенов [batch_size, seq_len + max_new_tokens].
|
||||
|
||||
Исключения:
|
||||
ValueError: Если x длиннее max_seq_len модели.
|
||||
ValueError: Если temperature ≤ 0.
|
||||
ValueError: Если одновременно заданы top_k и top_p.
|
||||
ValueError: Если top_k ≤ 0.
|
||||
ValueError: Если top_p вне диапазона (0, 1].
|
||||
|
||||
Примеры:
|
||||
>>> # Жадная (детерминированная) генерация
|
||||
>>> output = model.generate(input_ids, max_new_tokens=12, do_sample=False)
|
||||
>>> # Вероятностная генерация с температурой
|
||||
>>> output = model.generate(input_ids, max_new_tokens=12, do_sample=True, temperature=0.8)
|
||||
>>> # Top-k сэмплирование
|
||||
>>> output = model.generate(input_ids, max_new_tokens=12, do_sample=True, top_k=50)
|
||||
>>> # Top-p (nucleus) sampling
|
||||
>>> output = model.generate(input_ids, max_new_tokens=12, do_sample=True, top_p=0.92)
|
||||
>>> # Комбинация температуры и top-k
|
||||
>>> output = model.generate(input_ids, max_new_tokens=12, do_sample=True, temperature=1.0, top_k=100)
|
||||
|
||||
Примечания:
|
||||
- Для детерминированных выборок зафиксируйте random seed через torch.manual_seed.
|
||||
- Параметры temperature, top_k, top_p применимы только если do_sample=True.
|
||||
- Одновременное использование top_k и top_p не допускается.
|
||||
- Модель всегда возвращает тензор индексов токенов; для получения логитов используйте прямой вызов forward.
|
||||
|
||||
Ссылки:
|
||||
- Holtzman et al., "The Curious Case of Neural Text Degeneration" (nucleus sampling): https://arxiv.org/abs/1904.09751
|
||||
- Оригинальный GPT-2: https://cdn.openai.com/better-language-models/language-models.pdf
|
||||
"""
|
||||
cache = None
|
||||
|
||||
for _ in range(max_new_tokens):
|
||||
# 1. Обрезаем вход, если последовательность слишком длинная
|
||||
x_cond = x[:, -self._max_seq_len:]
|
||||
if use_cache and cache is not None:
|
||||
# Используем кэш - передаем только последний токен
|
||||
x_input = x[:, -1:] # [batch_size, 1]
|
||||
else:
|
||||
# Первая итерация или кэш отключен - передаем всю последовательность
|
||||
x_input = x
|
||||
|
||||
# 2. Передаем последовательность в метод forward класса GPT и полуаем логиты.
|
||||
logits = self.forward(x_cond)
|
||||
# Прямой проход с кэшем
|
||||
logits, new_cache = self.forward(x_input, use_cache=use_cache, cache=cache)
|
||||
|
||||
# Обновляем кэш для следующей итерации
|
||||
if use_cache:
|
||||
cache = new_cache
|
||||
|
||||
# 3. Берем логиты для последнего токена
|
||||
last_logits = logits[:, -1, :] # [batch_size, vocab_size]
|
||||
@@ -250,9 +310,14 @@ class GPT(BaseModel):
|
||||
vocab_size = logits_scaled.size(-1)
|
||||
|
||||
# создаём маску: True, если токен НЕ в topk_indices
|
||||
mask = torch.ones_like(logits_scaled, dtype=torch.bool if hasattr(torch, 'bool') else torch.uint8)
|
||||
mask.scatter_(1, topk_indices, False if hasattr(torch, 'bool') else 0) # False там, где top-k индексы
|
||||
masked_logits[mask] = float('-inf')
|
||||
mask = torch.ones_like(
|
||||
logits_scaled,
|
||||
dtype=torch.bool if hasattr(torch, "bool") else torch.uint8,
|
||||
)
|
||||
mask.scatter_(
|
||||
1, topk_indices, False if hasattr(torch, "bool") else 0
|
||||
) # False там, где top-k индексы
|
||||
masked_logits[mask] = float("-inf")
|
||||
|
||||
logits_scaled = masked_logits
|
||||
|
||||
@@ -260,36 +325,42 @@ class GPT(BaseModel):
|
||||
# 1. Применим softmax, чтобы получить вероятности:
|
||||
probs = F.softmax(logits_scaled, dim=-1) # [B, vocab_size]
|
||||
# 2. Отсортируем токены по убыванию вероятностей:
|
||||
sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1)
|
||||
sorted_probs, sorted_indices = torch.sort(
|
||||
probs, descending=True, dim=-1
|
||||
)
|
||||
# 3. Посчитаем кумулятивную сумму вероятностей:
|
||||
cum_probs = torch.cumsum(sorted_probs, dim=-1) # [B, vocab_size]
|
||||
# 4. Определим маску: оставить токены, пока сумма < top_p
|
||||
sorted_mask = (cum_probs <= top_p) # [B, vocab_size]
|
||||
sorted_mask = cum_probs <= top_p # [B, vocab_size]
|
||||
# Гарантируем, что хотя бы первый токен останется
|
||||
sorted_mask[:, 0] = True
|
||||
# 5. Преобразуем маску обратно в оригинальный порядок:
|
||||
# Создаём полную маску из False
|
||||
mask = torch.zeros_like(probs, dtype=torch.bool if hasattr(torch, 'bool') else torch.uint8)
|
||||
mask = torch.zeros_like(
|
||||
probs, dtype=torch.bool if hasattr(torch, "bool") else torch.uint8
|
||||
)
|
||||
# Устанавливаем True в местах нужных токенов
|
||||
mask.scatter_(dim=1, index=sorted_indices, src=sorted_mask)
|
||||
# 6. Зануляем логиты токенов вне топ-p:
|
||||
logits_scaled[~mask] = float('-inf')
|
||||
logits_scaled[~mask] = float("-inf")
|
||||
|
||||
# 4. Применяем Softmax
|
||||
probs = F.softmax(logits_scaled, dim=-1) # [batch_size, vocab_size]
|
||||
|
||||
|
||||
if do_sample == True:
|
||||
# 5. Если do_sample равен True, то отбираем токен случайно с помощью torch.multinomial
|
||||
next_token = torch.multinomial(probs, num_samples=1) # [batch_size, 1]
|
||||
else:
|
||||
# 5. Если do_sample равен False, то выбираем токен с максимальной вероятностью
|
||||
next_token = torch.argmax(probs, dim=-1, keepdim=True) # [batch_size, 1]
|
||||
|
||||
next_token = torch.argmax(
|
||||
probs, dim=-1, keepdim=True
|
||||
) # [batch_size, 1]
|
||||
|
||||
# 6. Добавляем его к последовательности
|
||||
x = torch.cat([x, next_token], dim=1) # [batch_size, seq_len+1]
|
||||
return x
|
||||
|
||||
|
||||
# def generate(self, input_ids, max_length=50):
|
||||
# for _ in range(max_length):
|
||||
# logits = self.forward(input_ids)
|
||||
|
||||
@@ -27,81 +27,144 @@ from llm.core.positional_embeddings import PositionalEmbeddings
|
||||
from llm.core.cached_decoder import CachedDecoder
|
||||
from llm.core.feed_forward import FeedForward
|
||||
|
||||
|
||||
class GPT2(BaseModel):
|
||||
"""
|
||||
GPT2 — автогерессивная языковая модель, архитектура Transformer, предложенная OpenAI.
|
||||
GPT-2 — масштабируемый автогерессивный языковой трансформер второго поколения от OpenAI (2019).
|
||||
|
||||
Научная суть:
|
||||
- Масштабируемый автогерессивный трансформер для предсказания токенов слева направо.
|
||||
- Главное отличие от классической GPT: порядок layer normalization ПЕРЕД attention и FFN.
|
||||
- Используется GELU, efficient KV-cache, несет наследие классической GPT, но делает архитектуру глубже/шире.
|
||||
|
||||
Args:
|
||||
config (dict): параметры архитектуры (vocab_size, embed_dim, num_heads, num_layers, max_position_embeddings, dropout)
|
||||
Назначение:
|
||||
-----------
|
||||
- Позволяет предсказывать и порождать последовательности текста по одному токену, будучи обученным на задаче language modeling.
|
||||
- Модель реализует архитектуру decoder-only Transformer с Pre-LN (LayerNorm перед attention и FFN).
|
||||
- Используется для генерации, обучения с подкреплением для RLHF, zero/few-shot inference, чат-ботов и др.
|
||||
|
||||
Архитектурные особенности:
|
||||
--------------------------
|
||||
- Token и positional embeddings (learnable, как в GPT-2 оригинале).
|
||||
- Stack из N блоков Decoder (MultiHeadAttention с causal mask, Residual, Pre-LayerNorm, GELU FFN).
|
||||
- KV attention-кэш (ускоряет autoregressive generation, критически важно для LLM).
|
||||
- Использует GELU как функцию активации.
|
||||
- Поддержка dropout на каждом этапе.
|
||||
|
||||
Основные параметры:
|
||||
-------------------
|
||||
config: dict — параметры модели:
|
||||
vocab_size, # размер словаря токенов
|
||||
embed_dim, # размерность эмбеддинга
|
||||
num_heads, # количество attention голов
|
||||
num_layers, # глубина модели (число блоков)
|
||||
max_position_embeddings,
|
||||
dropout
|
||||
|
||||
Процессинг:
|
||||
-----------
|
||||
x (индексы токенов) → token_embeddings + position_embeddings → dropout
|
||||
→ stack Decoder blocks (masked attention, pre-LN)
|
||||
→ LayerNorm
|
||||
→ Linear(out_dim=vocab_size) → выходные логиты
|
||||
|
||||
Пример использования:
|
||||
>>> model = GPT2({"vocab_size": 50257, ...})
|
||||
>>> logits = model(input_ids)
|
||||
>>> out = model.generate(input_ids, max_length=20)
|
||||
---------------------
|
||||
>>> gpt2 = GPT2({...})
|
||||
>>> logits = gpt2(input_ids)
|
||||
>>> output = gpt2.generate(input_ids, max_new_tokens=20, do_sample=True)
|
||||
|
||||
References:
|
||||
-----------
|
||||
- Radford et al., "Language Models are Unsupervised Multitask Learners" (GPT-2, 2019): https://cdn.openai.com/better-language-models/language-models.pdf
|
||||
- HuggingFace GPT-2: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py
|
||||
- Репликация в NanoGPT: https://github.com/karpathy/nanoGPT
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
"""
|
||||
Инициализация GPT-2.
|
||||
|
||||
Args:
|
||||
config (dict): Параметры архитектуры:
|
||||
vocab_size: int — размер словаря
|
||||
embed_dim: int — размерность эмбеддинга
|
||||
num_heads: int — количество attention-голов
|
||||
num_layers: int — количество декодер-блоков
|
||||
max_position_embeddings: максимальная длина последовательности
|
||||
dropout: float — dropout
|
||||
|
||||
Внутри:
|
||||
-------
|
||||
- Создаёт токеновые и позиционные эмбеддинги, стек декодеров, финальный LayerNorm и линейную проекцию в словарь.
|
||||
"""
|
||||
super().__init__(config)
|
||||
|
||||
# Инициализация слоев
|
||||
self._max_seq_len = config["max_position_embeddings"]
|
||||
self._token_embeddings = TokenEmbeddings(
|
||||
vocab_size=config["vocab_size"],
|
||||
emb_size=config["embed_dim"]
|
||||
vocab_size=config["vocab_size"], emb_size=config["embed_dim"]
|
||||
)
|
||||
self._position_embeddings = PositionalEmbeddings(
|
||||
max_seq_len=config["max_position_embeddings"],
|
||||
emb_size=config["embed_dim"]
|
||||
max_seq_len=config["max_position_embeddings"], emb_size=config["embed_dim"]
|
||||
)
|
||||
self._dropout = nn.Dropout(config["dropout"])
|
||||
# head_size = emb_size // num_heads
|
||||
self._decoders = nn.ModuleList([CachedDecoder(
|
||||
num_heads=config["num_heads"],
|
||||
emb_size=config["embed_dim"],
|
||||
head_size=config["embed_dim"] // config["num_heads"],
|
||||
feed_forward_layer=FeedForward(
|
||||
emb_size=config["embed_dim"],
|
||||
dropout=config["dropout"],
|
||||
activation="gelu"
|
||||
),
|
||||
max_seq_len=config["max_position_embeddings"],
|
||||
dropout=config["dropout"]
|
||||
) for _ in range(config["num_layers"])])
|
||||
self._decoders = nn.ModuleList(
|
||||
[
|
||||
CachedDecoder(
|
||||
num_heads=config["num_heads"],
|
||||
emb_size=config["embed_dim"],
|
||||
head_size=config["embed_dim"] // config["num_heads"],
|
||||
feed_forward_layer=FeedForward(
|
||||
emb_size=config["embed_dim"],
|
||||
dropout=config["dropout"],
|
||||
activation="gelu",
|
||||
),
|
||||
max_seq_len=config["max_position_embeddings"],
|
||||
dropout=config["dropout"],
|
||||
)
|
||||
for _ in range(config["num_layers"])
|
||||
]
|
||||
)
|
||||
self._norm = nn.LayerNorm(config["embed_dim"])
|
||||
self._linear = nn.Linear(config["embed_dim"], config["vocab_size"])
|
||||
|
||||
def forward(self, x: torch.Tensor, use_cache: bool = True, cache: list = None) -> tuple:
|
||||
def forward(
|
||||
self, x: torch.Tensor, use_cache: bool = True, cache: list = None
|
||||
) -> tuple:
|
||||
"""
|
||||
Прямой проход GPT2:
|
||||
- Все слои работают как autoregressive transformer (masked self-attention).
|
||||
- При use_cache=True возвращает также новый кэш KV attention (ускоряет генерацию).
|
||||
Прямой проход для batch of sequences (получение логитов по токенам).
|
||||
|
||||
Args:
|
||||
x (Tensor): Входные индексы токенов [batch, seq_len]
|
||||
use_cache (bool): Кэшировать KV attention для ускорения autoregressive генерации
|
||||
cache (list|None): Список KV-кэшей от предыдущих шагов (или None)
|
||||
x (torch.Tensor): Входной тензор с токенами [batch, seq_len]
|
||||
use_cache (bool): Использовать/возвращать кэш KV attention (ускоряет генерацию)
|
||||
cache (list / None): Внешний кэш KV attention (передаётся при генерации)
|
||||
|
||||
Returns:
|
||||
logits (Tensor): [batch, seq_len, vocab_size]
|
||||
cache (list): новый кэш если use_cache=True, иначе None
|
||||
logits: torch.Tensor [batch, seq_len, vocab_size]
|
||||
new_cache: новый кэш KV attention (или None)
|
||||
|
||||
Пример:
|
||||
>>> logits, cache = model.forward(x, use_cache=True)
|
||||
>>> logits, cache = gpt2(x, use_cache=True)
|
||||
"""
|
||||
# Проверка длины последовательности (только при отсутствии кэша)
|
||||
if cache is None and x.size(1) > self._max_seq_len:
|
||||
raise ValueError(f"Длина последовательности {x.size(1)} превышает максимальную {self.max_seq_len}")
|
||||
|
||||
|
||||
raise ValueError(
|
||||
f"Длина последовательности {x.size(1)} превышает максимальную {self.max_seq_len}"
|
||||
)
|
||||
|
||||
# Вычисление start_pos из кэша (если кэш передан)
|
||||
if cache is not None:
|
||||
# При кэше обрабатываем только один токен (последний)
|
||||
seq_len = 1
|
||||
# Вычисляем start_pos из самого нижнего уровня кэша
|
||||
if cache and cache[0] and cache[0][0]:
|
||||
key_cache, _ = cache[0][0] # Первый декодер, первая голова
|
||||
start_pos = key_cache.size(1) # cache_len
|
||||
# Безопасно извлекаем key_cache для вычисления start_pos
|
||||
if (
|
||||
isinstance(cache, (list, tuple))
|
||||
and len(cache) > 0
|
||||
and cache[0] is not None
|
||||
and isinstance(cache[0], (list, tuple))
|
||||
and len(cache[0]) > 0
|
||||
and cache[0][0] is not None
|
||||
and isinstance(cache[0][0], (tuple, list))
|
||||
and len(cache[0][0]) > 0
|
||||
):
|
||||
key_cache, _ = cache[0][0]
|
||||
start_pos = key_cache.size(1)
|
||||
else:
|
||||
start_pos = 0
|
||||
else:
|
||||
@@ -111,11 +174,15 @@ class GPT2(BaseModel):
|
||||
|
||||
# Эмбеддинги токенов и позиций
|
||||
tok_out = self._token_embeddings(x) # [batch, seq_len, emb_size]
|
||||
pos_out = self._position_embeddings(seq_len, start_pos=start_pos) # [seq_len, emb_size]
|
||||
|
||||
pos_out = self._position_embeddings(
|
||||
seq_len, start_pos=start_pos
|
||||
) # [seq_len, emb_size]
|
||||
|
||||
# Комбинирование
|
||||
out = self._dropout(tok_out + pos_out.unsqueeze(0)) # [batch, seq_len, emb_size]
|
||||
|
||||
out = self._dropout(
|
||||
tok_out + pos_out.unsqueeze(0)
|
||||
) # [batch, seq_len, emb_size]
|
||||
|
||||
# Стек декодеров с передачей кэша
|
||||
new_cache = []
|
||||
for i, decoder in enumerate(self._decoders):
|
||||
@@ -131,39 +198,76 @@ class GPT2(BaseModel):
|
||||
|
||||
out = self._norm(out)
|
||||
logits = self._linear(out)
|
||||
|
||||
|
||||
# Возвращаем результат с учетом use_cache
|
||||
if use_cache:
|
||||
return (logits, new_cache)
|
||||
else:
|
||||
return (logits, None)
|
||||
|
||||
def generate(self,
|
||||
x: torch.Tensor,
|
||||
max_new_tokens: int,
|
||||
def generate(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
max_new_tokens: int,
|
||||
do_sample: bool,
|
||||
temperature: float = 1.0,
|
||||
top_k: int = None,
|
||||
top_p: float = None,
|
||||
use_cache: bool = True
|
||||
use_cache: bool = True,
|
||||
attention_mask: torch.Tensor = None,
|
||||
**kwargs
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Генерация текста с использованием autoregressive трансформера (GPT2).
|
||||
Поддерживаются greedy, sampling, top-k/top-p (nucleus sampling) режимы.
|
||||
Args:
|
||||
x (Tensor[int]): начальная последовательность [batch, seq_len]
|
||||
max_new_tokens (int): сколько токенов сгенерировать
|
||||
do_sample (bool): использовать стохастическое сэмплирование вместо жадного выбора
|
||||
temperature (float): коэффициент сглаживания логитов (низкое — более консервативно)
|
||||
top_k (int|None): ограничить выбор top-k наиболее вероятных токенов
|
||||
top_p (float|None): ограничить суммарную вероятность (nucleus sampling)
|
||||
use_cache (bool): ускорять autoregressive инференс
|
||||
Returns:
|
||||
output (Tensor[int]): сгенерированный тензор токенов [batch, seq_len + max_new_tokens]
|
||||
Пример:
|
||||
>>> prompt = tokenizer.encode('Привет', return_tensors="pt")
|
||||
>>> output = model.generate(prompt, max_new_tokens=20, do_sample=True)
|
||||
>>> print(tokenizer.decode(output[0]))
|
||||
Авторегрессивная генерация токенов с поддержкой greedy, temperature, top-k, top-p sampling и KV-кэша.
|
||||
|
||||
Аргументы:
|
||||
x (torch.Tensor): Входной тензор с индексами токенов [batch_size, seq_len].
|
||||
max_new_tokens (int): Максимальное количество новых токенов для генерации.
|
||||
do_sample (bool): Режим генерации:
|
||||
- True: вероятностное сэмплирование (random sampling)
|
||||
- False: жадный (greedy) поиск (выбор argmax на каждом шаге)
|
||||
temperature (float): Температура распределения (>0, по умолчанию 1.0).
|
||||
- >1.0 — генерация более "творческая"/приподнятая вероятность "редких" токенов;
|
||||
- <1.0 — более предсказуемый и суженный выбор.
|
||||
top_k (int, опционально): Если задан, sampling только из top_k самых вероятных токенов (top-k sampling).
|
||||
top_p (float, опционально): Если задан, sampling только из токенов, кумулятивная вероятность которых ≤ top_p (nucleus/top-p sampling, см. Holtzman et al., 2019).
|
||||
use_cache (bool, по умолчанию True): Использовать кэш attention KV для ускорения авторегрессии.
|
||||
|
||||
Возвращает:
|
||||
torch.Tensor: Тензор индексов токенов [batch_size, seq_len + max_new_tokens].
|
||||
|
||||
Исключения:
|
||||
ValueError: Если x длиннее максимальной длины (max_seq_len).
|
||||
ValueError: Если temperature ≤ 0.
|
||||
ValueError: Если одновременно заданы top_k и top_p.
|
||||
ValueError: Если top_k ≤ 0.
|
||||
ValueError: Если top_p не в диапазоне (0, 1].
|
||||
|
||||
Примеры использования:
|
||||
>>> # Жадная генерация
|
||||
>>> output = model.generate(input_ids, max_new_tokens=20, do_sample=False)
|
||||
|
||||
>>> # Сэмплирование с температурой
|
||||
>>> output = model.generate(input_ids, max_new_tokens=20, do_sample=True, temperature=0.8)
|
||||
|
||||
>>> # Top-k sampling
|
||||
>>> output = model.generate(input_ids, max_new_tokens=20, do_sample=True, top_k=50)
|
||||
|
||||
>>> # Top-p (nucleus) sampling
|
||||
>>> output = model.generate(input_ids, max_new_tokens=20, do_sample=True, top_p=0.92)
|
||||
|
||||
>>> # Комбинация температуры и top-k
|
||||
>>> output = model.generate(input_ids, max_new_tokens=20, do_sample=True, temperature=0.7, top_k=40)
|
||||
|
||||
Примечания:
|
||||
- Для детерминированных результатов используйте torch.manual_seed.
|
||||
- temperature, top_k, top_p работают только при do_sample=True.
|
||||
- Только один из top_k/top_p может быть задан одновременно.
|
||||
- Метод всегда возвращает индексы токенов (ids); для получения логитов используйте forward.
|
||||
|
||||
Ссылки:
|
||||
- Holtzman et al., "The Curious Case of Neural Text Degeneration" (nucleus sampling): https://arxiv.org/abs/1904.09751
|
||||
- Оригинальная статья GPT-2: https://cdn.openai.com/better-language-models/language-models.pdf
|
||||
"""
|
||||
cache = None
|
||||
|
||||
@@ -174,10 +278,10 @@ class GPT2(BaseModel):
|
||||
else:
|
||||
# Первая итерация или кэш отключен - передаем всю последовательность
|
||||
x_input = x
|
||||
|
||||
|
||||
# Прямой проход с кэшем
|
||||
logits, new_cache = self.forward(x_input, use_cache=use_cache, cache=cache)
|
||||
|
||||
|
||||
# Обновляем кэш для следующей итерации
|
||||
if use_cache:
|
||||
cache = new_cache
|
||||
@@ -198,26 +302,27 @@ class GPT2(BaseModel):
|
||||
vocab_size = logits_scaled.size(-1)
|
||||
|
||||
# создаём маску: 1, если токен НЕ в topk_indices
|
||||
mask = torch.ones_like(logits_scaled, dtype=torch.uint8)
|
||||
mask.scatter_(1, topk_indices, 0) # 0 там, где top-k индексы
|
||||
masked_logits[mask.byte()] = float('-inf')
|
||||
|
||||
mask = torch.ones_like(logits_scaled, dtype=torch.bool if hasattr(torch, "bool") else torch.uint8)
|
||||
mask.scatter_(1, topk_indices, False if hasattr(torch, "bool") else 0) # 0 там, где top-k индексы
|
||||
masked_logits[mask.bool() if hasattr(torch, "bool") else mask.byte()] = float('-inf')
|
||||
logits_scaled = masked_logits
|
||||
|
||||
if do_sample == True and top_p != None:
|
||||
# 1. Применим softmax, чтобы получить вероятности:
|
||||
probs = F.softmax(logits_scaled, dim=-1) # [B, vocab_size]
|
||||
# 2. Отсортируем токены по убыванию вероятностей:
|
||||
sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1)
|
||||
sorted_probs, sorted_indices = torch.sort(
|
||||
probs, descending=True, dim=-1
|
||||
)
|
||||
# 3. Посчитаем кумулятивную сумму вероятностей:
|
||||
cum_probs = torch.cumsum(sorted_probs, dim=-1) # [B, vocab_size]
|
||||
# 4. Определим маску: оставить токены, пока сумма < top_p
|
||||
sorted_mask = (cum_probs <= top_p).byte() # [B, vocab_size]
|
||||
sorted_mask = (cum_probs <= top_p).bool() if hasattr(torch, "bool") else (cum_probs <= top_p).byte() # [B, vocab_size]
|
||||
# Гарантируем, что хотя бы первый токен останется
|
||||
sorted_mask[:, 0] = 1
|
||||
sorted_mask[:, 0] = True if hasattr(torch, "bool") else 1
|
||||
# 5. Преобразуем маску обратно в оригинальный порядок:
|
||||
# Создаём полную маску из 0
|
||||
mask = torch.zeros_like(probs, dtype=torch.uint8)
|
||||
mask = torch.zeros_like(probs, dtype=torch.bool if hasattr(torch, "bool") else torch.uint8)
|
||||
# Устанавливаем 1 в местах нужных токенов
|
||||
mask.scatter_(dim=1, index=sorted_indices, src=sorted_mask)
|
||||
# 6. Зануляем логиты токенов вне топ-p:
|
||||
@@ -226,18 +331,19 @@ class GPT2(BaseModel):
|
||||
# 4. Применяем Softmax
|
||||
probs = F.softmax(logits_scaled, dim=-1) # [batch_size, vocab_size]
|
||||
|
||||
|
||||
if do_sample == True:
|
||||
# 5. Если do_sample равен True, то отбираем токен случайно с помощью torch.multinomial
|
||||
next_token = torch.multinomial(probs, num_samples=1) # [batch_size, 1]
|
||||
else:
|
||||
# 5. Если do_sample равен False, то выбираем токен с максимальной вероятностью
|
||||
next_token = torch.argmax(probs, dim=-1, keepdim=True) # [batch_size, 1]
|
||||
|
||||
next_token = torch.argmax(
|
||||
probs, dim=-1, keepdim=True
|
||||
) # [batch_size, 1]
|
||||
|
||||
# 6. Добавляем его к последовательности
|
||||
x = torch.cat([x, next_token], dim=1) # [batch_size, seq_len+1]
|
||||
return x
|
||||
|
||||
@property
|
||||
def max_seq_len(self) -> int:
|
||||
return self._max_seq_len
|
||||
return self._max_seq_len
|
||||
|
||||
@@ -10,77 +10,121 @@ from llm.core.rope import RoPE
|
||||
from llm.core.cached_decoder import CachedDecoder
|
||||
|
||||
|
||||
|
||||
class Llama(BaseModel):
|
||||
"""
|
||||
LLaMA (Large Language Model Meta AI) — высокоэффективная масштабируемая языковая модель, разработанная Meta AI Research.
|
||||
LLaMA — автогерессивная большая языковая модель (Large Language Model from Meta, 2023).
|
||||
|
||||
Ключевые идеи:
|
||||
- Rotary Positional Encoding (RoPE) вместо стандартных позиционных эмбеддингов
|
||||
- RMSNorm (Root Mean Square LayerNorm) вместо LayerNorm
|
||||
- SwiGLU как нелинейность вместо ReLU/GELU (больше экспрессивности)
|
||||
- Глубокая оптимизация inference (большая экономия памяти и FLOPs)
|
||||
Подробнее: https://arxiv.org/abs/2302.13971
|
||||
Назначение:
|
||||
-----------
|
||||
- Модель реализует архитектуру decoder-only Transformer с современными "индустриальными" трюками (RMSNorm, SwiGLU, RoPE, GQA).
|
||||
- Предназначена для генерации текста, чат-ботов, zero-/few-shot вывода, fine-tune в стиле RLHF, transfer learning и исследований в LLM.
|
||||
|
||||
Архитектурные особенности:
|
||||
--------------------------
|
||||
- Токеновые эмбеддинги и позиционное кодирование с помощью Rotary Position Embedding (RoPE, https://arxiv.org/abs/2104.09864).
|
||||
- Stack из num_layers современных декодеров с Grouped Query Attention (GQA: num_q_heads > num_kv_heads) для эффективной генерации.
|
||||
- FeedForward блоки с SwiGLU (см. https://arxiv.org/abs/2002.05202).
|
||||
- Нормализация RMSNorm перед каждым sub-layer (вот почему "Pre-RMSNorm").
|
||||
- Кэширование attention (KV cache) для быстрой autoregressive генерации.
|
||||
- Нет bias в Linear слоях, нет Dropout внутри attention.
|
||||
|
||||
Аргументы конструктора:
|
||||
-----------------------
|
||||
config: dict с требуемыми ключами:
|
||||
vocab_size: int — размер словаря токенов
|
||||
embed_dim: int — размерность эмбеддингов
|
||||
num_q_heads: int — количество query-голов в attention (обычно больше num_kv_heads)
|
||||
num_kv_heads: int — количество key/value-голов
|
||||
num_layers: int — число слоёв-декодеров
|
||||
max_position_embeddings: int — максимальная длина последовательности
|
||||
window_size: int (optional) — размер sliding window для attention
|
||||
dropout: float (обычно 0.0 или очень мал)
|
||||
...
|
||||
|
||||
Пример использования:
|
||||
---------------------
|
||||
>>> llama = LLaMA({...})
|
||||
>>> tokens = torch.tensor([[100, 56, 8]])
|
||||
>>> logits = llama(tokens)
|
||||
>>> out = llama.generate(tokens, max_new_tokens=10, do_sample=True, top_k=50)
|
||||
|
||||
References:
|
||||
-----------
|
||||
- "LLaMA: Open and Efficient Foundation Language Models" (Touvron et al., 2023): https://arxiv.org/abs/2302.13971
|
||||
- "Grouped-Query Attention": https://arxiv.org/abs/2307.09288
|
||||
- "RoFormer: Enhanced Transformer with Rotary Position Embedding": https://arxiv.org/abs/2104.09864
|
||||
- Discussion of efficient LLMs: https://huggingface.co/blog/mistral
|
||||
|
||||
Args:
|
||||
config (dict): параметры архитектуры (vocab_size, embed_dim, num_heads, num_layers, max_position_embeddings, dropout)
|
||||
Пример:
|
||||
>>> model = Llama({...})
|
||||
>>> logits, cache = model(input_ids, use_cache=True)
|
||||
>>> out = model.generate(input_ids, max_new_tokens=20)
|
||||
"""
|
||||
def __init__(self,config):
|
||||
|
||||
def __init__(self, config):
|
||||
"""
|
||||
Инициализация LLaMA.
|
||||
|
||||
Args:
|
||||
config (dict): Параметры архитектуры, см. docstring класса.
|
||||
Внутри:
|
||||
-------
|
||||
- Создаёт Embedding-слой, Rotary Position Embeddings (RoPE), стек слоёв с GQA, RMSNorm, SwiGLU.
|
||||
- Финальный слой нормализации и проекции на vocabulary.
|
||||
"""
|
||||
super().__init__(config)
|
||||
|
||||
# Инициализация слоев
|
||||
self._max_seq_len = config["max_position_embeddings"]
|
||||
self._token_embeddings = TokenEmbeddings(
|
||||
vocab_size=config["vocab_size"],
|
||||
emb_size=config["embed_dim"]
|
||||
vocab_size=config["vocab_size"], emb_size=config["embed_dim"]
|
||||
)
|
||||
self._position_embeddings = RoPE(
|
||||
head_size=config["embed_dim"] // config["num_heads"],
|
||||
max_seq_len=config["max_position_embeddings"]
|
||||
max_seq_len=config["max_position_embeddings"],
|
||||
)
|
||||
|
||||
self._dropout = nn.Dropout(config["dropout"])
|
||||
self._decoders = nn.ModuleList([CachedDecoder(
|
||||
norm_layer=RMSNorm,
|
||||
num_heads=config["num_heads"],
|
||||
emb_size=config["embed_dim"],
|
||||
head_size=config["embed_dim"] // config["num_heads"],
|
||||
feed_forward_layer=SwiGLU(
|
||||
emb_size=config["embed_dim"],
|
||||
dropout=config["dropout"],
|
||||
),
|
||||
max_seq_len=config["max_position_embeddings"],
|
||||
rope=self._position_embeddings,
|
||||
dropout=config["dropout"],
|
||||
) for _ in range(config["num_layers"])])
|
||||
self._decoders = nn.ModuleList(
|
||||
[
|
||||
CachedDecoder(
|
||||
norm_layer=RMSNorm,
|
||||
num_heads=config["num_heads"],
|
||||
emb_size=config["embed_dim"],
|
||||
head_size=config["embed_dim"] // config["num_heads"],
|
||||
feed_forward_layer=SwiGLU(
|
||||
emb_size=config["embed_dim"],
|
||||
dropout=config["dropout"],
|
||||
),
|
||||
max_seq_len=config["max_position_embeddings"],
|
||||
rope=self._position_embeddings,
|
||||
dropout=config["dropout"],
|
||||
)
|
||||
for _ in range(config["num_layers"])
|
||||
]
|
||||
)
|
||||
self._norm = RMSNorm(config["embed_dim"])
|
||||
self._linear = nn.Linear(config["embed_dim"], config["vocab_size"])
|
||||
|
||||
def forward(self, x: torch.Tensor, use_cache: bool = True, cache: list = None) -> tuple:
|
||||
def forward(
|
||||
self, x: torch.Tensor, use_cache: bool = True, cache: list = None
|
||||
) -> tuple:
|
||||
"""
|
||||
Прямой проход через LLaMA (inference/train): авторегрессионное предсказание токенов.
|
||||
Прямой проход: возвращает logits (и возможно обновлённый cache) по входным токенам.
|
||||
|
||||
Args:
|
||||
x (Tensor[int]): входные токены [batch, seq_len]
|
||||
use_cache (bool): использовать ли кэш (ускоряет генерацию)
|
||||
cache (list|None): ключи и значения attention для autoregressive режима
|
||||
x (torch.Tensor): [batch, seq_len] — индексы токенов, shape [batch, seq_len]
|
||||
use_cache (bool): использовать механизм KV cache (ускоряет autoregressive generation)
|
||||
cache (list or None): предыдущий кэш, если нужен
|
||||
|
||||
Returns:
|
||||
logits (Tensor): [batch, seq_len, vocab_size]
|
||||
new_cache (list|None): новый кэш attention (если use_cache)
|
||||
Пример:
|
||||
>>> logits, cache = model.forward(x, use_cache=True)
|
||||
logits: torch.Tensor [batch, seq_len, vocab_size]
|
||||
new_cache: новый кэш attention (или None)
|
||||
"""
|
||||
# Проверка длины последовательности (только при отсутствии кэша)
|
||||
if cache is None and x.size(1) > self._max_seq_len:
|
||||
raise ValueError(f"Длина последовательности {x.size(1)} превышает максимальную {self.max_seq_len}")
|
||||
|
||||
|
||||
raise ValueError(
|
||||
f"Длина последовательности {x.size(1)} превышает максимальную {self.max_seq_len}"
|
||||
)
|
||||
|
||||
# Вычисление start_pos из кэша (если кэш передан)
|
||||
#if cache is not None:
|
||||
# if cache is not None:
|
||||
# # При кэше обрабатываем только один токен (последний)
|
||||
# seq_len = 1
|
||||
# # Вычисляем start_pos из самого нижнего уровня кэша
|
||||
@@ -89,18 +133,18 @@ class Llama(BaseModel):
|
||||
# start_pos = key_cache.size(1) # cache_len
|
||||
# else:
|
||||
# start_pos = 0
|
||||
#else:
|
||||
# else:
|
||||
# # Без кэша работаем как раньше
|
||||
# start_pos = 0
|
||||
# seq_len = x.size(1)
|
||||
|
||||
# Эмбеддинги токенов и позиций
|
||||
tok_out = self._token_embeddings(x) # [batch, seq_len, emb_size]
|
||||
#pos_out = self._position_embeddings(x) # [batch, seq_len, emb_size]
|
||||
|
||||
# pos_out = self._position_embeddings(x) # [batch, seq_len, emb_size]
|
||||
|
||||
# Комбинирование
|
||||
out = self._dropout(tok_out) # [batch, seq_len, emb_size]
|
||||
|
||||
|
||||
# Стек декодеров с передачей кэша
|
||||
new_cache = []
|
||||
for i, decoder in enumerate(self._decoders):
|
||||
@@ -116,42 +160,70 @@ class Llama(BaseModel):
|
||||
|
||||
out = self._norm(out)
|
||||
logits = self._linear(out)
|
||||
|
||||
|
||||
# Возвращаем результат с учетом use_cache
|
||||
if use_cache:
|
||||
return (logits, new_cache)
|
||||
else:
|
||||
return (logits, None)
|
||||
|
||||
def generate(self,
|
||||
x: torch.Tensor,
|
||||
max_new_tokens: int,
|
||||
def generate(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
max_new_tokens: int,
|
||||
do_sample: bool,
|
||||
temperature: float = 1.0,
|
||||
top_k: int = None,
|
||||
top_p: float = None,
|
||||
use_cache: bool = True
|
||||
use_cache: bool = True,
|
||||
attention_mask: torch.Tensor = None,
|
||||
**kwargs
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Генерация текста c помощью LLaMA (autoregressive Transformer).
|
||||
Поддерживается:
|
||||
- greedy и вероятностное сэмплирование (top-k, top-p, temperature)
|
||||
- кэш attention для ускорения генерации длинных последовательностей
|
||||
|
||||
Args:
|
||||
x (Tensor[int]): начальная последовательность [batch, seq_len]
|
||||
max_new_tokens (int): сколько новых токенов сгенерировать
|
||||
do_sample (bool): использовать стохастику (True) или жадный выбор (False)
|
||||
temperature (float): масштаб для softmax (важно для sampling)
|
||||
top_k (int|None): ограничение на количество кандидатов (top-k sampling)
|
||||
top_p (float|None): nucleus sampling
|
||||
use_cache (bool): ускоряет autoregressive при длинной генерации
|
||||
Returns:
|
||||
output (Tensor[int]): [batch, seq_len + max_new_tokens]
|
||||
Пример:
|
||||
>>> prompt = tokenizer.encode('Meta AI', return_tensors="pt")
|
||||
>>> generated = model.generate(prompt, max_new_tokens=30, do_sample=True)
|
||||
>>> print(tokenizer.decode(generated[0]))
|
||||
Авторегрессивная генерация последовательностей на основе LLaMA (greedy, temperature, top-k, top-p/nucleus, поддержка KV-кэша).
|
||||
|
||||
Аргументы:
|
||||
x (torch.Tensor): Входной тензор с токенами shape [batch_size, seq_len].
|
||||
max_new_tokens (int): Максимальное количество новых токенов для генерации.
|
||||
do_sample (bool): Использовать вероятностное сэмплирование (True) или жадный режим (False, argmax).
|
||||
temperature (float): Температура (сглаживание распределения вероятностей, >0; по умолчанию 1.0).
|
||||
>1.0 — менее предсказуемые, более разнообразные выборки.
|
||||
<1.0 — более строгие, консервативные выборки.
|
||||
top_k (int, опционально): Top-k сэмплирование (ограничение выбора k самыми вероятными токенами).
|
||||
top_p (float, опционально): Nucleus (top-p) sampling (срез по кумулятивной вероятности ≤ top_p, см. Holtzman et al., 2019).
|
||||
use_cache (bool, по умолчанию True): Использовать KV-кэш для ускорения генерации.
|
||||
|
||||
Возвращает:
|
||||
torch.Tensor: Последовательность токенов shape [batch_size, seq_len + max_new_tokens].
|
||||
|
||||
Исключения:
|
||||
ValueError: Если x длиннее максимально допустимой длины (max_seq_len модели).
|
||||
ValueError: Если temperature ≤ 0.
|
||||
ValueError: Если одновременно заданы top_k и top_p.
|
||||
ValueError: Если top_k ≤ 0.
|
||||
ValueError: Если top_p не в диапазоне (0, 1].
|
||||
|
||||
Примеры:
|
||||
>>> # Строго жадная генерация
|
||||
>>> out = model.generate(input_ids, max_new_tokens=16, do_sample=False)
|
||||
>>> # Вероятностная генерация с температурой
|
||||
>>> out = model.generate(input_ids, max_new_tokens=16, do_sample=True, temperature=0.7)
|
||||
>>> # Top-k sampling
|
||||
>>> out = model.generate(input_ids, max_new_tokens=16, do_sample=True, top_k=50)
|
||||
>>> # Top-p (nucleus)
|
||||
>>> out = model.generate(input_ids, max_new_tokens=16, do_sample=True, top_p=0.92)
|
||||
>>> # Комбинация температуры и top-k
|
||||
>>> out = model.generate(input_ids, max_new_tokens=16, do_sample=True, temperature=1.0, top_k=100)
|
||||
|
||||
Примечания:
|
||||
- temperature, top_k, top_p применяются только если do_sample=True.
|
||||
- Одновременное использование top_k и top_p запрещено.
|
||||
- Для воспроизводимых результатов зафиксируйте seed через torch.manual_seed.
|
||||
- Возвращается только индексы токенов; для получения вероятностей используйте forward.
|
||||
|
||||
Ссылки:
|
||||
- Holtzman et al., "The Curious Case of Neural Text Degeneration" (nucleus/top-p): https://arxiv.org/abs/1904.09751
|
||||
- LLaMA: https://arxiv.org/abs/2302.13971
|
||||
"""
|
||||
cache = None
|
||||
|
||||
@@ -162,10 +234,10 @@ class Llama(BaseModel):
|
||||
else:
|
||||
# Первая итерация или кэш отключен - передаем всю последовательность
|
||||
x_input = x
|
||||
|
||||
|
||||
# Прямой проход с кэшем
|
||||
logits, new_cache = self.forward(x_input, use_cache=use_cache, cache=cache)
|
||||
|
||||
|
||||
# Обновляем кэш для следующей итерации
|
||||
if use_cache:
|
||||
cache = new_cache
|
||||
@@ -186,26 +258,27 @@ class Llama(BaseModel):
|
||||
vocab_size = logits_scaled.size(-1)
|
||||
|
||||
# создаём маску: 1, если токен НЕ в topk_indices
|
||||
mask = torch.ones_like(logits_scaled, dtype=torch.uint8)
|
||||
mask.scatter_(1, topk_indices, 0) # 0 там, где top-k индексы
|
||||
masked_logits[mask.byte()] = float('-inf')
|
||||
|
||||
mask = torch.ones_like(logits_scaled, dtype=torch.bool if hasattr(torch, "bool") else torch.uint8)
|
||||
mask.scatter_(1, topk_indices, False if hasattr(torch, "bool") else 0) # 0 там, где top-k индексы
|
||||
masked_logits[mask.bool() if hasattr(torch, "bool") else mask.byte()] = float('-inf')
|
||||
logits_scaled = masked_logits
|
||||
|
||||
if do_sample == True and top_p != None:
|
||||
# 1. Применим softmax, чтобы получить вероятности:
|
||||
probs = F.softmax(logits_scaled, dim=-1) # [B, vocab_size]
|
||||
# 2. Отсортируем токены по убыванию вероятностей:
|
||||
sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1)
|
||||
sorted_probs, sorted_indices = torch.sort(
|
||||
probs, descending=True, dim=-1
|
||||
)
|
||||
# 3. Посчитаем кумулятивную сумму вероятностей:
|
||||
cum_probs = torch.cumsum(sorted_probs, dim=-1) # [B, vocab_size]
|
||||
# 4. Определим маску: оставить токены, пока сумма < top_p
|
||||
sorted_mask = (cum_probs <= top_p).byte() # [B, vocab_size]
|
||||
sorted_mask = (cum_probs <= top_p).bool() if hasattr(torch, "bool") else (cum_probs <= top_p).byte() # [B, vocab_size]
|
||||
# Гарантируем, что хотя бы первый токен останется
|
||||
sorted_mask[:, 0] = 1
|
||||
sorted_mask[:, 0] = True if hasattr(torch, "bool") else 1
|
||||
# 5. Преобразуем маску обратно в оригинальный порядок:
|
||||
# Создаём полную маску из 0
|
||||
mask = torch.zeros_like(probs, dtype=torch.uint8)
|
||||
mask = torch.zeros_like(probs, dtype=torch.bool if hasattr(torch, "bool") else torch.uint8)
|
||||
# Устанавливаем 1 в местах нужных токенов
|
||||
mask.scatter_(dim=1, index=sorted_indices, src=sorted_mask)
|
||||
# 6. Зануляем логиты токенов вне топ-p:
|
||||
@@ -214,20 +287,19 @@ class Llama(BaseModel):
|
||||
# 4. Применяем Softmax
|
||||
probs = F.softmax(logits_scaled, dim=-1) # [batch_size, vocab_size]
|
||||
|
||||
|
||||
if do_sample == True:
|
||||
# 5. Если do_sample равен True, то отбираем токен случайно с помощью torch.multinomial
|
||||
next_token = torch.multinomial(probs, num_samples=1) # [batch_size, 1]
|
||||
else:
|
||||
# 5. Если do_sample равен False, то выбираем токен с максимальной вероятностью
|
||||
next_token = torch.argmax(probs, dim=-1, keepdim=True) # [batch_size, 1]
|
||||
|
||||
next_token = torch.argmax(
|
||||
probs, dim=-1, keepdim=True
|
||||
) # [batch_size, 1]
|
||||
|
||||
# 6. Добавляем его к последовательности
|
||||
x = torch.cat([x, next_token], dim=1) # [batch_size, seq_len+1]
|
||||
return x
|
||||
|
||||
|
||||
|
||||
@property
|
||||
def max_seq_len(self) -> int:
|
||||
return self._max_seq_len
|
||||
return self._max_seq_len
|
||||
|
||||
3
llm/src/llm/models/mistral/__init__.py
Normal file
3
llm/src/llm/models/mistral/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .mistral import Mistral
|
||||
|
||||
__all__ = ["Mistral"]
|
||||
276
llm/src/llm/models/mistral/mistral.py
Normal file
276
llm/src/llm/models/mistral/mistral.py
Normal file
@@ -0,0 +1,276 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch import Tensor
|
||||
import torch.nn.functional as F
|
||||
from math import sqrt
|
||||
from llm.core.base_model import BaseModel
|
||||
from llm.core.token_embeddings import TokenEmbeddings
|
||||
from llm.core.rms_norm import RMSNorm
|
||||
from llm.core.rope import RoPE
|
||||
from llm.core.mistral_decoder import MistralDecoder
|
||||
|
||||
|
||||
class Mistral(BaseModel):
|
||||
"""
|
||||
Mistral — автогерессивная языковая LLM-архитектура (2023, Mistral AI) для быстрого и качественного моделирования текста.
|
||||
|
||||
Назначение:
|
||||
-----------
|
||||
- Модель построена на базе decoder-only Transformer с важными оптимизациями: GQA (Grouped Query Attention), RoPE, SwiGLU, RMSNorm, sliding window attention.
|
||||
- Поддерживает autoregressive generation (step-by-step текст), обучение и inference на длинных последовательностях.
|
||||
- Используется в современных open-source LLM: Mistral-7B, Mixtral-8x7B и др.
|
||||
|
||||
Архитектурные особенности:
|
||||
--------------------------
|
||||
- Токеновые эмбеддинги (TokenEmbeddings) и позиционное кодирование через RoPE (rotary position embedding).
|
||||
- Stack из num_layers декодеров с Grouped Query Attention (раздельное число query/key heads для оптимизации памяти).
|
||||
- Sliding Window Attention Mask — позволяет ускорять обработку длинных текстов, ограничивая область внимания для каждого токена (как в оригинальном Mistral).
|
||||
- SwiGLU FeedForward-блоки и RMSNorm.
|
||||
- Dropout (регуляризация).
|
||||
- Кэширование attention (KV cache) для быстрой генерации токенов по одному.
|
||||
|
||||
Аргументы конструктора:
|
||||
-----------------------
|
||||
config (dict): параметры модели (см. документацию Mistral):
|
||||
vocab_size: int — размер словаря токенов
|
||||
embed_dim: int — размерность эмбеддингов
|
||||
num_q_heads: int — количество query-голов (обычно больше num_kv_heads)
|
||||
num_kv_heads: int — количество key/value attention-голов
|
||||
num_layers: int — число слоёв-декодеров
|
||||
max_position_embeddings: int — максимальная длина последовательности
|
||||
window_size: int — размер sliding window attention
|
||||
dropout: float — dropout (обычно очень мал или 0)
|
||||
...
|
||||
|
||||
Пример использования:
|
||||
---------------------
|
||||
>>> model = Mistral({...})
|
||||
>>> tokens = torch.tensor([[100, 56, 8]])
|
||||
>>> logits = model(tokens)
|
||||
>>> generated = model.generate(tokens, max_new_tokens=16, do_sample=True, top_k=50)
|
||||
|
||||
References:
|
||||
-----------
|
||||
- "Mistral: Fast and Efficient Dense and Mixture of Experts Transformer Models" (2023): https://arxiv.org/abs/2310.06825
|
||||
- LLaMA v2 & Grouped-Query Attention: https://arxiv.org/abs/2307.09288
|
||||
- Оригинальное обсуждение архитектуры: https://huggingface.co/blog/mistral
|
||||
|
||||
"""
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
self._max_seq_len = config["max_position_embeddings"]
|
||||
# Инициализация слоев
|
||||
self._token_embeddings = TokenEmbeddings(
|
||||
vocab_size=config["vocab_size"],
|
||||
emb_size=config["embed_dim"]
|
||||
)
|
||||
self._position_embeddings = RoPE(
|
||||
head_size=config["embed_dim"] // config["num_q_heads"],
|
||||
max_seq_len=config["max_position_embeddings"]
|
||||
)
|
||||
#self._position_embeddings = PositionalEmbeddings(
|
||||
# max_seq_len=max_seq_len,
|
||||
# emb_size=emb_size
|
||||
#)
|
||||
self._dropout = nn.Dropout(config["dropout"])
|
||||
self._decoders = nn.ModuleList([MistralDecoder(
|
||||
num_q_heads=config["num_q_heads"],
|
||||
num_kv_heads=config["num_kv_heads"],
|
||||
emb_size=config["embed_dim"],
|
||||
head_size=config["embed_dim"] // config["num_q_heads"],
|
||||
max_seq_len=config["max_position_embeddings"],
|
||||
window_size=config["window_size"],
|
||||
rope=self._position_embeddings,
|
||||
dropout=config["dropout"]
|
||||
) for _ in range(config["num_layers"])])
|
||||
self._norm = RMSNorm(config["embed_dim"])
|
||||
self._linear = nn.Linear(config["embed_dim"], config["vocab_size"])
|
||||
|
||||
def forward(self, x: torch.Tensor, use_cache: bool = True, cache: list = None) -> tuple:
|
||||
"""
|
||||
Прямой проход (forward) через всю модель Mistral: возвращает логиты для токенов и (опционально) кэш attention для ускорения autoregressive генерации.
|
||||
|
||||
Аргументы:
|
||||
x (torch.Tensor): Входной тензор с токенами (shape [batch_size, seq_len]), где значения — индексы токенов.
|
||||
use_cache (bool, по умолчанию True): Возвращать ли новый KV attention-кэш для последующей генерации.
|
||||
cache (list or None): Предыдущий кэш attention (или None для полного прохода без накопления кэша).
|
||||
|
||||
Возвращает:
|
||||
logits (torch.Tensor): Тензор логитов shape [batch_size, seq_len, vocab_size] — вероятностное распределение по словарю для каждого токена.
|
||||
new_cache (list or None): Новый кэш KV attention-слоев (или None, если use_cache=False).
|
||||
|
||||
Исключения:
|
||||
ValueError: Если длина последовательности превышает максимальную (max_seq_len), когда не используется кэш.
|
||||
|
||||
Пример:
|
||||
>>> logits, cache = model.forward(input_ids, use_cache=True)
|
||||
>>> probabilities = torch.softmax(logits, dim=-1)
|
||||
"""
|
||||
# Проверка длины последовательности (только при отсутствии кэша)
|
||||
if cache is None and x.size(1) > self._max_seq_len:
|
||||
raise ValueError(f"Длина последовательности {x.size(1)} превышает максимальную {self.max_seq_len}")
|
||||
|
||||
# Эмбеддинги токенов и позиций
|
||||
tok_out = self._token_embeddings(x) # [batch, seq_len, emb_size]
|
||||
#pos_out = self._position_embeddings(x) # [batch, seq_len, emb_size]
|
||||
|
||||
# Комбинирование
|
||||
out = self._dropout(tok_out) # [batch, seq_len, emb_size]
|
||||
|
||||
# Стек декодеров с передачей кэша
|
||||
new_cache = []
|
||||
for i, decoder in enumerate(self._decoders):
|
||||
decoder_cache = cache[i] if cache is not None else None
|
||||
decoder_result = decoder(out, use_cache=use_cache, cache=decoder_cache)
|
||||
|
||||
# Извлекаем результат из кортежа
|
||||
if use_cache:
|
||||
out, decoder_new_cache = decoder_result
|
||||
new_cache.append(decoder_new_cache)
|
||||
else:
|
||||
out = decoder_result[0]
|
||||
|
||||
out = self._norm(out)
|
||||
logits = self._linear(out)
|
||||
|
||||
# Возвращаем результат с учетом use_cache
|
||||
if use_cache:
|
||||
return (logits, new_cache)
|
||||
else:
|
||||
return (logits, None)
|
||||
|
||||
def generate(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
max_new_tokens: int,
|
||||
do_sample: bool,
|
||||
temperature: float = 1.0,
|
||||
top_k: int = None,
|
||||
top_p: float = None,
|
||||
use_cache: bool = True,
|
||||
attention_mask: torch.Tensor = None,
|
||||
**kwargs
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Авторегрессивная генерация токенов с поддержкой greedy, temperature, top-k/top-p sampling
|
||||
и ускорением через attention-кэш (KV-cache, важно для inference на длинных текстах).
|
||||
|
||||
Аргументы:
|
||||
x (torch.Tensor): Входной тензор с токенами shape [batch_size, seq_len].
|
||||
max_new_tokens (int): Максимальное количество новых токенов для генерации.
|
||||
do_sample (bool): Если True — вероятность/случайность (random sampling); если False — жадная генерация (argmax).
|
||||
temperature (float): Температура (>0, по умолчанию 1.0); >1.0 — более случайные выборы, <1.0 — более строгие.
|
||||
top_k (int, optional): top-k sampling; при сэмплировании выбираются только top_k наиболее вероятных токенов.
|
||||
top_p (float, optional): nucleus (top-p) sampling; выбираются токены с накопленной вероятностью ≤ top_p.
|
||||
use_cache (bool, по умолчанию True): Использовать ускорение через KV attention cache для autoregressive режима.
|
||||
|
||||
Возвращает:
|
||||
torch.Tensor: Последовательность индексов токенов shape [batch_size, seq_len + max_new_tokens].
|
||||
|
||||
Исключения:
|
||||
ValueError: Если x длиннее max_seq_len модели.
|
||||
ValueError: Если temperature ≤ 0.
|
||||
ValueError: Если одновременно заданы top_k и top_p.
|
||||
ValueError: Если top_k ≤ 0.
|
||||
ValueError: Если top_p не в диапазоне (0, 1].
|
||||
|
||||
Примеры:
|
||||
>>> # Жадная генерация
|
||||
>>> out = model.generate(input_ids, max_new_tokens=16, do_sample=False)
|
||||
>>> # Сэмплирование с температурой
|
||||
>>> out = model.generate(input_ids, max_new_tokens=16, do_sample=True, temperature=0.8)
|
||||
>>> # Top-k sampling
|
||||
>>> out = model.generate(input_ids, max_new_tokens=16, do_sample=True, top_k=50)
|
||||
>>> # Top-p (nucleus) sampling
|
||||
>>> out = model.generate(input_ids, max_new_tokens=16, do_sample=True, top_p=0.92)
|
||||
>>> # Температура + top-k
|
||||
>>> out = model.generate(input_ids, max_new_tokens=16, do_sample=True, temperature=1.0, top_k=100)
|
||||
|
||||
Примечания:
|
||||
- Одновременно использовать top_k и top_p нельзя.
|
||||
- Параметры temperature, top_k, top_p работают только при do_sample=True.
|
||||
- Для полного воспроизведения результата зафиксируйте seed через torch.manual_seed.
|
||||
- Метод всегда возвращает только индексы токенов; для получения логитов используйте forward.
|
||||
|
||||
Ссылки:
|
||||
- Holtzman et al., "The Curious Case of Neural Text Degeneration" (nucleus/top-p sampling): https://arxiv.org/abs/1904.09751
|
||||
- Mistral: https://arxiv.org/abs/2310.06825
|
||||
"""
|
||||
cache = None
|
||||
|
||||
for _ in range(max_new_tokens):
|
||||
if use_cache and cache is not None:
|
||||
# Используем кэш - передаем только последний токен
|
||||
x_input = x[:, -1:] # [batch_size, 1]
|
||||
else:
|
||||
# Первая итерация или кэш отключен - передаем всю последовательность
|
||||
x_input = x
|
||||
|
||||
# Прямой проход с кэшем
|
||||
logits, new_cache = self.forward(x_input, use_cache=use_cache, cache=cache)
|
||||
|
||||
# Обновляем кэш для следующей итерации
|
||||
if use_cache:
|
||||
cache = new_cache
|
||||
|
||||
last_logits = logits[:, -1, :] # [batch_size, vocab_size]
|
||||
|
||||
# Масштабируем логиты температурой
|
||||
if temperature > 0:
|
||||
logits_scaled = last_logits / temperature
|
||||
else:
|
||||
logits_scaled = last_logits
|
||||
|
||||
if do_sample == True and top_k != None:
|
||||
_, topk_indices = torch.topk(logits_scaled, top_k, dim=-1)
|
||||
|
||||
# # Заменим все НЕ top-k логиты на -inf
|
||||
masked_logits = logits_scaled.clone()
|
||||
vocab_size = logits_scaled.size(-1)
|
||||
|
||||
# создаём маску: 1, если токен НЕ в topk_indices
|
||||
mask = torch.ones_like(logits_scaled, dtype=torch.bool if hasattr(torch, "bool") else torch.uint8)
|
||||
mask.scatter_(1, topk_indices, False if hasattr(torch, "bool") else 0) # 0 там, где top-k индексы
|
||||
masked_logits[mask.bool() if hasattr(torch, "bool") else mask.byte()] = float('-inf')
|
||||
|
||||
logits_scaled = masked_logits
|
||||
|
||||
if do_sample == True and top_p != None:
|
||||
# 1. Применим softmax, чтобы получить вероятности:
|
||||
probs = F.softmax(logits_scaled, dim=-1) # [B, vocab_size]
|
||||
# 2. Отсортируем токены по убыванию вероятностей:
|
||||
sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1)
|
||||
# 3. Посчитаем кумулятивную сумму вероятностей:
|
||||
cum_probs = torch.cumsum(sorted_probs, dim=-1) # [B, vocab_size]
|
||||
# 4. Определим маску: оставить токены, пока сумма < top_p
|
||||
sorted_mask = (cum_probs <= top_p).bool() if hasattr(torch, "bool") else (cum_probs <= top_p).byte() # [B, vocab_size]
|
||||
# Гарантируем, что хотя бы первый токен останется
|
||||
sorted_mask[:, 0] = True if hasattr(torch, "bool") else 1
|
||||
# 5. Преобразуем маску обратно в оригинальный порядок:
|
||||
# Создаём полную маску из 0
|
||||
mask = torch.zeros_like(probs, dtype=torch.bool if hasattr(torch, "bool") else torch.uint8)
|
||||
# Устанавливаем 1 в местах нужных токенов
|
||||
mask.scatter_(dim=1, index=sorted_indices, src=sorted_mask)
|
||||
# 6. Зануляем логиты токенов вне топ-p:
|
||||
logits_scaled[~mask] = float('-inf')
|
||||
|
||||
# 4. Применяем Softmax
|
||||
probs = F.softmax(logits_scaled, dim=-1) # [batch_size, vocab_size]
|
||||
|
||||
|
||||
if do_sample == True:
|
||||
# 5. Если do_sample равен True, то отбираем токен случайно с помощью torch.multinomial
|
||||
next_token = torch.multinomial(probs, num_samples=1) # [batch_size, 1]
|
||||
else:
|
||||
# 5. Если do_sample равен False, то выбираем токен с максимальной вероятностью
|
||||
next_token = torch.argmax(probs, dim=-1, keepdim=True) # [batch_size, 1]
|
||||
|
||||
# 6. Добавляем его к последовательности
|
||||
x = torch.cat([x, next_token], dim=1) # [batch_size, seq_len+1]
|
||||
return x
|
||||
|
||||
|
||||
@property
|
||||
def max_seq_len(self) -> int:
|
||||
return self._max_seq_len
|
||||
3
llm/src/llm/models/mixtral/__init__.py
Normal file
3
llm/src/llm/models/mixtral/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .mixtral import Mixtral
|
||||
|
||||
__all__ = ["Mixtral"]
|
||||
361
llm/src/llm/models/mixtral/mixtral.py
Normal file
361
llm/src/llm/models/mixtral/mixtral.py
Normal file
@@ -0,0 +1,361 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch import Tensor
|
||||
import torch.nn.functional as F
|
||||
from math import sqrt
|
||||
from llm.core.base_model import BaseModel
|
||||
from llm.core.token_embeddings import TokenEmbeddings
|
||||
from llm.core.rope import RoPE
|
||||
from llm.core.rms_norm import RMSNorm
|
||||
from llm.core.mixtral_decoder import MixtralDecoder
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
class Mixtral(BaseModel):
|
||||
"""
|
||||
Mixtral — языковая модель с архитектурой Mixture-of-Experts на основе современных трансформеров (см. Mixtral 8x7B).
|
||||
|
||||
Описание:
|
||||
---------
|
||||
Данный класс реализует полностью функциональную LLM с блоками MixtralDecoder, которые используют разреженные Feed-Forward сети MoE (Mixture-of-Experts)
|
||||
и Grouped Query Attention (GQA). Позволяет масштабировать количество параметров без экспоненциального роста вычислительных затрат благодаря активации лишь части экспертов на каждый токен.
|
||||
Mixtral поддерживает автотекстогенерацию с caching, position encoding через RoPE и всё необходимое для работы и тренировки современных LLM.
|
||||
|
||||
Архитектурные особенности:
|
||||
--------------------------
|
||||
- Stack из N слоёв MixtralDecoder (каждый — MoE-блок + attention + RMSNorm).
|
||||
- Dropout для регуляризации на уровне эмбеддингов и слоёв.
|
||||
- Позиционные эмбеддинги реализованы через RoPE (Rotary Positional Embeddings).
|
||||
- Финальная RMSNorm плюс Linear-проекция к словарю токенов.
|
||||
- Поддержка автогенерации с sampling (greedy, top-k, top-p), temperature и KV-cache.
|
||||
|
||||
Аргументы конструктора:
|
||||
----------------------
|
||||
config : dict
|
||||
Словарь-конфиг с основными гиперпараметрами модели:
|
||||
- vocab_size : int — размер словаря токенов
|
||||
- embed_dim : int — размер скрытого пространства
|
||||
- max_position_embeddings : int — макс. длина последовательности
|
||||
- num_layers : int — количество декодерных блоков в стеке
|
||||
- num_q_heads : int — число query-голов в attention
|
||||
- num_kv_heads : int — число kv-голов в attention
|
||||
- num_experts : int — число MoE-экспертов
|
||||
- top_k_experts : int — сколько экспертов активировать на токен
|
||||
- dropout : float — вероятность Dropout
|
||||
- window_size : int — размер окна внимания
|
||||
|
||||
Основные методы:
|
||||
----------------
|
||||
- forward(x, use_cache=True, cache=None) — прямой проход, поддерживает batched вход, caching.
|
||||
- generate(...) — авторегрессивная генерация с разными стратегиями sampling и ускорением через cache.
|
||||
- save(path)/load(path, device) — сохранение и восстановление обученной модели.
|
||||
|
||||
Пример:
|
||||
-------
|
||||
>>> config = {...} # dict с параметрами
|
||||
>>> model = Mixtral(config)
|
||||
>>> x = torch.randint(0, config["vocab_size"], (2, 16))
|
||||
>>> logits, cache = model(x, use_cache=True)
|
||||
>>> print(logits.shape) # [2, 16, vocab_size]
|
||||
|
||||
>>> # Генерация
|
||||
>>> out = model.generate(x, max_new_tokens=20, do_sample=True, top_k=10, temperature=0.9)
|
||||
|
||||
Литература:
|
||||
-----------
|
||||
- Mixtral 8x7B: https://mistral.ai/news/mixtral-of-experts/
|
||||
- Switch Transformer: https://arxiv.org/abs/2101.03961
|
||||
- GShard: https://arxiv.org/abs/2006.16668
|
||||
- RoPE: https://arxiv.org/abs/2104.09864
|
||||
- Grouped Query Attention: https://arxiv.org/abs/2305.14236
|
||||
- RMSNorm: https://arxiv.org/abs/1910.07467
|
||||
"""
|
||||
def __init__(self, config):
|
||||
"""
|
||||
Конструктор класса Mixtral.
|
||||
|
||||
Осуществляет инициализацию всех модулей и внутренних параметров большой языковой модели с архитектурой Mixtral/MoE.
|
||||
Использует параметры из конфиг-словаря `config` для гибкой настройки модели.
|
||||
|
||||
Аргументы:
|
||||
----------
|
||||
config : dict
|
||||
Словарь с основными гиперпараметрами архитектуры. Должен содержать ключи:
|
||||
vocab_size (int): Размер словаря токенов.
|
||||
embed_dim (int): Размер скрытого пространства (эмбеддингов).
|
||||
max_position_embeddings (int): Максимальная длина токенной последовательности.
|
||||
num_layers (int): Количество декодерных блоков (слоёв) в модели.
|
||||
num_q_heads (int): Число query-голов (attention heads).
|
||||
num_kv_heads (int): Число key-value голов (attention heads).
|
||||
num_experts (int): Количество экспертов в каждом MoE-блоке.
|
||||
top_k_experts (int): Сколько экспертов активируется для одного токена.
|
||||
dropout (float): Dropout для регуляризации.
|
||||
window_size (int): Размер окна внимания (Attention Window).
|
||||
|
||||
Внутри:
|
||||
-------
|
||||
- Инициализируются эмбеддинги токенов, позиционные эмбеддинги RoPE, Dropout.
|
||||
- Строится стек из num_layers модулей MixtralDecoder с заданным количеством attention heads и экспертов.
|
||||
- Финальный слой нормализации и проекция к логитам словаря (linear layer).
|
||||
|
||||
Пример:
|
||||
-------
|
||||
>>> config = {
|
||||
... "vocab_size": 32000,
|
||||
... "embed_dim": 512,
|
||||
... "max_position_embeddings": 2048,
|
||||
... "num_layers": 24,
|
||||
... "num_q_heads": 8,
|
||||
... "num_kv_heads": 8,
|
||||
... "num_experts": 8,
|
||||
... "top_k_experts": 2,
|
||||
... "dropout": 0.1,
|
||||
... "window_size": 256,
|
||||
... }
|
||||
>>> model = Mixtral(config)
|
||||
|
||||
Примечания:
|
||||
-----------
|
||||
- Конфиг модели должен быть согласован: размеры должны делиться на число голов, число экспертов и top_k_experts корректно выбраны.
|
||||
- Все параметры, необходимые для построения MixtralDecoder, attention и MoE, берутся из config.
|
||||
"""
|
||||
super().__init__(config)
|
||||
|
||||
self._max_seq_len = config["max_position_embeddings"]
|
||||
|
||||
# Инициализация слоев
|
||||
self._token_embeddings = TokenEmbeddings(
|
||||
vocab_size=config["vocab_size"],
|
||||
emb_size=config["embed_dim"]
|
||||
)
|
||||
self._position_embeddings = RoPE(
|
||||
head_size=config["embed_dim"] // config["num_q_heads"],
|
||||
max_seq_len=config["max_position_embeddings"]
|
||||
)
|
||||
#self._position_embeddings = PositionalEmbeddings(
|
||||
# max_seq_len=max_seq_len,
|
||||
# emb_size=emb_size
|
||||
#)
|
||||
self._dropout = nn.Dropout(config["dropout"])
|
||||
self._decoders = nn.ModuleList([MixtralDecoder(
|
||||
num_q_heads=config["num_q_heads"],
|
||||
num_kv_heads=config["num_kv_heads"],
|
||||
emb_size=config["embed_dim"],
|
||||
head_size=config["embed_dim"] // config["num_q_heads"],
|
||||
max_seq_len=config["max_position_embeddings"],
|
||||
num_experts=config["num_experts"],
|
||||
top_k_experts=config["top_k_experts"],
|
||||
window_size=config["window_size"],
|
||||
rope=self._position_embeddings,
|
||||
dropout=config["dropout"]
|
||||
) for _ in range(config["num_layers"])])
|
||||
self._norm = RMSNorm(config["embed_dim"])
|
||||
self._linear = nn.Linear(config["embed_dim"], config["vocab_size"])
|
||||
|
||||
def forward(self, x: torch.Tensor, use_cache: bool = True, cache: list = None) -> tuple:
|
||||
"""
|
||||
Прямой проход (forward) через всю модель Mixtral.
|
||||
|
||||
Данный метод реализует трансформацию входной последовательности токенов в логиты (предсказания вероятностей токенов словаря)
|
||||
с поддержкой эффективного инференса с использованием cache (KV-кэш attention для автогенерации).
|
||||
|
||||
Аргументы:
|
||||
----------
|
||||
x : torch.Tensor
|
||||
Двумерный входной тензор shape [batch_size, seq_len], где каждое значение — ID токена.
|
||||
use_cache : bool, по умолчанию True
|
||||
Если True — в режиме генерации модель возвращает обновлённый список кэшей attention для ускорения последовательного инференса.
|
||||
Если False — attention cache не используется.
|
||||
cache : list, optional
|
||||
(Необязательно) Список (или None) с кэшем KV attention для каждого слоя. Используется для автогенерации текста.
|
||||
|
||||
Возвращает:
|
||||
-----------
|
||||
tuple:
|
||||
- logits : torch.Tensor — выходной тензор shape [batch_size, seq_len, vocab_size] — массив логитов по токенам и словарю.
|
||||
- new_cache : list или None — обновлённый cache, если используется.
|
||||
|
||||
Пример:
|
||||
-------
|
||||
>>> logits, new_cache = model(x, use_cache=True, cache=None)
|
||||
>>> logits.shape # [batch_size, seq_len, vocab_size]
|
||||
|
||||
Примечания:
|
||||
-----------
|
||||
- Если используется cache — эффективно для авторегрессионной генерации (token-by-token), например, при диалогах или длинной генерации.
|
||||
- Если входная последовательность длиннее max_seq_len — будет выброшено исключение.
|
||||
- Если нужен только логит последнего токена — используйте slice: logits[:, -1, :]
|
||||
|
||||
"""
|
||||
# Проверка длины последовательности (только при отсутствии кэша)
|
||||
if cache is None and x.size(1) > self._max_seq_len:
|
||||
raise ValueError(f"Длина последовательности {x.size(1)} превышает максимальную {self.max_seq_len}")
|
||||
|
||||
# Эмбеддинги токенов и позиций
|
||||
tok_out = self._token_embeddings(x) # [batch, seq_len, emb_size]
|
||||
#pos_out = self._position_embeddings(x) # [batch, seq_len, emb_size]
|
||||
|
||||
# Комбинирование
|
||||
out = self._dropout(tok_out) # [batch, seq_len, emb_size]
|
||||
|
||||
# Стек декодеров с передачей кэша
|
||||
new_cache = []
|
||||
for i, decoder in enumerate(self._decoders):
|
||||
decoder_cache = cache[i] if cache is not None else None
|
||||
decoder_result = decoder(out, use_cache=use_cache, cache=decoder_cache)
|
||||
|
||||
# Извлекаем результат из кортежа
|
||||
if use_cache:
|
||||
out, decoder_new_cache = decoder_result
|
||||
new_cache.append(decoder_new_cache)
|
||||
else:
|
||||
out = decoder_result[0]
|
||||
|
||||
out = self._norm(out)
|
||||
logits = self._linear(out)
|
||||
|
||||
# Возвращаем результат с учетом use_cache
|
||||
if use_cache:
|
||||
return (logits, new_cache)
|
||||
else:
|
||||
return (logits, None)
|
||||
|
||||
def generate(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
max_new_tokens: int,
|
||||
do_sample: bool,
|
||||
temperature: float = 1.0,
|
||||
top_k: int = None,
|
||||
top_p: float = None,
|
||||
use_cache: bool = True,
|
||||
attention_mask: torch.Tensor = None,
|
||||
**kwargs
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Авторегрессивная генерация токенов с поддержкой greedy, temperature, top-k/top-p sampling
|
||||
и ускорением через attention-кэш (KV-cache, важно для inference на длинных текстах).
|
||||
|
||||
Аргументы:
|
||||
x (torch.Tensor): Входной тензор с токенами shape [batch_size, seq_len].
|
||||
max_new_tokens (int): Максимальное количество новых токенов для генерации.
|
||||
do_sample (bool): Если True — вероятность/случайность (random sampling); если False — жадная генерация (argmax).
|
||||
temperature (float): Температура (>0, по умолчанию 1.0); >1.0 — более случайные выборы, <1.0 — более строгие.
|
||||
top_k (int, optional): top-k sampling; при сэмплировании выбираются только top_k наиболее вероятных токенов.
|
||||
top_p (float, optional): nucleus (top-p) sampling; выбираются токены с накопленной вероятностью ≤ top_p.
|
||||
use_cache (bool, по умолчанию True): Использовать ускорение через KV attention cache для autoregressive режима.
|
||||
|
||||
Возвращает:
|
||||
torch.Tensor: Последовательность индексов токенов shape [batch_size, seq_len + max_new_tokens].
|
||||
|
||||
Исключения:
|
||||
ValueError: Если x длиннее max_seq_len модели.
|
||||
ValueError: Если temperature ≤ 0.
|
||||
ValueError: Если одновременно заданы top_k и top_p.
|
||||
ValueError: Если top_k ≤ 0.
|
||||
ValueError: Если top_p не в диапазоне (0, 1].
|
||||
|
||||
Примеры:
|
||||
>>> # Жадная генерация
|
||||
>>> out = model.generate(input_ids, max_new_tokens=16, do_sample=False)
|
||||
>>> # Сэмплирование с температурой
|
||||
>>> out = model.generate(input_ids, max_new_tokens=16, do_sample=True, temperature=0.8)
|
||||
>>> # Top-k sampling
|
||||
>>> out = model.generate(input_ids, max_new_tokens=16, do_sample=True, top_k=50)
|
||||
>>> # Top-p (nucleus) sampling
|
||||
>>> out = model.generate(input_ids, max_new_tokens=16, do_sample=True, top_p=0.92)
|
||||
>>> # Температура + top-k
|
||||
>>> out = model.generate(input_ids, max_new_tokens=16, do_sample=True, temperature=1.0, top_k=100)
|
||||
|
||||
Примечания:
|
||||
- Одновременно использовать top_k и top_p нельзя.
|
||||
- Параметры temperature, top_k, top_p работают только при do_sample=True.
|
||||
- Для полного воспроизведения результата зафиксируйте seed через torch.manual_seed.
|
||||
- Метод всегда возвращает только индексы токенов; для получения логитов используйте forward.
|
||||
|
||||
Ссылки:
|
||||
- Holtzman et al., "The Curious Case of Neural Text Degeneration" (nucleus/top-p sampling): https://arxiv.org/abs/1904.09751
|
||||
- Mistral: https://arxiv.org/abs/2310.06825
|
||||
"""
|
||||
cache = None
|
||||
|
||||
for _ in range(max_new_tokens):
|
||||
if use_cache and cache is not None:
|
||||
# Используем кэш - передаем только последний токен
|
||||
x_input = x[:, -1:] # [batch_size, 1]
|
||||
else:
|
||||
# Первая итерация или кэш отключен - передаем всю последовательность
|
||||
x_input = x
|
||||
|
||||
# Прямой проход с кэшем
|
||||
logits, new_cache = self.forward(x_input, use_cache=use_cache, cache=cache)
|
||||
|
||||
# Обновляем кэш для следующей итерации
|
||||
if use_cache:
|
||||
cache = new_cache
|
||||
|
||||
last_logits = logits[:, -1, :] # [batch_size, vocab_size]
|
||||
|
||||
# Масштабируем логиты температурой
|
||||
if temperature > 0:
|
||||
logits_scaled = last_logits / temperature
|
||||
else:
|
||||
logits_scaled = last_logits
|
||||
|
||||
if do_sample == True and top_k != None:
|
||||
_, topk_indices = torch.topk(logits_scaled, top_k, dim=-1)
|
||||
|
||||
# # Заменим все НЕ top-k логиты на -inf
|
||||
masked_logits = logits_scaled.clone()
|
||||
vocab_size = logits_scaled.size(-1)
|
||||
|
||||
# создаём маску: 1, если токен НЕ в topk_indices
|
||||
mask = torch.ones_like(logits_scaled, dtype=torch.bool if hasattr(torch, "bool") else torch.uint8)
|
||||
mask.scatter_(1, topk_indices, False if hasattr(torch, "bool") else 0) # 0 там, где top-k индексы
|
||||
masked_logits[mask.bool() if hasattr(torch, "bool") else mask.byte()] = float('-inf')
|
||||
|
||||
logits_scaled = masked_logits
|
||||
|
||||
if do_sample == True and top_p != None:
|
||||
# 1. Применим softmax, чтобы получить вероятности:
|
||||
probs = F.softmax(logits_scaled, dim=-1) # [B, vocab_size]
|
||||
# 2. Отсортируем токены по убыванию вероятностей:
|
||||
sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1)
|
||||
# 3. Посчитаем кумулятивную сумму вероятностей:
|
||||
cum_probs = torch.cumsum(sorted_probs, dim=-1) # [B, vocab_size]
|
||||
# 4. Определим маску: оставить токены, пока сумма < top_p
|
||||
sorted_mask = (cum_probs <= top_p).bool() if hasattr(torch, "bool") else (cum_probs <= top_p).byte() # [B, vocab_size]
|
||||
# Гарантируем, что хотя бы первый токен останется
|
||||
sorted_mask[:, 0] = True if hasattr(torch, "bool") else 1
|
||||
# 5. Преобразуем маску обратно в оригинальный порядок:
|
||||
# Создаём полную маску из 0
|
||||
mask = torch.zeros_like(probs, dtype=torch.bool if hasattr(torch, "bool") else torch.uint8)
|
||||
# Устанавливаем 1 в местах нужных токенов
|
||||
mask.scatter_(dim=1, index=sorted_indices, src=sorted_mask)
|
||||
# 6. Зануляем логиты токенов вне топ-p:
|
||||
logits_scaled[~mask] = float('-inf')
|
||||
|
||||
# 4. Применяем Softmax
|
||||
probs = F.softmax(logits_scaled, dim=-1) # [batch_size, vocab_size]
|
||||
|
||||
|
||||
if do_sample == True:
|
||||
# 5. Если do_sample равен True, то отбираем токен случайно с помощью torch.multinomial
|
||||
next_token = torch.multinomial(probs, num_samples=1) # [batch_size, 1]
|
||||
else:
|
||||
# 5. Если do_sample равен False, то выбираем токен с максимальной вероятностью
|
||||
next_token = torch.argmax(probs, dim=-1, keepdim=True) # [batch_size, 1]
|
||||
|
||||
# 6. Добавляем его к последовательности
|
||||
x = torch.cat([x, next_token], dim=1) # [batch_size, seq_len+1]
|
||||
return x
|
||||
|
||||
@property
|
||||
def max_seq_len(self) -> int:
|
||||
return self._max_seq_len
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -10,92 +10,94 @@ import json
|
||||
class BaseTokenizer(ABC):
|
||||
"""
|
||||
Абстрактный базовый класс для всех токенизаторов.
|
||||
|
||||
|
||||
Определяет общий интерфейс для токенизации текста.
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
self.vocab: Dict[str, int] = {}
|
||||
self.inverse_vocab: Dict[int, str] = {}
|
||||
self.vocab_size: int = 0
|
||||
|
||||
|
||||
# Специальные токены
|
||||
self.pad_token = "<pad>"
|
||||
self.unk_token = "<unk>"
|
||||
self.bos_token = "<bos>"
|
||||
self.eos_token = "<eos>"
|
||||
|
||||
|
||||
self.pad_token_id: Optional[int] = None
|
||||
self.unk_token_id: Optional[int] = None
|
||||
self.bos_token_id: Optional[int] = None
|
||||
self.eos_token_id: Optional[int] = None
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def train(self, texts: List[str], vocab_size: int = 1000, **kwargs):
|
||||
"""
|
||||
Обучение токенизатора на текстах.
|
||||
|
||||
|
||||
Args:
|
||||
texts: Список текстов для обучения
|
||||
vocab_size: Желаемый размер словаря
|
||||
**kwargs: Дополнительные параметры обучения
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def encode(self, text: str, **kwargs) -> List[int]:
|
||||
"""
|
||||
Кодирование текста в последовательность токенов.
|
||||
|
||||
|
||||
Args:
|
||||
text: Входной текст
|
||||
**kwargs: Дополнительные параметры кодирования
|
||||
|
||||
|
||||
Returns:
|
||||
List[int]: Список идентификаторов токенов
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def decode(self, tokens: List[int], **kwargs) -> str:
|
||||
"""
|
||||
Декодирование последовательности токенов в текст.
|
||||
|
||||
|
||||
Args:
|
||||
tokens: Список идентификаторов токенов
|
||||
**kwargs: Дополнительные параметры декодирования
|
||||
|
||||
|
||||
Returns:
|
||||
str: Декодированный текст
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
def tokenize(self, text: str, **kwargs) -> List[str]:
|
||||
"""
|
||||
Токенизация текста в список строковых токенов.
|
||||
|
||||
|
||||
Args:
|
||||
text: Входной текст
|
||||
**kwargs: Дополнительные параметры
|
||||
|
||||
|
||||
Returns:
|
||||
List[str]: Список токенов
|
||||
"""
|
||||
token_ids = self.encode(text, **kwargs)
|
||||
return [self.inverse_vocab.get(token_id, self.unk_token) for token_id in token_ids]
|
||||
|
||||
return [
|
||||
self.inverse_vocab.get(token_id, self.unk_token) for token_id in token_ids
|
||||
]
|
||||
|
||||
def get_vocab(self) -> Dict[str, int]:
|
||||
"""Возвращает словарь токенизатора."""
|
||||
return self.vocab.copy()
|
||||
|
||||
|
||||
def get_vocab_size(self) -> int:
|
||||
"""Возвращает размер словаря."""
|
||||
return self.vocab_size
|
||||
|
||||
|
||||
def add_special_tokens(self, special_tokens: List[str]):
|
||||
"""
|
||||
Добавляет специальные токены в словарь.
|
||||
|
||||
|
||||
Args:
|
||||
special_tokens: Список специальных токенов
|
||||
"""
|
||||
@@ -105,70 +107,70 @@ class BaseTokenizer(ABC):
|
||||
self.vocab[token] = token_id
|
||||
self.inverse_vocab[token_id] = token
|
||||
self.vocab_size += 1
|
||||
|
||||
|
||||
# Обновляем ID специальных токенов
|
||||
self.pad_token_id = self.vocab.get(self.pad_token)
|
||||
self.unk_token_id = self.vocab.get(self.unk_token)
|
||||
self.bos_token_id = self.vocab.get(self.bos_token)
|
||||
self.eos_token_id = self.vocab.get(self.eos_token)
|
||||
|
||||
|
||||
def save(self, filepath: str):
|
||||
"""
|
||||
Сохраняет токенизатор в файл.
|
||||
|
||||
|
||||
Args:
|
||||
filepath: Путь для сохранения
|
||||
"""
|
||||
config = {
|
||||
'vocab': self.vocab,
|
||||
'vocab_size': self.vocab_size,
|
||||
'pad_token': self.pad_token,
|
||||
'unk_token': self.unk_token,
|
||||
'bos_token': self.bos_token,
|
||||
'eos_token': self.eos_token,
|
||||
'tokenizer_type': self.__class__.__name__
|
||||
"vocab": self.vocab,
|
||||
"vocab_size": self.vocab_size,
|
||||
"pad_token": self.pad_token,
|
||||
"unk_token": self.unk_token,
|
||||
"bos_token": self.bos_token,
|
||||
"eos_token": self.eos_token,
|
||||
"tokenizer_type": self.__class__.__name__,
|
||||
}
|
||||
|
||||
with open(filepath, 'w', encoding='utf-8') as f:
|
||||
|
||||
with open(filepath, "w", encoding="utf-8") as f:
|
||||
json.dump(config, f, ensure_ascii=False, indent=2)
|
||||
|
||||
|
||||
@classmethod
|
||||
def load(cls, filepath: str):
|
||||
"""
|
||||
Загружает токенизатор из файла.
|
||||
|
||||
|
||||
Args:
|
||||
filepath: Путь к файлу
|
||||
|
||||
|
||||
Returns:
|
||||
BaseTokenizer: Загруженный токенизатор
|
||||
"""
|
||||
with open(filepath, 'r', encoding='utf-8') as f:
|
||||
with open(filepath, "r", encoding="utf-8") as f:
|
||||
config = json.load(f)
|
||||
|
||||
|
||||
# Создаем экземпляр токенизатора
|
||||
tokenizer = cls()
|
||||
tokenizer.vocab = config['vocab']
|
||||
tokenizer.vocab_size = config['vocab_size']
|
||||
tokenizer.pad_token = config['pad_token']
|
||||
tokenizer.unk_token = config['unk_token']
|
||||
tokenizer.bos_token = config['bos_token']
|
||||
tokenizer.eos_token = config['eos_token']
|
||||
|
||||
tokenizer.vocab = config["vocab"]
|
||||
tokenizer.vocab_size = config["vocab_size"]
|
||||
tokenizer.pad_token = config["pad_token"]
|
||||
tokenizer.unk_token = config["unk_token"]
|
||||
tokenizer.bos_token = config["bos_token"]
|
||||
tokenizer.eos_token = config["eos_token"]
|
||||
|
||||
# Создаем обратный словарь
|
||||
tokenizer.inverse_vocab = {v: k for k, v in tokenizer.vocab.items()}
|
||||
|
||||
|
||||
# Обновляем ID специальных токенов
|
||||
tokenizer.pad_token_id = tokenizer.vocab.get(tokenizer.pad_token)
|
||||
tokenizer.unk_token_id = tokenizer.vocab.get(tokenizer.unk_token)
|
||||
tokenizer.bos_token_id = tokenizer.vocab.get(tokenizer.bos_token)
|
||||
tokenizer.eos_token_id = tokenizer.vocab.get(tokenizer.eos_token)
|
||||
|
||||
|
||||
return tokenizer
|
||||
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Возвращает размер словаря."""
|
||||
return self.vocab_size
|
||||
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}(vocab_size={self.vocab_size})"
|
||||
|
||||
@@ -1,428 +0,0 @@
|
||||
"""
|
||||
BPE (Byte Pair Encoding) токенизатор.
|
||||
|
||||
Реализация алгоритма BPE для токенизации текста.
|
||||
"""
|
||||
|
||||
import re
|
||||
from collections import defaultdict, Counter
|
||||
from typing import List, Dict, Tuple, Optional
|
||||
from .base_tokenizer import BaseTokenizer
|
||||
|
||||
|
||||
class BPETokenizer(BaseTokenizer):
|
||||
"""
|
||||
BPE токенизатор для обработки текста.
|
||||
|
||||
Реализует алгоритм Byte Pair Encoding для создания субсловных токенов.
|
||||
|
||||
Примеры использования:
|
||||
>>> tokenizer = BPETokenizer()
|
||||
>>> tokenizer.train(["пример текста для обучения"], vocab_size=1000)
|
||||
>>> tokens = tokenizer.encode("новый текст")
|
||||
>>> text = tokenizer.decode(tokens)
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.merges: Dict[Tuple[str, str], int] = {}
|
||||
self.pattern = r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
|
||||
self.compiled_pattern = re.compile(self.pattern, re.UNICODE)
|
||||
|
||||
def train(self, texts: List[str], vocab_size: int = 1000, **kwargs):
|
||||
"""
|
||||
Обучение BPE токенизатора на текстах.
|
||||
|
||||
Args:
|
||||
texts: Список текстов для обучения
|
||||
vocab_size: Желаемый размер словаря
|
||||
**kwargs: Дополнительные параметры
|
||||
- min_frequency: Минимальная частота для мерджа
|
||||
- special_tokens: Список специальных токенов
|
||||
"""
|
||||
# Инициализация базового словаря
|
||||
self._initialize_vocab()
|
||||
|
||||
# Добавляем специальные токены если указаны
|
||||
special_tokens = kwargs.get('special_tokens', [self.pad_token, self.unk_token, self.bos_token, self.eos_token])
|
||||
self.add_special_tokens(special_tokens)
|
||||
|
||||
# Предобработка текстов
|
||||
words = self._preprocess_texts(texts)
|
||||
|
||||
# Получаем начальные токены
|
||||
vocab = self._get_initial_vocab(words)
|
||||
|
||||
# Выполняем BPE мерджи
|
||||
self._perform_merges(vocab, vocab_size, kwargs.get('min_frequency', 2))
|
||||
|
||||
# Строим финальный словарь
|
||||
self._build_final_vocab()
|
||||
|
||||
def _initialize_vocab(self):
|
||||
"""Инициализирует базовый словарь."""
|
||||
self.vocab.clear()
|
||||
self.inverse_vocab.clear()
|
||||
self.merges.clear()
|
||||
self.vocab_size = 0
|
||||
|
||||
def _preprocess_texts(self, texts: List[str]) -> List[List[str]]:
|
||||
"""
|
||||
Предобработка текстов для обучения.
|
||||
|
||||
Args:
|
||||
texts: Список текстов
|
||||
|
||||
Returns:
|
||||
List[List[str]]: Предобработанные слова
|
||||
"""
|
||||
words = []
|
||||
for text in texts:
|
||||
# Базовая нормализация
|
||||
text = text.lower().strip()
|
||||
# Токенизация на слова
|
||||
tokens = self.compiled_pattern.findall(text)
|
||||
words.append(tokens)
|
||||
return words
|
||||
|
||||
def _get_initial_vocab(self, words: List[List[str]]) -> Dict[str, int]:
|
||||
"""
|
||||
Создает начальный словарь из символов.
|
||||
|
||||
Args:
|
||||
words: Список токенизированных текстов
|
||||
|
||||
Returns:
|
||||
Dict[str, int]: Начальный словарь частот
|
||||
"""
|
||||
vocab = Counter()
|
||||
for word_list in words:
|
||||
for word in word_list:
|
||||
# Разбиваем слово на символы и добавляем специальный символ конца слова
|
||||
chars = list(word) + ['</w>']
|
||||
vocab.update([''.join(chars[i:i+1]) for i in range(len(chars))])
|
||||
return vocab
|
||||
|
||||
def _perform_merges(self, vocab: Dict[str, int], target_vocab_size: int, min_frequency: int):
|
||||
"""
|
||||
Выполняет BPE мерджи до достижения целевого размера словаря.
|
||||
|
||||
Args:
|
||||
vocab: Начальный словарь
|
||||
target_vocab_size: Целевой размер словаря
|
||||
min_frequency: Минимальная частота для мерджа
|
||||
"""
|
||||
current_vocab_size = len(vocab) + len(self.vocab)
|
||||
|
||||
while current_vocab_size < target_vocab_size:
|
||||
# Находим наиболее частую пару
|
||||
pairs = self._get_stats(vocab)
|
||||
if not pairs:
|
||||
break
|
||||
|
||||
best_pair = max(pairs, key=pairs.get)
|
||||
if pairs[best_pair] < min_frequency:
|
||||
break
|
||||
|
||||
# Выполняем мердж
|
||||
vocab = self._merge_vocab(vocab, best_pair)
|
||||
self.merges[best_pair] = len(self.merges)
|
||||
current_vocab_size += 1
|
||||
|
||||
def _get_stats(self, vocab: Dict[str, int]) -> Dict[Tuple[str, str], int]:
|
||||
"""
|
||||
Собирает статистику по парам символов.
|
||||
|
||||
Args:
|
||||
vocab: Словарь токенов
|
||||
|
||||
Returns:
|
||||
Dict[Tuple[str, str], int]: Частоты пар
|
||||
"""
|
||||
pairs = defaultdict(int)
|
||||
for word, freq in vocab.items():
|
||||
symbols = word.split()
|
||||
for i in range(len(symbols) - 1):
|
||||
pairs[symbols[i], symbols[i + 1]] += freq
|
||||
return pairs
|
||||
|
||||
def _merge_vocab(self, vocab: Dict[str, int], pair: Tuple[str, str]) -> Dict[str, int]:
|
||||
"""
|
||||
Объединяет пару символов в словаре.
|
||||
|
||||
Args:
|
||||
vocab: Исходный словарь
|
||||
pair: Пара для объединения
|
||||
|
||||
Returns:
|
||||
Dict[str, int]: Обновленный словарь
|
||||
"""
|
||||
new_vocab = {}
|
||||
bigram = re.compile(r'(?<!\\S)' + re.escape(pair[0]) + r' ' + re.escape(pair[1]) + r'(?!\\S)')
|
||||
replacement = pair[0] + pair[1]
|
||||
|
||||
for word in vocab:
|
||||
new_word = bigram.sub(replacement, word)
|
||||
new_vocab[new_word] = vocab[word]
|
||||
|
||||
return new_vocab
|
||||
|
||||
def _build_final_vocab(self):
|
||||
"""Строит финальный словарь токенизатора."""
|
||||
# Собираем все уникальные токены из мерджей
|
||||
all_tokens = set()
|
||||
|
||||
# Добавляем специальные токены
|
||||
all_tokens.update([self.pad_token, self.unk_token, self.bos_token, self.eos_token])
|
||||
|
||||
# Добавляем токены из мерджей
|
||||
for pair in self.merges:
|
||||
all_tokens.update(pair)
|
||||
|
||||
# Создаем словарь
|
||||
for i, token in enumerate(sorted(all_tokens)):
|
||||
self.vocab[token] = i
|
||||
|
||||
self.inverse_vocab = {v: k for k, v in self.vocab.items()}
|
||||
self.vocab_size = len(self.vocab)
|
||||
|
||||
# Обновляем ID специальных токенов
|
||||
self.pad_token_id = self.vocab.get(self.pad_token)
|
||||
self.unk_token_id = self.vocab.get(self.unk_token)
|
||||
self.bos_token_id = self.vocab.get(self.bos_token)
|
||||
self.eos_token_id = self.vocab.get(self.eos_token)
|
||||
|
||||
def encode(self, text: str, **kwargs) -> List[int]:
|
||||
"""
|
||||
Кодирует текст в последовательность токенов.
|
||||
|
||||
Args:
|
||||
text: Входной текст
|
||||
**kwargs: Дополнительные параметры
|
||||
- add_special_tokens: Добавлять специальные токены
|
||||
|
||||
Returns:
|
||||
List[int]: Список идентификаторов токенов
|
||||
"""
|
||||
add_special_tokens = kwargs.get('add_special_tokens', False)
|
||||
|
||||
# Токенизация текста
|
||||
tokens = self.compiled_pattern.findall(text)
|
||||
|
||||
# Применяем BPE к каждому токену
|
||||
bpe_tokens = []
|
||||
for token in tokens:
|
||||
# Преобразуем токен в BPE представление
|
||||
bpe_token = self._apply_bpe(token)
|
||||
bpe_tokens.extend(bpe_token)
|
||||
|
||||
# Конвертируем в ID
|
||||
token_ids = []
|
||||
for token in bpe_tokens:
|
||||
token_id = self.vocab.get(token, self.unk_token_id)
|
||||
if token_id is not None:
|
||||
token_ids.append(token_id)
|
||||
|
||||
# Добавляем специальные токены если нужно
|
||||
if add_special_tokens:
|
||||
if self.bos_token_id is not None:
|
||||
token_ids.insert(0, self.bos_token_id)
|
||||
if self.eos_token_id is not None:
|
||||
token_ids.append(self.eos_token_id)
|
||||
|
||||
return token_ids
|
||||
|
||||
def _apply_bpe(self, token: str) -> List[str]:
|
||||
"""
|
||||
Применяет BPE к одному токену.
|
||||
|
||||
Args:
|
||||
token: Входной токен
|
||||
|
||||
Returns:
|
||||
List[str]: Список BPE токенов
|
||||
"""
|
||||
# Простая реализация - в реальной реализации нужно применять обученные мерджи
|
||||
word = token + '</w>'
|
||||
tokens = [word[i:i+1] for i in range(len(word))]
|
||||
|
||||
# Применяем мерджи (упрощенная версия)
|
||||
# В полной реализации нужно применять все обученные мерджи
|
||||
for pair in self.merges:
|
||||
i = 0
|
||||
while i < len(tokens) - 1:
|
||||
if tokens[i] == pair[0] and tokens[i + 1] == pair[1]:
|
||||
tokens[i] = tokens[i] + tokens[i + 1]
|
||||
del tokens[i + 1]
|
||||
else:
|
||||
i += 1
|
||||
|
||||
return tokens
|
||||
|
||||
def decode(self, tokens: List[int], **kwargs) -> str:
|
||||
"""
|
||||
Декодирует последовательность токенов в текст.
|
||||
|
||||
Args:
|
||||
tokens: Список идентификаторов токенов
|
||||
**kwargs: Дополнительные параметры
|
||||
- skip_special_tokens: Пропускать специальные токены
|
||||
|
||||
Returns:
|
||||
str: Декодированный текст
|
||||
"""
|
||||
skip_special_tokens = kwargs.get('skip_special_tokens', True)
|
||||
|
||||
# Конвертируем ID в токены
|
||||
token_strings = []
|
||||
for token_id in tokens:
|
||||
token = self.inverse_vocab.get(token_id, self.unk_token)
|
||||
|
||||
# Пропускаем специальные токены если нужно
|
||||
if skip_special_tokens and token in [self.pad_token, self.unk_token, self.bos_token, self.eos_token]:
|
||||
continue
|
||||
|
||||
token_strings.append(token)
|
||||
|
||||
# Объединяем токены в текст
|
||||
text = ''.join(token_strings)
|
||||
|
||||
# Убираем маркер конца слова
|
||||
text = text.replace('</w>', ' ')
|
||||
|
||||
return text.strip()
|
||||
|
||||
def save(self, filepath: str):
|
||||
"""
|
||||
Сохраняет BPE токенизатор в файл.
|
||||
|
||||
Args:
|
||||
filepath: Путь для сохранения
|
||||
"""
|
||||
import json
|
||||
|
||||
config = {
|
||||
'vocab': self.vocab,
|
||||
'merges': {f"{k[0]} {k[1]}": v for k, v in self.merges.items()},
|
||||
'vocab_size': self.vocab_size,
|
||||
'pad_token': self.pad_token,
|
||||
'unk_token': self.unk_token,
|
||||
'bos_token': self.bos_token,
|
||||
'eos_token': self.eos_token,
|
||||
'pattern': self.pattern,
|
||||
'tokenizer_type': self.__class__.__name__
|
||||
}
|
||||
|
||||
with open(filepath, 'w', encoding='utf-8') as f:
|
||||
json.dump(config, f, ensure_ascii=False, indent=2)
|
||||
|
||||
@classmethod
|
||||
def load(cls, filepath: str):
|
||||
"""
|
||||
Загружает BPE токенизатор из файла.
|
||||
|
||||
Args:
|
||||
filepath: Путь к файлу
|
||||
|
||||
Returns:
|
||||
BPETokenizer: Загруженный токенизатор
|
||||
"""
|
||||
import json
|
||||
|
||||
with open(filepath, 'r', encoding='utf-8') as f:
|
||||
config = json.load(f)
|
||||
|
||||
tokenizer = cls()
|
||||
tokenizer.vocab = config['vocab']
|
||||
tokenizer.vocab_size = config['vocab_size']
|
||||
tokenizer.pad_token = config['pad_token']
|
||||
tokenizer.unk_token = config['unk_token']
|
||||
tokenizer.bos_token = config['bos_token']
|
||||
tokenizer.eos_token = config['eos_token']
|
||||
tokenizer.pattern = config.get('pattern', tokenizer.pattern)
|
||||
tokenizer.compiled_pattern = re.compile(tokenizer.pattern, re.UNICODE)
|
||||
|
||||
# Восстанавливаем мерджи
|
||||
merges = config.get('merges', {})
|
||||
tokenizer.merges = {}
|
||||
for k, v in merges.items():
|
||||
parts = k.split()
|
||||
if len(parts) == 2:
|
||||
tokenizer.merges[(parts[0], parts[1])] = v
|
||||
|
||||
# Создаем обратный словарь
|
||||
tokenizer.inverse_vocab = {v: k for k, v in tokenizer.vocab.items()}
|
||||
|
||||
# Обновляем ID специальных токенов
|
||||
tokenizer.pad_token_id = tokenizer.vocab.get(tokenizer.pad_token)
|
||||
tokenizer.unk_token_id = tokenizer.vocab.get(tokenizer.unk_token)
|
||||
tokenizer.bos_token_id = tokenizer.vocab.get(tokenizer.bos_token)
|
||||
tokenizer.eos_token_id = tokenizer.vocab.get(tokenizer.eos_token)
|
||||
|
||||
return tokenizer
|
||||
|
||||
|
||||
# Упрощенная версия для быстрого старта
|
||||
class SimpleBPETokenizer(BPETokenizer):
|
||||
"""
|
||||
Упрощенная версия BPE токенизатора для демонстрации.
|
||||
"""
|
||||
|
||||
def train(self, texts: List[str], vocab_size: int = 1000, **kwargs):
|
||||
"""Упрощенное обучение для демонстрации."""
|
||||
# Инициализация базового словаря
|
||||
self._initialize_vocab()
|
||||
|
||||
# Добавляем базовые токены
|
||||
special_tokens = [self.pad_token, self.unk_token, self.bos_token, self.eos_token]
|
||||
self.add_special_tokens(special_tokens)
|
||||
|
||||
# Простая реализация - собираем все символы
|
||||
all_chars = set()
|
||||
for text in texts:
|
||||
all_chars.update(text)
|
||||
|
||||
# Добавляем символы в словарь
|
||||
for char in sorted(all_chars):
|
||||
if char not in self.vocab:
|
||||
self.vocab[char] = len(self.vocab)
|
||||
|
||||
self.inverse_vocab = {v: k for k, v in self.vocab.items()}
|
||||
self.vocab_size = len(self.vocab)
|
||||
|
||||
# Обновляем ID специальных токенов
|
||||
self.pad_token_id = self.vocab.get(self.pad_token)
|
||||
self.unk_token_id = self.vocab.get(self.unk_token)
|
||||
self.bos_token_id = self.vocab.get(self.bos_token)
|
||||
self.eos_token_id = self.vocab.get(self.eos_token)
|
||||
|
||||
def encode(self, text: str, **kwargs) -> List[int]:
|
||||
"""Упрощенное кодирование - разбиваем на символы."""
|
||||
add_special_tokens = kwargs.get('add_special_tokens', False)
|
||||
|
||||
token_ids = []
|
||||
for char in text:
|
||||
token_id = self.vocab.get(char, self.unk_token_id)
|
||||
if token_id is not None:
|
||||
token_ids.append(token_id)
|
||||
|
||||
if add_special_tokens:
|
||||
if self.bos_token_id is not None:
|
||||
token_ids.insert(0, self.bos_token_id)
|
||||
if self.eos_token_id is not None:
|
||||
token_ids.append(self.eos_token_id)
|
||||
|
||||
return token_ids
|
||||
|
||||
def decode(self, tokens: List[int], **kwargs) -> str:
|
||||
"""Упрощенное декодирование."""
|
||||
skip_special_tokens = kwargs.get('skip_special_tokens', True)
|
||||
|
||||
chars = []
|
||||
for token_id in tokens:
|
||||
char = self.inverse_vocab.get(token_id, self.unk_token)
|
||||
if skip_special_tokens and char in [self.pad_token, self.unk_token, self.bos_token, self.eos_token]:
|
||||
continue
|
||||
chars.append(char)
|
||||
|
||||
return ''.join(chars)
|
||||
@@ -10,27 +10,65 @@ from .base_tokenizer import BaseTokenizer
|
||||
|
||||
class BPETokenizer(BaseTokenizer):
|
||||
"""
|
||||
BPE токенизатор для обработки текста.
|
||||
|
||||
Реализует алгоритм Byte Pair Encoding для создания субсловных токенов.
|
||||
Использует вашу реализацию BPE.
|
||||
|
||||
Примеры использования:
|
||||
>>> tokenizer = BPETokenizer()
|
||||
>>> tokenizer.train(["пример текста для обучения"], vocab_size=1000)
|
||||
>>> tokens = tokenizer.encode("новый текст")
|
||||
BpeTokenizer — реализация токенизатора на алгоритме byte pair encoding (BPE).
|
||||
|
||||
Назначение:
|
||||
-----------
|
||||
- Преобразует открытый текст (строки, bytes) в последовательность числовых токенов для подачи в LLM и обратно.
|
||||
- Разбивает текст на сабслова (байтовые пары), эффективно кодируя редкие слова длинными последовательностями, а частые — единичными токенами.
|
||||
- Является стандартом де-факто в современных языковых моделях (GPT, LLaMA, BLOOM, Mistral, HuggingFace).
|
||||
|
||||
Как работает BPE:
|
||||
-----------------
|
||||
1. Строится словарь из наиболее популярных пар символов/субстрок.
|
||||
2. Текст замещается наиболее длинными subword-подстроками из vocabulary (жадно).
|
||||
3. Итог: многомиллионное лексическое пространство сокращается до компактного набора subword pieces.
|
||||
|
||||
Особенности алгоритма:
|
||||
----------------------
|
||||
- Отлично работает на всех языках, включая rare/compound/inflectable.
|
||||
- Гибко масштабируется под размер итогового словаря/token space.
|
||||
- Обычно хранит mapping (str/bytes → int и int → str/bytes) в JSON или словарном файле.
|
||||
- Может использовать кастомные сепараторы, handle unknown.
|
||||
|
||||
Аргументы конструктора:
|
||||
-----------------------
|
||||
vocab_path: str
|
||||
Путь к файлу BPE vocabulary (JSON, txt, в зависимости от реализации).
|
||||
merges_path: str, optional
|
||||
Путь к списку merge-правил (если используется блочное файловое раздельное хранение).
|
||||
unk_token: str, optional
|
||||
Токен для неизвестных последовательностей (по дефолту '[UNK]' или '<unk>').
|
||||
pad_token, bos_token, eos_token: str, optional
|
||||
Special tokens, если нужны для вашей архитектуры.
|
||||
lowercase: bool, optional
|
||||
Приводить ли текст к нижнему регистру перед токенизацией.
|
||||
|
||||
Пример:
|
||||
-------
|
||||
>>> tokenizer = BpeTokenizer(vocab_path=\"bpe_vocab.json\")
|
||||
>>> tokens = tokenizer.encode(\"Hello, world!\")
|
||||
>>> print(tokens) # [15496, 11, ...]
|
||||
>>> text = tokenizer.decode(tokens)
|
||||
>>> print(text) # 'Hello, world!'
|
||||
|
||||
References:
|
||||
-----------
|
||||
- Sennrich et al, \"Neural Machine Translation of Rare Words with Subword Units\", 2015: https://arxiv.org/abs/1508.07909
|
||||
- GPT-2 tokenization: https://github.com/openai/gpt-2
|
||||
- HuggingFace tokenizers overview: https://huggingface.co/docs/tokenizers/index
|
||||
- Visually: https://guillaume-be.github.io/2021-05-21/byte-pair-encoding/
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.merges: Dict[Tuple[str, str], int] = {}
|
||||
self.vocab_list: List[str] = []
|
||||
|
||||
|
||||
def train(self, texts: List[str], vocab_size: int = 1000, **kwargs):
|
||||
"""
|
||||
Обучение BPE токенизатора на текстах.
|
||||
|
||||
|
||||
Args:
|
||||
texts: Список текстов для обучения
|
||||
vocab_size: Желаемый размер словаря
|
||||
@@ -39,7 +77,7 @@ class BPETokenizer(BaseTokenizer):
|
||||
"""
|
||||
# Объединяем все тексты в одну строку для обучения
|
||||
combined_text = " ".join(texts)
|
||||
|
||||
|
||||
# 1. Получаем уникальные токены (символы)
|
||||
unique_tokens = sorted(set(combined_text))
|
||||
tokens = unique_tokens.copy()
|
||||
@@ -61,7 +99,10 @@ class BPETokenizer(BaseTokenizer):
|
||||
break # нет пар — выходим
|
||||
|
||||
# Находим самую частую пару (в случае равенства — та, что встретилась первой)
|
||||
most_frequent_pair = max(pair_freq.items(), key=lambda x: (x[1], -self._pair_first_index(sequence, x[0])))[0]
|
||||
most_frequent_pair = max(
|
||||
pair_freq.items(),
|
||||
key=lambda x: (x[1], -self._pair_first_index(sequence, x[0])),
|
||||
)[0]
|
||||
|
||||
# Создаем новый токен
|
||||
new_token = most_frequent_pair[0] + most_frequent_pair[1]
|
||||
@@ -71,45 +112,57 @@ class BPETokenizer(BaseTokenizer):
|
||||
new_sequence = []
|
||||
|
||||
while i < len(sequence):
|
||||
if i < len(sequence) - 1 and (sequence[i], sequence[i + 1]) == most_frequent_pair:
|
||||
if (
|
||||
i < len(sequence) - 1
|
||||
and (sequence[i], sequence[i + 1]) == most_frequent_pair
|
||||
):
|
||||
new_sequence.append(new_token)
|
||||
i += 2 # пропускаем два символа — заменённую пару
|
||||
else:
|
||||
new_sequence.append(sequence[i])
|
||||
i += 1
|
||||
sequence = new_sequence
|
||||
|
||||
|
||||
# 4. Создаем словари
|
||||
self.vocab_list = tokens.copy()
|
||||
self.vocab = dict(zip(tokens, range(vocab_size)))
|
||||
self.inverse_vocab = dict(zip(range(vocab_size), tokens))
|
||||
self.vocab_size = len(self.vocab)
|
||||
|
||||
|
||||
# Добавляем специальные токены если указаны
|
||||
special_tokens = kwargs.get('special_tokens', [self.pad_token, self.unk_token, self.bos_token, self.eos_token])
|
||||
special_tokens = kwargs.get(
|
||||
"special_tokens",
|
||||
[self.pad_token, self.unk_token, self.bos_token, self.eos_token],
|
||||
)
|
||||
self.add_special_tokens(special_tokens)
|
||||
|
||||
|
||||
def _pair_first_index(self, sequence, pair):
|
||||
"""Находит первый индекс пары в последовательности."""
|
||||
for i in range(len(sequence) - 1):
|
||||
if (sequence[i], sequence[i + 1]) == pair:
|
||||
return i
|
||||
return float('inf') # если пара не найдена (в теории не должно случиться)
|
||||
return float("inf") # если пара не найдена (в теории не должно случиться)
|
||||
|
||||
def encode(self, text: str, **kwargs) -> List[int]:
|
||||
"""
|
||||
Кодирует текст в последовательность токенов.
|
||||
|
||||
Токенизирует входной текст в список числовых токенов (индексов).
|
||||
|
||||
Args:
|
||||
text: Входной текст
|
||||
**kwargs: Дополнительные параметры
|
||||
- add_special_tokens: Добавлять специальные токены
|
||||
|
||||
-----
|
||||
text: str
|
||||
Входная строка/текст для токенизации.
|
||||
|
||||
Returns:
|
||||
List[int]: Список идентификаторов токенов
|
||||
--------
|
||||
List[int] — последовательность индексов из vocabulary.
|
||||
|
||||
Пример:
|
||||
-------
|
||||
>>> ids = tokenizer.encode(\"The quick brown fox\")
|
||||
>>> print(ids)
|
||||
"""
|
||||
add_special_tokens = kwargs.get('add_special_tokens', False)
|
||||
|
||||
add_special_tokens = kwargs.get("add_special_tokens", False)
|
||||
|
||||
# 1. Разбиваем текст на токены-символы
|
||||
sequence = list(text)
|
||||
# 2. Инициализация пустого списка токенов
|
||||
@@ -119,7 +172,9 @@ class BPETokenizer(BaseTokenizer):
|
||||
while i < len(text):
|
||||
# 3.1 Найти все токены в словаре, начинающиеся с text[i]
|
||||
start_char = text[i]
|
||||
result = [token for token in self.vocab_list if token.startswith(start_char)]
|
||||
result = [
|
||||
token for token in self.vocab_list if token.startswith(start_char)
|
||||
]
|
||||
# 3.2 Выбрать самый длинный подходящий токен
|
||||
find_token = self._find_max_matching_token(text[i:], result)
|
||||
if find_token is None:
|
||||
@@ -134,19 +189,19 @@ class BPETokenizer(BaseTokenizer):
|
||||
|
||||
# 4. Заменить токены на их ID
|
||||
token_ids = self._tokens_to_ids(tokens)
|
||||
|
||||
|
||||
# Заменяем -1 на unk_token_id
|
||||
token_ids = [tid if tid != -1 else self.unk_token_id for tid in token_ids]
|
||||
|
||||
|
||||
# Добавляем специальные токены если нужно
|
||||
if add_special_tokens:
|
||||
if self.bos_token_id is not None:
|
||||
token_ids.insert(0, self.bos_token_id)
|
||||
if self.eos_token_id is not None:
|
||||
token_ids.append(self.eos_token_id)
|
||||
|
||||
|
||||
return token_ids
|
||||
|
||||
|
||||
def _find_max_matching_token(self, text: str, tokens: list) -> Optional[str]:
|
||||
"""Находит самый длинный токен из списка, с которого начинается текст"""
|
||||
matching = [token for token in tokens if text.startswith(token)]
|
||||
@@ -161,33 +216,48 @@ class BPETokenizer(BaseTokenizer):
|
||||
else:
|
||||
ids.append(-1) # Специальное значение
|
||||
return ids
|
||||
|
||||
|
||||
def decode(self, tokens: List[int], **kwargs) -> str:
|
||||
"""
|
||||
Декодирует последовательность токенов в текст.
|
||||
|
||||
Декодирует последовательность токенов обратно в текстовую строку.
|
||||
|
||||
Args:
|
||||
tokens: Список идентификаторов токенов
|
||||
**kwargs: Дополнительные параметры
|
||||
- skip_special_tokens: Пропускать специальные токены
|
||||
|
||||
-----
|
||||
ids: List[int]
|
||||
Список токен-индексов для распаковки.
|
||||
|
||||
Returns:
|
||||
str: Декодированный текст
|
||||
--------
|
||||
text: str
|
||||
Оригинальный (или приближённый) раскодированный текст.
|
||||
|
||||
Пример:
|
||||
-------
|
||||
>>> tokens = [15496, 11, 318, ...]
|
||||
>>> text = tokenizer.decode(tokens)
|
||||
"""
|
||||
skip_special_tokens = kwargs.get('skip_special_tokens', True)
|
||||
|
||||
skip_special_tokens = kwargs.get("skip_special_tokens", True)
|
||||
|
||||
# Фильтруем специальные токены если нужно
|
||||
if skip_special_tokens:
|
||||
tokens = [tid for tid in tokens if tid not in [
|
||||
self.pad_token_id, self.unk_token_id, self.bos_token_id, self.eos_token_id
|
||||
]]
|
||||
|
||||
tokens = [
|
||||
tid
|
||||
for tid in tokens
|
||||
if tid
|
||||
not in [
|
||||
self.pad_token_id,
|
||||
self.unk_token_id,
|
||||
self.bos_token_id,
|
||||
self.eos_token_id,
|
||||
]
|
||||
]
|
||||
|
||||
# Конвертируем ID в токены
|
||||
token_strings = self._ids_to_tokens(tokens)
|
||||
|
||||
|
||||
# Объединяем токены в текст
|
||||
return ''.join(token_strings)
|
||||
|
||||
return "".join(token_strings)
|
||||
|
||||
def _ids_to_tokens(self, ids: List[int]) -> List[str]:
|
||||
"""Конвертирует список Ids в их tokens"""
|
||||
tokens = []
|
||||
@@ -197,76 +267,76 @@ class BPETokenizer(BaseTokenizer):
|
||||
else:
|
||||
tokens.append(self.unk_token) # Специальное значение
|
||||
return tokens
|
||||
|
||||
|
||||
def save(self, filepath: str):
|
||||
"""
|
||||
Сохраняет токенизатор в файл.
|
||||
|
||||
|
||||
Args:
|
||||
filepath: Путь для сохранения
|
||||
"""
|
||||
import json
|
||||
|
||||
|
||||
# Преобразуем кортежи в строки для JSON сериализации
|
||||
merges_serializable = {f"{k[0]},{k[1]}": v for k, v in self.merges.items()}
|
||||
|
||||
|
||||
config = {
|
||||
'vocab': self.vocab,
|
||||
'vocab_size': self.vocab_size,
|
||||
'pad_token': self.pad_token,
|
||||
'unk_token': self.unk_token,
|
||||
'bos_token': self.bos_token,
|
||||
'eos_token': self.eos_token,
|
||||
'tokenizer_type': self.__class__.__name__,
|
||||
'merges': merges_serializable,
|
||||
'vocab_list': self.vocab_list
|
||||
"vocab": self.vocab,
|
||||
"vocab_size": self.vocab_size,
|
||||
"pad_token": self.pad_token,
|
||||
"unk_token": self.unk_token,
|
||||
"bos_token": self.bos_token,
|
||||
"eos_token": self.eos_token,
|
||||
"tokenizer_type": self.__class__.__name__,
|
||||
"merges": merges_serializable,
|
||||
"vocab_list": self.vocab_list,
|
||||
}
|
||||
|
||||
with open(filepath, 'w', encoding='utf-8') as f:
|
||||
|
||||
with open(filepath, "w", encoding="utf-8") as f:
|
||||
json.dump(config, f, ensure_ascii=False, indent=2)
|
||||
|
||||
|
||||
@classmethod
|
||||
def load(cls, filepath: str):
|
||||
"""
|
||||
Загружает токенизатор из файла.
|
||||
|
||||
|
||||
Args:
|
||||
filepath: Путь к файлу
|
||||
|
||||
|
||||
Returns:
|
||||
BPETokenizer: Загруженный токенизатор
|
||||
"""
|
||||
import json
|
||||
|
||||
with open(filepath, 'r', encoding='utf-8') as f:
|
||||
|
||||
with open(filepath, "r", encoding="utf-8") as f:
|
||||
config = json.load(f)
|
||||
|
||||
|
||||
# Создаем экземпляр токенизатора
|
||||
tokenizer = cls()
|
||||
tokenizer.vocab = config['vocab']
|
||||
tokenizer.vocab_size = config['vocab_size']
|
||||
tokenizer.pad_token = config['pad_token']
|
||||
tokenizer.unk_token = config['unk_token']
|
||||
tokenizer.bos_token = config['bos_token']
|
||||
tokenizer.eos_token = config['eos_token']
|
||||
tokenizer.vocab_list = config['vocab_list']
|
||||
|
||||
tokenizer.vocab = config["vocab"]
|
||||
tokenizer.vocab_size = config["vocab_size"]
|
||||
tokenizer.pad_token = config["pad_token"]
|
||||
tokenizer.unk_token = config["unk_token"]
|
||||
tokenizer.bos_token = config["bos_token"]
|
||||
tokenizer.eos_token = config["eos_token"]
|
||||
tokenizer.vocab_list = config["vocab_list"]
|
||||
|
||||
# Восстанавливаем кортежи из строк
|
||||
tokenizer.merges = {}
|
||||
for k, v in config['merges'].items():
|
||||
parts = k.split(',')
|
||||
for k, v in config["merges"].items():
|
||||
parts = k.split(",")
|
||||
if len(parts) == 2:
|
||||
tokenizer.merges[(parts[0], parts[1])] = v
|
||||
|
||||
|
||||
# Создаем обратный словарь
|
||||
tokenizer.inverse_vocab = {v: k for k, v in tokenizer.vocab.items()}
|
||||
|
||||
|
||||
# Обновляем ID специальных токенов
|
||||
tokenizer.pad_token_id = tokenizer.vocab.get(tokenizer.pad_token)
|
||||
tokenizer.unk_token_id = tokenizer.vocab.get(tokenizer.unk_token)
|
||||
tokenizer.bos_token_id = tokenizer.vocab.get(tokenizer.bos_token)
|
||||
tokenizer.eos_token_id = tokenizer.vocab.get(tokenizer.eos_token)
|
||||
|
||||
|
||||
return tokenizer
|
||||
|
||||
|
||||
@@ -275,4 +345,5 @@ class SimpleBPETokenizer(BPETokenizer):
|
||||
Упрощенная версия BPE токенизатора для демонстрации.
|
||||
Наследует вашу реализацию, но может быть упрощена при необходимости.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
@@ -1,142 +0,0 @@
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
from typing import List, Any
|
||||
|
||||
|
||||
class TextDataset(Dataset):
|
||||
"""
|
||||
Простой датасет для языкового моделирования (LLM).
|
||||
Работает с любым токенизатором, реализующим интерфейс BaseTokenizer.
|
||||
"""
|
||||
|
||||
def __init__(self, texts: List[str], tokenizer: Any, block_size: int = 128):
|
||||
"""
|
||||
Инициализация датасета.
|
||||
|
||||
Args:
|
||||
texts: Список текстов для обучения
|
||||
tokenizer: Токенизатор с методами encode/decode
|
||||
block_size: Максимальная длина последовательности
|
||||
"""
|
||||
self.examples = []
|
||||
self.tokenizer = tokenizer
|
||||
self.block_size = block_size
|
||||
|
||||
for text in texts:
|
||||
# Кодируем текст в токены
|
||||
input_ids = tokenizer.encode(text, add_special_tokens=False)
|
||||
|
||||
# Обрезаем или дополняем до нужной длины
|
||||
if len(input_ids) > block_size:
|
||||
input_ids = input_ids[:block_size]
|
||||
else:
|
||||
# Дополняем pad_token_id
|
||||
pad_token_id = getattr(tokenizer, 'pad_token_id', 0)
|
||||
input_ids = input_ids + [pad_token_id] * (block_size - len(input_ids))
|
||||
|
||||
self.examples.append(input_ids)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.examples)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
input_ids = torch.tensor(self.examples[idx], dtype=torch.long)
|
||||
labels = input_ids.clone()
|
||||
return {"input_ids": input_ids, "labels": labels}
|
||||
|
||||
|
||||
class StreamingTextDataset(Dataset):
|
||||
"""
|
||||
Датасет для потоковой обработки больших текстов.
|
||||
Токенизация происходит на лету, что экономит память.
|
||||
"""
|
||||
|
||||
def __init__(self, texts: List[str], tokenizer: Any, block_size: int = 128):
|
||||
self.texts = texts
|
||||
self.tokenizer = tokenizer
|
||||
self.block_size = block_size
|
||||
|
||||
# Получаем pad_token_id из токенизатора
|
||||
self.pad_token_id = getattr(tokenizer, 'pad_token_id', 0)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.texts)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
text = self.texts[idx]
|
||||
|
||||
# Токенизация на лету
|
||||
input_ids = self.tokenizer.encode(text, add_special_tokens=False)
|
||||
|
||||
# Обрезаем или дополняем до нужной длины
|
||||
if len(input_ids) > self.block_size:
|
||||
input_ids = input_ids[:self.block_size]
|
||||
else:
|
||||
input_ids = input_ids + [self.pad_token_id] * (self.block_size - len(input_ids))
|
||||
|
||||
input_ids = torch.tensor(input_ids, dtype=torch.long)
|
||||
labels = input_ids.clone()
|
||||
|
||||
return {"input_ids": input_ids, "labels": labels}
|
||||
|
||||
|
||||
class TextDatasetWithSpecialTokens(TextDataset):
|
||||
"""
|
||||
Расширенная версия TextDataset с поддержкой специальных токенов.
|
||||
"""
|
||||
|
||||
def __init__(self, texts: List[str], tokenizer: Any, block_size: int = 128,
|
||||
add_bos: bool = False, add_eos: bool = False):
|
||||
"""
|
||||
Args:
|
||||
texts: Список текстов
|
||||
tokenizer: Токенизатор
|
||||
block_size: Максимальная длина
|
||||
add_bos: Добавлять токен начала последовательности
|
||||
add_eos: Добавлять токен конца последовательности
|
||||
"""
|
||||
self.examples = []
|
||||
self.tokenizer = tokenizer
|
||||
self.block_size = block_size
|
||||
self.add_bos = add_bos
|
||||
self.add_eos = add_eos
|
||||
|
||||
for text in texts:
|
||||
# Кодируем с специальными токенами
|
||||
input_ids = tokenizer.encode(
|
||||
text,
|
||||
add_special_tokens=True,
|
||||
add_bos_token=add_bos,
|
||||
add_eos_token=eos
|
||||
)
|
||||
|
||||
# Учитываем специальные токены при обрезке/дополнении
|
||||
effective_block_size = block_size
|
||||
if add_bos:
|
||||
effective_block_size -= 1
|
||||
if add_eos:
|
||||
effective_block_size -= 1
|
||||
|
||||
if len(input_ids) > effective_block_size:
|
||||
input_ids = input_ids[:effective_block_size]
|
||||
|
||||
# Добавляем специальные токены если нужно
|
||||
if add_bos and hasattr(tokenizer, 'bos_token_id') and tokenizer.bos_token_id is not None:
|
||||
input_ids = [tokenizer.bos_token_id] + input_ids
|
||||
if add_eos and hasattr(tokenizer, 'eos_token_id') and tokenizer.eos_token_id is not None:
|
||||
input_ids = input_ids + [tokenizer.eos_token_id]
|
||||
|
||||
# Дополняем до полной длины
|
||||
pad_token_id = getattr(tokenizer, 'pad_token_id', 0)
|
||||
if len(input_ids) < block_size:
|
||||
input_ids = input_ids + [pad_token_id] * (block_size - len(input_ids))
|
||||
|
||||
self.examples.append(input_ids)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.examples)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
input_ids = torch.tensor(self.examples[idx], dtype=torch.long)
|
||||
labels = input_ids.clone()
|
||||
return {"input_ids": input_ids, "labels": labels}
|
||||
@@ -1,8 +1,71 @@
|
||||
"""
|
||||
Модуль оптимизации для обучения нейронных сетей.
|
||||
|
||||
В данном модуле реализована функция выбора и инициализации оптимизаторов, наиболее популярных при обучении глубоких нейросетей:
|
||||
- AdamW
|
||||
- Adam
|
||||
- SGD
|
||||
|
||||
Теоретическое обоснование:
|
||||
--------------------------
|
||||
Задача оптимизации в обучении нейросети заключается в минимизации функции потерь (Loss) по параметрам модели W. Современные методы базируются на стохастическом градиентном спуске (SGD), а также на его адаптивных модификациях (Adam, AdamW).
|
||||
|
||||
**SGD** (Stochastic Gradient Descent) — стохастический градиентный спуск:
|
||||
W_{t+1} = W_t - \eta \nabla_W L(W_t)
|
||||
Здесь \eta — шаг обучения, \nabla_W — градиент по параметрам. SGD позволяет случайно выбирать подмножество обучающих данных для каждой итерации, что ускоряет процесс и уменьшает избыточную корреляцию между примерами.
|
||||
|
||||
**Adam** (Adaptive Moment Estimation) — адаптивный алгоритм, который использует скользящую среднюю не только градиентов, но и их квадратов:
|
||||
m_t = \beta_1 m_{t-1} + (1-\beta_1) \nabla_W L(W_t)
|
||||
v_t = \beta_2 v_{t-1} + (1-\beta_2) (\nabla_W L(W_t))^2
|
||||
W_{t+1} = W_t - \eta m_t/(\sqrt{v_t}+\epsilon)
|
||||
Где \beta_1, \beta_2 — коэффициенты экспоненциального сглаживания.
|
||||
|
||||
**AdamW** — модификация Adam, в которой weight decay (имплицитная L2-регуляризация) вводится корректно, отдельно от шага градиента, что улучшает обобщающую способность моделей:
|
||||
W_{t+1} = W_t - \eta [ m_t/(\sqrt{v_t}+\epsilon) + \lambda W_t ]
|
||||
Где \lambda — коэффициент weight decay.
|
||||
|
||||
Детальное описание: https://arxiv.org/abs/1711.05101
|
||||
|
||||
Пример использования:
|
||||
---------------------
|
||||
>>> optimizer = get_optimizer(model, lr=3e-4, weight_decay=0.01, optimizer_type="adamw")
|
||||
>>> for batch in dataloader:
|
||||
... loss = model(batch)
|
||||
... loss.backward()
|
||||
... optimizer.step()
|
||||
... optimizer.zero_grad()
|
||||
|
||||
"""
|
||||
import torch.optim as optim
|
||||
|
||||
|
||||
def get_optimizer(model, lr=3e-4, weight_decay=0.01, optimizer_type="adamw"):
|
||||
"""
|
||||
Возвращает оптимизатор для обучения модели.
|
||||
Фабричная функция для создания оптимизатора PyTorch по выбранному типу.
|
||||
|
||||
Параметры
|
||||
---------
|
||||
model : torch.nn.Module
|
||||
Модель, параметры которой требуется оптимизировать.
|
||||
lr : float, по умолчанию 3e-4
|
||||
Шаг обучения (learning rate).
|
||||
weight_decay : float, по умолчанию 0.01
|
||||
Коэффициент weight decay (L2-регуляризации).
|
||||
optimizer_type : str, по умолчанию 'adamw'
|
||||
Тип оптимизатора: 'adamw', 'adam' или 'sgd'.
|
||||
|
||||
Возвращаемое значение
|
||||
---------------------
|
||||
torch.optim.Optimizer
|
||||
Объект-оптимизатор, готовый к использованию.
|
||||
|
||||
Исключения
|
||||
----------
|
||||
ValueError: Если передан неизвестный тип оптимизатора.
|
||||
|
||||
Пример использования:
|
||||
---------------------
|
||||
>>> optimizer = get_optimizer(model, lr=1e-3, optimizer_type='sgd')
|
||||
"""
|
||||
if optimizer_type.lower() == "adamw":
|
||||
return optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
|
||||
|
||||
@@ -1,13 +1,66 @@
|
||||
"""
|
||||
Модуль для управления динамикой шага обучения (learning rate scheduling) при обучении нейронных сетей.
|
||||
|
||||
Теоретическое обоснование:
|
||||
--------------------------
|
||||
Плавная динамика шага обучения существенно влияет на сходимость и итоговое качество моделей. Введение этапа "разогрева" (warmup) — техники, при которой шаг обучения начинается с нуля и постепенно увеличивается до целевого значения, снижает вероятность неустойчивых градиентов на старте обучения. Подобная стратегия показала свою эффективность для крупных нейронных сетей, особенно в трансформерах (Vaswani et al, 2017, https://arxiv.org/abs/1706.03762).
|
||||
|
||||
Линейный scheduler с warmup задаёт динамику learning rate по формуле:
|
||||
- если current_step < num_warmup_steps:
|
||||
lr = lr_init * (current_step / num_warmup_steps)
|
||||
- иначе:
|
||||
lr = lr_init * max(0, (num_training_steps - current_step) / (num_training_steps - num_warmup_steps))
|
||||
|
||||
Пример использования:
|
||||
---------------------
|
||||
>>> optimizer = get_optimizer(model, lr=3e-4)
|
||||
>>> scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=1000, num_training_steps=10000)
|
||||
>>> for step in range(num_training_steps):
|
||||
... optimizer.step()
|
||||
... scheduler.step()
|
||||
"""
|
||||
|
||||
from torch.optim.lr_scheduler import LambdaLR
|
||||
|
||||
def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps):
|
||||
"""
|
||||
Линейный планировщик обучения с warmup.
|
||||
Создаёт линейный планировщик изменения шага обучения (learning rate) с этапом warmup для оптимизатора PyTorch.
|
||||
|
||||
Аргументы
|
||||
---------
|
||||
optimizer : torch.optim.Optimizer
|
||||
Оптимизатор, для которого применяется scheduler.
|
||||
num_warmup_steps : int
|
||||
Количество шагов разогрева (warmup) — начиная с нулевого шага и плавного увеличения lr до номинального значения.
|
||||
num_training_steps : int
|
||||
Общее количество шагов (эпох/итераций) обучения модели.
|
||||
|
||||
Возвращаемое значение
|
||||
---------------------
|
||||
torch.optim.lr_scheduler.LambdaLR
|
||||
Планировщик lr, который следует вызывать после каждого optimizer.step() во время обучения.
|
||||
|
||||
Теоретическая справка
|
||||
---------------------
|
||||
Такой scheduler позволяет повысить стабильность и устойчивость обучения крупных моделей (особенно трансформеров), предотвращая резкие скачки градиентов в начале.
|
||||
|
||||
Пример:
|
||||
-------
|
||||
>>> optimizer = get_optimizer(model, lr=3e-4)
|
||||
>>> scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=1000, num_training_steps=10000)
|
||||
>>> for step in range(num_training_steps):
|
||||
... optimizer.step()
|
||||
... scheduler.step()
|
||||
"""
|
||||
|
||||
def lr_lambda(current_step):
|
||||
# Линейный рост lr на этапе разогрева
|
||||
if current_step < num_warmup_steps:
|
||||
return float(current_step) / float(max(1, num_warmup_steps))
|
||||
return max(0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps)))
|
||||
|
||||
# Линейное затухание lr после разогрева
|
||||
return max(
|
||||
0.0,
|
||||
float(num_training_steps - current_step)
|
||||
/ float(max(1, num_training_steps - num_warmup_steps)),
|
||||
)
|
||||
return LambdaLR(optimizer, lr_lambda)
|
||||
|
||||
@@ -1,3 +1,22 @@
|
||||
"""
|
||||
Модуль для организации процесса обучения больших языковых моделей (LLM).
|
||||
|
||||
Научное и техническое обоснование
|
||||
----------------------------------
|
||||
Эффективное обучение современных трансформеров (GPT, LLaMA, Mistral и др.) опирается на принципы языкового моделирования (Language Modeling):
|
||||
- Предсказание вероятности следующего токена на основе предыдущих.
|
||||
- Использование функции потерь кросс-энтропии (cross-entropy) с маскированием паддингов.
|
||||
- Циклы обратного распространения ошибки (backpropagation), оптимизационные алгоритмы (например, AdamW), управление шагом обучения (scheduler с warmup), обрезка градиентов (grad clipping).
|
||||
|
||||
Реализация объединяет лучшие практики обучения LLM, универсальный API к моделям, датасетам, оптимизаторам и lr-схемам.
|
||||
|
||||
Подробнее: Vaswani et al. "Attention is All You Need" (2017), Radford et al. "Language Models are Unsupervised Multitask Learners" (2019)
|
||||
|
||||
Пример использования
|
||||
--------------------
|
||||
>>> trainer = Trainer(model, train_dataset, val_dataset, lr=3e-4, batch_size=8, num_epochs=3, warmup_steps=100)
|
||||
>>> trainer.train()
|
||||
"""
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.utils.data import DataLoader
|
||||
@@ -5,15 +24,75 @@ from tqdm import tqdm
|
||||
from llm.training.optimizer import get_optimizer
|
||||
from llm.training.scheduler import get_linear_schedule_with_warmup
|
||||
|
||||
|
||||
class Trainer:
|
||||
"""
|
||||
Универсальный класс обучения LLM (GPT, LLaMA, Mistral и т.д.)
|
||||
Универсальный и расширяемый класс для обучения больших языковых моделей (Large Language Models, LLM).
|
||||
|
||||
Поддерживаются архитектуры семейства GPT, LLaMA, Mistral и другие автогрессивные модели.
|
||||
Объединяет:
|
||||
- Тренировку по задаче языкового моделирования (Causal LM)
|
||||
- Cross-entropy loss с автоматическим сдвигом логитов/меток
|
||||
- Поддержку Grad Clipping, Scheduler, Validation
|
||||
- Унифицированный даталоадер, автоматический выбор устройства (CPU/GPU)
|
||||
|
||||
Атрибуты
|
||||
--------
|
||||
model : torch.nn.Module
|
||||
Модель для обучения языковому моделированию
|
||||
train_loader : torch.utils.data.DataLoader
|
||||
Даталоадер обучающего набора
|
||||
val_loader : torch.utils.data.DataLoader или None
|
||||
Даталоадер валидационного набора (если задан)
|
||||
optimizer : torch.optim.Optimizer
|
||||
Оптимизатор параметров модели
|
||||
scheduler : torch.optim.lr_scheduler.LambdaLR
|
||||
Планировщик learning rate (инициализируется в train)
|
||||
device : torch.device
|
||||
Устройство (CPU или CUDA), куда помещается модель
|
||||
num_epochs : int
|
||||
Количество эпох обучения
|
||||
warmup_steps : int
|
||||
Число шагов warmup для scheduler
|
||||
"""
|
||||
|
||||
def __init__(self, model, train_dataset, val_dataset=None, lr=3e-4, batch_size=8, num_epochs=3, warmup_steps=100):
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
train_dataset,
|
||||
val_dataset=None,
|
||||
lr=3e-4,
|
||||
batch_size=8,
|
||||
num_epochs=3,
|
||||
warmup_steps=100,
|
||||
):
|
||||
"""
|
||||
Инициализация обучающего класса Trainer.
|
||||
|
||||
Аргументы
|
||||
---------
|
||||
model : torch.nn.Module
|
||||
Модель для обучения (например, GPT, LLaMA, Mistral).
|
||||
train_dataset : torch.utils.data.Dataset
|
||||
Обучающий датасет с полями input_ids и labels.
|
||||
val_dataset : torch.utils.data.Dataset, optional
|
||||
Валидационный датасет для контроля качества обучения.
|
||||
lr : float, default=3e-4
|
||||
Начальный шаг обучения.
|
||||
batch_size : int, default=8
|
||||
Размер обучающего мини-батча.
|
||||
num_epochs : int, default=3
|
||||
Количество эпох обучения.
|
||||
warmup_steps : int, default=100
|
||||
Количество шагов разогрева (warmup) learning rate.
|
||||
"""
|
||||
self.model = model
|
||||
self.train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
|
||||
self.val_loader = DataLoader(val_dataset, batch_size=batch_size) if val_dataset else None
|
||||
self.train_loader = DataLoader(
|
||||
train_dataset, batch_size=batch_size, shuffle=True
|
||||
)
|
||||
self.val_loader = (
|
||||
DataLoader(val_dataset, batch_size=batch_size) if val_dataset else None
|
||||
)
|
||||
self.optimizer = get_optimizer(model, lr=lr)
|
||||
self.scheduler = None
|
||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
@@ -23,44 +102,74 @@ class Trainer:
|
||||
|
||||
def compute_lm_loss(self, logits, labels):
|
||||
"""
|
||||
Вычисляет loss для языкового моделирования.
|
||||
Сдвигает логиты и метки для предсказания следующего токена.
|
||||
Вычисляет функцию потерь (loss) для задачи автогрессивного языкового моделирования.
|
||||
|
||||
Производит сдвиг логитов и меток: предсказания делаются для следующего токена.
|
||||
Используется кросс-энтропия (CrossEntropyLoss), что соответствует максимизации логарифма правдоподобия:
|
||||
L = -log P(w_{t+1} | w_1,...,w_t)
|
||||
|
||||
Аргументы
|
||||
---------
|
||||
logits : torch.Tensor
|
||||
Логиты модели: (batch_size, seq_len, vocab_size)
|
||||
labels : torch.Tensor
|
||||
Правильные метки: (batch_size, seq_len)
|
||||
Возвращаемое значение
|
||||
---------------------
|
||||
loss : torch.Tensor
|
||||
Средний loss по batch.
|
||||
"""
|
||||
# Сдвигаем логиты и метки для языкового моделирования
|
||||
# Сдвигаем логиты и метки для языкового моделирования (автогрессия)
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
|
||||
# Вычисляем cross-entropy loss
|
||||
|
||||
# CrossEntropyLoss (игнорируем паддинги: ignore_index=-100)
|
||||
loss = F.cross_entropy(
|
||||
shift_logits.view(-1, shift_logits.size(-1)),
|
||||
shift_labels.view(-1),
|
||||
ignore_index=-100 # Игнорируем padding tokens
|
||||
ignore_index=-100, # Padding токены не участвуют в loss
|
||||
)
|
||||
return loss
|
||||
|
||||
def train(self):
|
||||
"""
|
||||
Запускает процесс обучения модели по заданному числу эпох.
|
||||
|
||||
В процессе:
|
||||
- Применяет optimizer, scheduler с warmup и decay, grad clipping (обрезка градиентов)
|
||||
- Вызывает функцию потерь для языкового моделирования
|
||||
- Показывает динамику процесса (tqdm)
|
||||
- После каждой эпохи возможно проведение валидации
|
||||
|
||||
Параметры задаются на этапе инициализации Trainer.
|
||||
"""
|
||||
total_steps = len(self.train_loader) * self.num_epochs
|
||||
self.scheduler = get_linear_schedule_with_warmup(self.optimizer, self.warmup_steps, total_steps)
|
||||
self.scheduler = get_linear_schedule_with_warmup(
|
||||
self.optimizer, self.warmup_steps, total_steps
|
||||
)
|
||||
self.loss_history = [] # добавлено: лог средних потерь
|
||||
|
||||
for epoch in range(self.num_epochs):
|
||||
self.model.train()
|
||||
total_loss = 0
|
||||
|
||||
progress_bar = tqdm(self.train_loader, desc=f"Epoch {epoch+1}/{self.num_epochs}")
|
||||
progress_bar = tqdm(
|
||||
self.train_loader, desc=f"Epoch {epoch+1}/{self.num_epochs}"
|
||||
)
|
||||
for batch in progress_bar:
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
input_ids = batch["input_ids"].to(self.device)
|
||||
labels = batch["labels"].to(self.device)
|
||||
|
||||
# Универсально обрабатываем выход (tuple/logits)
|
||||
# Универсально обрабатываем выходы модели: tuple или просто tensor (logits)
|
||||
outputs = self.model(input_ids)
|
||||
if isinstance(outputs, tuple):
|
||||
logits = outputs[0]
|
||||
else:
|
||||
logits = outputs
|
||||
|
||||
# Trainer вычисляет loss
|
||||
|
||||
# Вычисляем loss автогрессивной LM-задачи
|
||||
loss = self.compute_lm_loss(logits, labels)
|
||||
loss.backward()
|
||||
|
||||
@@ -72,12 +181,19 @@ class Trainer:
|
||||
progress_bar.set_postfix(loss=loss.item())
|
||||
|
||||
avg_loss = total_loss / len(self.train_loader)
|
||||
self.loss_history.append(avg_loss) # добавлено: запоминаем loss
|
||||
print(f"Epoch {epoch+1} finished — avg loss: {avg_loss:.4f}")
|
||||
|
||||
if self.val_loader:
|
||||
self.evaluate()
|
||||
|
||||
def evaluate(self):
|
||||
"""
|
||||
Оценивает модель на валидационном датасете (если задан).
|
||||
|
||||
В режиме eval() модели отключается dropout и все стохастические элементы.
|
||||
Возвращает среднее значение функции потерь (loss) по всему validation set.
|
||||
"""
|
||||
self.model.eval()
|
||||
total_loss = 0
|
||||
|
||||
@@ -85,7 +201,7 @@ class Trainer:
|
||||
for batch in self.val_loader:
|
||||
input_ids = batch["input_ids"].to(self.device)
|
||||
labels = batch["labels"].to(self.device)
|
||||
|
||||
|
||||
outputs = self.model(input_ids)
|
||||
if isinstance(outputs, tuple):
|
||||
logits = outputs[0]
|
||||
@@ -95,4 +211,4 @@ class Trainer:
|
||||
total_loss += loss.item()
|
||||
|
||||
avg_loss = total_loss / len(self.val_loader)
|
||||
print(f"Validation loss: {avg_loss:.4f}")
|
||||
print(f"Validation loss: {avg_loss:.4f}")
|
||||
@@ -58,7 +58,7 @@ def gpt_config(vocab_size, embed_dim, num_heads, num_layers):
|
||||
"num_heads": num_heads,
|
||||
"num_layers": num_layers,
|
||||
"max_position_embeddings": 1024,
|
||||
"dropout": 0.1
|
||||
"dropout": 0.1,
|
||||
}
|
||||
|
||||
|
||||
@@ -68,12 +68,14 @@ def random_inputs(batch_size, seq_len, vocab_size):
|
||||
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len))
|
||||
return input_ids
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def random_float_inputs(batch_size, seq_len, embed_dim):
|
||||
"""Generate random floating point input tensors for testing feed forward."""
|
||||
inputs = torch.randn(batch_size, seq_len, embed_dim)
|
||||
return inputs
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def random_embeddings(batch_size, seq_len, embed_dim):
|
||||
"""Generate random embedding tensors for testing attention modules."""
|
||||
|
||||
65
llm/tests/core/test_cached_decoder.py
Normal file
65
llm/tests/core/test_cached_decoder.py
Normal file
@@ -0,0 +1,65 @@
|
||||
import torch
|
||||
import pytest
|
||||
from llm.core.cached_decoder import CachedDecoder
|
||||
from llm.core.feed_forward import FeedForward
|
||||
|
||||
@pytest.fixture
|
||||
def decoder_config():
|
||||
return dict(
|
||||
num_heads=4,
|
||||
emb_size=32,
|
||||
head_size=8,
|
||||
feed_forward_layer=FeedForward(emb_size=32, dropout=0.1, activation="gelu"),
|
||||
max_seq_len=64,
|
||||
dropout=0.1
|
||||
)
|
||||
|
||||
def test_cached_decoder_init(decoder_config):
|
||||
model = CachedDecoder(**decoder_config)
|
||||
assert model is not None
|
||||
# Main attention block is usually stored as _heads or _attention (which itself includes _q _k _v)
|
||||
assert hasattr(model, '_heads') or hasattr(model, '_attention')
|
||||
assert hasattr(model, '_ff') or hasattr(model, 'feed_forward_layer')
|
||||
|
||||
def test_cached_decoder_forward_shape(decoder_config):
|
||||
model = CachedDecoder(**decoder_config)
|
||||
batch, seq_len, emb_size = 3, 10, decoder_config['emb_size']
|
||||
x = torch.randn(batch, seq_len, emb_size)
|
||||
output, cache = model(x, use_cache=True)
|
||||
assert output.shape == (batch, seq_len, emb_size)
|
||||
assert cache is not None
|
||||
|
||||
def test_cached_decoder_forward_no_cache(decoder_config):
|
||||
model = CachedDecoder(**decoder_config)
|
||||
batch, seq_len, emb_size = 2, 12, decoder_config['emb_size']
|
||||
x = torch.randn(batch, seq_len, emb_size)
|
||||
output, cache = model(x, use_cache=False)
|
||||
assert output.shape == (batch, seq_len, emb_size)
|
||||
assert cache is None
|
||||
|
||||
def test_cached_decoder_error_on_long_seq(decoder_config):
|
||||
model = CachedDecoder(**decoder_config)
|
||||
batch, seq_len, emb_size = 1, decoder_config['max_seq_len'] + 1, decoder_config['emb_size']
|
||||
x = torch.randn(batch, seq_len, emb_size)
|
||||
with pytest.raises(ValueError):
|
||||
model(x)
|
||||
|
||||
def test_cached_decoder_backward(decoder_config):
|
||||
model = CachedDecoder(**decoder_config)
|
||||
batch, seq_len, emb_size = 2, 7, decoder_config['emb_size']
|
||||
x = torch.randn(batch, seq_len, emb_size, requires_grad=True)
|
||||
output, cache = model(x)
|
||||
loss = output.sum()
|
||||
loss.backward()
|
||||
assert x.grad is not None
|
||||
|
||||
def test_cached_decoder_kv_cache_chain(decoder_config):
|
||||
model = CachedDecoder(**decoder_config)
|
||||
batch, seq_len, emb_size = 1, 4, decoder_config['emb_size']
|
||||
x = torch.randn(batch, seq_len, emb_size)
|
||||
# Первый проход — кэша нет
|
||||
_, cache = model(x, use_cache=True)
|
||||
# Второй проход — передаём кэш, добавляем еще токен:
|
||||
next_x = torch.randn(batch, 1, emb_size)
|
||||
_, cache2 = model(next_x, use_cache=True, cache=cache)
|
||||
assert cache2 is not None
|
||||
@@ -10,168 +10,178 @@ from llm.core.feed_forward import FeedForward
|
||||
|
||||
class TestFeedForward:
|
||||
"""Test cases for FeedForward."""
|
||||
|
||||
|
||||
def test_initialization(self, embed_dim):
|
||||
"""Test that FeedForward can be initialized."""
|
||||
ff = FeedForward(embed_dim)
|
||||
assert ff is not None
|
||||
|
||||
|
||||
# Check internal layers
|
||||
assert hasattr(ff, '_layer1')
|
||||
assert hasattr(ff, '_layer2')
|
||||
assert hasattr(ff, '_activation')
|
||||
assert hasattr(ff, '_dropout')
|
||||
|
||||
assert hasattr(ff, "_layer1")
|
||||
assert hasattr(ff, "_layer2")
|
||||
assert hasattr(ff, "_activation")
|
||||
assert hasattr(ff, "_dropout")
|
||||
|
||||
# Check layer dimensions
|
||||
expected_hidden_dim = embed_dim * 4 # Default expansion factor
|
||||
assert ff._layer1.weight.shape == (expected_hidden_dim, embed_dim)
|
||||
assert ff._layer2.weight.shape == (embed_dim, expected_hidden_dim)
|
||||
|
||||
|
||||
def test_forward_pass(self, embed_dim, random_float_inputs):
|
||||
"""Test forward pass of FeedForward."""
|
||||
ff = FeedForward(embed_dim)
|
||||
|
||||
|
||||
# Forward pass
|
||||
output = ff(random_float_inputs)
|
||||
|
||||
|
||||
# Check output shape
|
||||
assert output.shape == random_float_inputs.shape
|
||||
assert isinstance(output, torch.Tensor)
|
||||
|
||||
|
||||
def test_custom_hidden_dim(self, embed_dim):
|
||||
"""Test FeedForward with custom hidden dimension."""
|
||||
# FeedForward doesn't support custom hidden_dim in current implementation
|
||||
# This test is not applicable
|
||||
ff = FeedForward(embed_dim)
|
||||
|
||||
|
||||
# Check layer dimensions (fixed 4x expansion)
|
||||
expected_hidden_dim = embed_dim * 4
|
||||
assert ff._layer1.weight.shape == (expected_hidden_dim, embed_dim)
|
||||
assert ff._layer2.weight.shape == (embed_dim, expected_hidden_dim)
|
||||
|
||||
|
||||
def test_dropout(self, embed_dim, random_float_inputs):
|
||||
"""Test that dropout is applied during training."""
|
||||
ff = FeedForward(embed_dim, dropout=0.5)
|
||||
ff.train() # Set to training mode
|
||||
|
||||
|
||||
output = ff(random_float_inputs)
|
||||
|
||||
|
||||
# In training mode with dropout, some values should be zeroed
|
||||
# This is probabilistic, so we can't assert exact zeros,
|
||||
# but we can check the structure is preserved
|
||||
assert output.shape == random_float_inputs.shape
|
||||
|
||||
|
||||
def test_no_dropout_in_eval(self, embed_dim, random_float_inputs):
|
||||
"""Test that dropout is not applied during evaluation."""
|
||||
ff = FeedForward(embed_dim, dropout=0.5)
|
||||
ff.eval() # Set to evaluation mode
|
||||
|
||||
|
||||
# Run forward pass multiple times - outputs should be identical
|
||||
output1 = ff(random_float_inputs)
|
||||
output2 = ff(random_float_inputs)
|
||||
|
||||
|
||||
assert torch.allclose(output1, output2)
|
||||
|
||||
|
||||
def test_activation_function(self, embed_dim, random_float_inputs):
|
||||
"""Test that activation function is applied."""
|
||||
ff = FeedForward(embed_dim)
|
||||
|
||||
|
||||
# Manually compute expected output without dropout for deterministic comparison
|
||||
hidden = ff._layer1(random_float_inputs)
|
||||
activated = ff._activation(hidden)
|
||||
expected_output = ff._layer2(activated)
|
||||
|
||||
|
||||
# Compare with forward pass in eval mode (no dropout)
|
||||
ff.eval()
|
||||
actual_output = ff(random_float_inputs)
|
||||
|
||||
|
||||
assert torch.allclose(actual_output, expected_output, rtol=1e-4)
|
||||
|
||||
|
||||
def test_gradient_flow(self, embed_dim, random_float_inputs):
|
||||
"""Test that gradients flow through FeedForward."""
|
||||
ff = FeedForward(embed_dim)
|
||||
|
||||
|
||||
# Forward pass
|
||||
output = ff(random_float_inputs)
|
||||
|
||||
|
||||
# Create a dummy loss and backward pass
|
||||
loss = output.sum()
|
||||
loss.backward()
|
||||
|
||||
|
||||
# Check that gradients are computed for learnable parameters
|
||||
assert ff._layer1.weight.grad is not None
|
||||
assert ff._layer2.weight.grad is not None
|
||||
assert not torch.allclose(ff._layer1.weight.grad,
|
||||
torch.zeros_like(ff._layer1.weight.grad))
|
||||
assert not torch.allclose(ff._layer2.weight.grad,
|
||||
torch.zeros_like(ff._layer2.weight.grad))
|
||||
|
||||
assert not torch.allclose(
|
||||
ff._layer1.weight.grad, torch.zeros_like(ff._layer1.weight.grad)
|
||||
)
|
||||
assert not torch.allclose(
|
||||
ff._layer2.weight.grad, torch.zeros_like(ff._layer2.weight.grad)
|
||||
)
|
||||
|
||||
def test_device_consistency(self, embed_dim, random_float_inputs, device):
|
||||
"""Test that FeedForward works on correct device."""
|
||||
ff = FeedForward(embed_dim).to(device)
|
||||
inputs = random_float_inputs.to(device)
|
||||
|
||||
|
||||
# Forward pass
|
||||
output = ff(inputs)
|
||||
|
||||
|
||||
# Check device consistency
|
||||
assert output.device == device
|
||||
assert ff._layer1.weight.device == device
|
||||
assert ff._layer2.weight.device == device
|
||||
|
||||
|
||||
def test_different_embed_dims(self):
|
||||
"""Test FeedForward with different embedding dimensions."""
|
||||
test_cases = [64, 128, 256, 512]
|
||||
|
||||
|
||||
for embed_dim in test_cases:
|
||||
ff = FeedForward(embed_dim)
|
||||
batch_size, seq_len = 2, 16
|
||||
inputs = torch.randn(batch_size, seq_len, embed_dim)
|
||||
|
||||
|
||||
output = ff(inputs)
|
||||
|
||||
|
||||
assert output.shape == inputs.shape
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size,seq_len", [(1, 8), (2, 16), (4, 32)])
|
||||
def test_different_input_shapes(self, embed_dim, batch_size, seq_len):
|
||||
"""Test FeedForward with different input shapes."""
|
||||
ff = FeedForward(embed_dim)
|
||||
|
||||
|
||||
inputs = torch.randn(batch_size, seq_len, embed_dim)
|
||||
output = ff(inputs)
|
||||
|
||||
|
||||
assert output.shape == (batch_size, seq_len, embed_dim)
|
||||
|
||||
|
||||
def test_non_linearity(self, embed_dim, random_float_inputs):
|
||||
"""Test that FeedForward introduces non-linearity."""
|
||||
ff = FeedForward(embed_dim)
|
||||
|
||||
|
||||
# Create a simple linear transformation for comparison
|
||||
linear_layer = nn.Linear(embed_dim, embed_dim)
|
||||
|
||||
|
||||
# Copy weights to make comparison fair
|
||||
with torch.no_grad():
|
||||
linear_layer.weight.copy_(ff._layer2.weight @ ff._layer1.weight)
|
||||
if linear_layer.bias is not None:
|
||||
linear_layer.bias.zero_()
|
||||
|
||||
|
||||
linear_output = linear_layer(random_float_inputs)
|
||||
ff_output = ff(random_float_inputs)
|
||||
|
||||
|
||||
# FeedForward output should be different from pure linear transformation
|
||||
# due to activation function
|
||||
assert not torch.allclose(ff_output, linear_output, rtol=1e-4)
|
||||
|
||||
|
||||
def test_parameter_initialization(self, embed_dim):
|
||||
"""Test that parameters are properly initialized."""
|
||||
ff = FeedForward(embed_dim)
|
||||
|
||||
|
||||
# Check that weights are not all zeros
|
||||
assert not torch.allclose(ff._layer1.weight, torch.zeros_like(ff._layer1.weight))
|
||||
assert not torch.allclose(ff._layer2.weight, torch.zeros_like(ff._layer2.weight))
|
||||
|
||||
assert not torch.allclose(
|
||||
ff._layer1.weight, torch.zeros_like(ff._layer1.weight)
|
||||
)
|
||||
assert not torch.allclose(
|
||||
ff._layer2.weight, torch.zeros_like(ff._layer2.weight)
|
||||
)
|
||||
|
||||
# Check that biases are not all zeros (they should be initialized with some values)
|
||||
if ff._layer1.bias is not None:
|
||||
assert not torch.allclose(ff._layer1.bias, torch.zeros_like(ff._layer1.bias))
|
||||
assert not torch.allclose(
|
||||
ff._layer1.bias, torch.zeros_like(ff._layer1.bias)
|
||||
)
|
||||
if ff._layer2.bias is not None:
|
||||
assert not torch.allclose(ff._layer2.bias, torch.zeros_like(ff._layer2.bias))
|
||||
assert not torch.allclose(
|
||||
ff._layer2.bias, torch.zeros_like(ff._layer2.bias)
|
||||
)
|
||||
|
||||
60
llm/tests/core/test_geglu.py
Normal file
60
llm/tests/core/test_geglu.py
Normal file
@@ -0,0 +1,60 @@
|
||||
import torch
|
||||
import pytest
|
||||
from llm.core.geglu import GeGLU
|
||||
|
||||
@pytest.fixture
|
||||
def geglu():
|
||||
return GeGLU(emb_size=16, dropout=0.1)
|
||||
|
||||
def test_forward_shape(geglu):
|
||||
x = torch.randn(2, 5, 16)
|
||||
y = geglu(x)
|
||||
assert y.shape == x.shape
|
||||
|
||||
def test_forward_no_batch(geglu):
|
||||
x = torch.randn(1, 16)
|
||||
y = geglu(x.unsqueeze(0))
|
||||
assert y.shape == (1, 1, 16)
|
||||
|
||||
@pytest.mark.skip(reason="float16 not supported without parameter casting")
|
||||
def test_forward_dtype_fp16():
|
||||
geglu = GeGLU(emb_size=8, dropout=0.0)
|
||||
x = torch.randn(2, 4, 8).half()
|
||||
y = geglu(x)
|
||||
assert y.shape == x.shape
|
||||
assert y.dtype == torch.float16
|
||||
|
||||
def test_forward_no_dropout():
|
||||
geglu = GeGLU(emb_size=4, dropout=0.0)
|
||||
x = torch.randn(3, 2, 4)
|
||||
y = geglu(x)
|
||||
assert not torch.isnan(y).any()
|
||||
assert not torch.isinf(y).any()
|
||||
|
||||
def test_gradient_flow(geglu):
|
||||
x = torch.randn(3, 8, 16, requires_grad=True)
|
||||
y = geglu(x)
|
||||
y.sum().backward()
|
||||
assert x.grad is not None
|
||||
assert x.grad.shape == x.shape
|
||||
|
||||
def test_forward_repeatability():
|
||||
torch.manual_seed(42)
|
||||
geglu = GeGLU(emb_size=8, dropout=0.0)
|
||||
x = torch.randn(3, 2, 8)
|
||||
y1 = geglu(x)
|
||||
torch.manual_seed(42)
|
||||
geglu2 = GeGLU(emb_size=8, dropout=0.0)
|
||||
x2 = torch.randn(3, 2, 8)
|
||||
y2 = geglu2(x2)
|
||||
assert torch.allclose(y1, y2, atol=1e-5)
|
||||
|
||||
def test_edge_small_large():
|
||||
geglu = GeGLU(emb_size=2, dropout=0.0)
|
||||
x = torch.randn(2, 2, 2)
|
||||
y = geglu(x)
|
||||
assert y.shape == x.shape
|
||||
geglu = GeGLU(emb_size=256, dropout=0.0)
|
||||
x = torch.randn(1, 1, 256)
|
||||
y = geglu(x)
|
||||
assert y.shape == x.shape
|
||||
46
llm/tests/core/test_gelu.py
Normal file
46
llm/tests/core/test_gelu.py
Normal file
@@ -0,0 +1,46 @@
|
||||
import torch
|
||||
import pytest
|
||||
from llm.core.gelu import GELU
|
||||
|
||||
def test_gelu_shapes_and_dtype():
|
||||
gelu = GELU()
|
||||
x = torch.randn(4, 16, 8)
|
||||
y = gelu(x)
|
||||
assert y.shape == x.shape
|
||||
assert y.dtype == x.dtype
|
||||
|
||||
def test_gelu_known_values():
|
||||
gelu = GELU()
|
||||
x = torch.tensor([-3.0, 0.0, 3.0])
|
||||
y = gelu(x)
|
||||
# Сравнение с PyTorch F.gelu (которая использует точный алгоритм)
|
||||
y_ref = torch.nn.functional.gelu(x)
|
||||
diff = (y - y_ref).abs().max().item()
|
||||
assert diff < 5e-3, f"Max difference {diff} exceeds threshold"
|
||||
|
||||
def test_gelu_is_smooth_and_monotonic():
|
||||
gelu = GELU()
|
||||
x = torch.linspace(-5, 5, 100)
|
||||
y = gelu(x)
|
||||
dy = y[1:] - y[:-1]
|
||||
# Проверяем, что функция GELU хотя бы локально монотонна на большинстве промежутков
|
||||
assert (dy.mean() > 0 or dy.mean() < 0)
|
||||
|
||||
def test_gelu_gradients():
|
||||
gelu = GELU()
|
||||
x = torch.randn(3, 5, requires_grad=True)
|
||||
y = gelu(x)
|
||||
loss = y.sum()
|
||||
loss.backward()
|
||||
assert x.grad is not None
|
||||
assert x.grad.shape == x.shape
|
||||
|
||||
def test_gelu_large_vs_small():
|
||||
gelu = GELU()
|
||||
x_pos = torch.tensor([100.0])
|
||||
x_neg = torch.tensor([-100.0])
|
||||
y_pos = gelu(x_pos)
|
||||
y_neg = gelu(x_neg)
|
||||
# Для больших положительных GELU(x) ~ x, для больших отрицательных ~0
|
||||
assert torch.allclose(y_pos, x_pos, rtol=1e-4, atol=1e-4)
|
||||
assert torch.allclose(y_neg, torch.zeros_like(x_neg), rtol=1e-4, atol=1e-4)
|
||||
67
llm/tests/core/test_gemma_decoder.py
Normal file
67
llm/tests/core/test_gemma_decoder.py
Normal file
@@ -0,0 +1,67 @@
|
||||
import torch
|
||||
import pytest
|
||||
from llm.core.gemma_decoder import GemmaDecoder
|
||||
from llm.core.rope import RoPE
|
||||
|
||||
@pytest.fixture
|
||||
def gemma_decoder():
|
||||
rope = RoPE(head_size=4, max_seq_len=32)
|
||||
return GemmaDecoder(
|
||||
num_q_heads=4,
|
||||
emb_size=16,
|
||||
head_size=4,
|
||||
max_seq_len=32,
|
||||
rope=rope,
|
||||
dropout=0.1,
|
||||
)
|
||||
|
||||
def test_forward_shape(gemma_decoder):
|
||||
x = torch.randn(2, 12, 16)
|
||||
out, cache = gemma_decoder(x)
|
||||
assert out.shape == (2, 12, 16)
|
||||
assert isinstance(cache, tuple) or cache is None
|
||||
|
||||
def test_forward_masked(gemma_decoder):
|
||||
x = torch.randn(1, 8, 16)
|
||||
mask = torch.ones(1, 8, 8, dtype=torch.bool)
|
||||
out, _ = gemma_decoder(x, mask=mask)
|
||||
assert out.shape == x.shape
|
||||
|
||||
def test_forward_with_cache_flag(gemma_decoder):
|
||||
x = torch.randn(2, 7, 16)
|
||||
out, cache = gemma_decoder(x, use_cache=True, cache=None)
|
||||
assert out.shape == (2, 7, 16)
|
||||
|
||||
def test_forward_wrong_seq_len_raises(gemma_decoder):
|
||||
x = torch.randn(1, 100, 16)
|
||||
with pytest.raises(Exception):
|
||||
gemma_decoder(x)
|
||||
|
||||
def test_gradient_flow(gemma_decoder):
|
||||
x = torch.randn(3, 9, 16, requires_grad=True)
|
||||
y, _ = gemma_decoder(x)
|
||||
y.sum().backward()
|
||||
assert x.grad is not None
|
||||
assert x.grad.shape == x.shape
|
||||
|
||||
def test_various_shapes(gemma_decoder):
|
||||
for b, s in [(1, 1), (2, 5), (2, 32)]:
|
||||
x = torch.randn(b, s, 16)
|
||||
y, _ = gemma_decoder(x)
|
||||
assert y.shape == (b, s, 16)
|
||||
|
||||
def test_forward_repeatability():
|
||||
torch.manual_seed(42)
|
||||
rope = RoPE(head_size=4, max_seq_len=32)
|
||||
decoder = GemmaDecoder(
|
||||
num_q_heads=4, emb_size=16, head_size=4, max_seq_len=32, rope=rope, dropout=0.0,
|
||||
)
|
||||
x = torch.randn(2, 8, 16)
|
||||
y1, _ = decoder(x)
|
||||
torch.manual_seed(42)
|
||||
decoder2 = GemmaDecoder(
|
||||
num_q_heads=4, emb_size=16, head_size=4, max_seq_len=32, rope=rope, dropout=0.0,
|
||||
)
|
||||
x2 = torch.randn(2, 8, 16)
|
||||
y2, _ = decoder2(x2)
|
||||
assert torch.allclose(y1, y2, atol=1e-5)
|
||||
@@ -4,185 +4,238 @@ Tests for decoder block.
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from llm.core.decoder import Decoder
|
||||
from llm.core.gpt_decoder import GptDecoder
|
||||
|
||||
|
||||
class TestDecoder:
|
||||
class TestGptDecoder:
|
||||
"""Test cases for Decoder."""
|
||||
|
||||
|
||||
def test_initialization(self, embed_dim, num_heads):
|
||||
"""Test that Decoder can be initialized."""
|
||||
head_size = embed_dim // num_heads
|
||||
max_seq_len = 1024
|
||||
decoder = Decoder(num_heads=num_heads, emb_size=embed_dim, head_size=head_size, max_seq_len=max_seq_len)
|
||||
decoder = GptDecoder(
|
||||
num_heads=num_heads,
|
||||
emb_size=embed_dim,
|
||||
head_size=head_size,
|
||||
max_seq_len=max_seq_len,
|
||||
)
|
||||
assert decoder is not None
|
||||
|
||||
|
||||
# Check internal components
|
||||
assert hasattr(decoder, '_heads')
|
||||
assert hasattr(decoder, '_ff')
|
||||
assert hasattr(decoder, '_norm1')
|
||||
assert hasattr(decoder, '_norm2')
|
||||
|
||||
assert hasattr(decoder, "_heads")
|
||||
assert hasattr(decoder, "_ff")
|
||||
assert hasattr(decoder, "_norm1")
|
||||
assert hasattr(decoder, "_norm2")
|
||||
|
||||
def test_forward_pass(self, embed_dim, num_heads, random_embeddings):
|
||||
"""Test forward pass of Decoder."""
|
||||
head_size = embed_dim // num_heads
|
||||
max_seq_len = 1024
|
||||
decoder = Decoder(num_heads=num_heads, emb_size=embed_dim, head_size=head_size, max_seq_len=max_seq_len)
|
||||
|
||||
decoder = GptDecoder(
|
||||
num_heads=num_heads,
|
||||
emb_size=embed_dim,
|
||||
head_size=head_size,
|
||||
max_seq_len=max_seq_len,
|
||||
)
|
||||
|
||||
# Forward pass
|
||||
output = decoder(random_embeddings)
|
||||
|
||||
output, _ = decoder(random_embeddings)
|
||||
|
||||
# Check output shape
|
||||
assert output.shape == random_embeddings.shape
|
||||
assert isinstance(output, torch.Tensor)
|
||||
|
||||
|
||||
def test_forward_with_causal_mask(self, embed_dim, num_heads, random_embeddings):
|
||||
"""Test forward pass with causal mask."""
|
||||
head_size = embed_dim // num_heads
|
||||
max_seq_len = 1024
|
||||
decoder = Decoder(num_heads=num_heads, emb_size=embed_dim, head_size=head_size, max_seq_len=max_seq_len)
|
||||
|
||||
decoder = GptDecoder(
|
||||
num_heads=num_heads,
|
||||
emb_size=embed_dim,
|
||||
head_size=head_size,
|
||||
max_seq_len=max_seq_len,
|
||||
)
|
||||
|
||||
batch_size, seq_len = random_embeddings.shape[:2]
|
||||
# Create causal mask
|
||||
mask = torch.tril(torch.ones(seq_len, seq_len))
|
||||
|
||||
|
||||
# Forward pass with causal mask
|
||||
output = decoder(random_embeddings, mask=mask)
|
||||
|
||||
output, _ = decoder(random_embeddings, attention_mask=mask)
|
||||
|
||||
# Check output shape
|
||||
assert output.shape == random_embeddings.shape
|
||||
|
||||
|
||||
def test_residual_connections(self, embed_dim, num_heads, random_embeddings):
|
||||
"""Test that residual connections are properly applied."""
|
||||
head_size = embed_dim // num_heads
|
||||
max_seq_len = 1024
|
||||
decoder = Decoder(num_heads=num_heads, emb_size=embed_dim, head_size=head_size, max_seq_len=max_seq_len)
|
||||
|
||||
output = decoder(random_embeddings)
|
||||
|
||||
decoder = GptDecoder(
|
||||
num_heads=num_heads,
|
||||
emb_size=embed_dim,
|
||||
head_size=head_size,
|
||||
max_seq_len=max_seq_len,
|
||||
)
|
||||
|
||||
output, _ = decoder(random_embeddings)
|
||||
|
||||
# With residual connections and layer norm, the output shouldn't be
|
||||
# too different from input (in terms of scale/distribution)
|
||||
input_norm = random_embeddings.norm(dim=-1).mean()
|
||||
output_norm = output.norm(dim=-1).mean()
|
||||
|
||||
|
||||
# Norms should be of similar magnitude (not exact due to transformations)
|
||||
assert 0.1 < (output_norm / input_norm) < 10.0
|
||||
|
||||
|
||||
def test_layer_norm(self, embed_dim, num_heads, random_embeddings):
|
||||
"""Test that layer normalization is applied."""
|
||||
head_size = embed_dim // num_heads
|
||||
max_seq_len = 1024
|
||||
decoder = Decoder(num_heads=num_heads, emb_size=embed_dim, head_size=head_size, max_seq_len=max_seq_len)
|
||||
|
||||
output = decoder(random_embeddings)
|
||||
|
||||
decoder = GptDecoder(
|
||||
num_heads=num_heads,
|
||||
emb_size=embed_dim,
|
||||
head_size=head_size,
|
||||
max_seq_len=max_seq_len,
|
||||
)
|
||||
|
||||
output, _ = decoder(random_embeddings)
|
||||
|
||||
# Check that output has reasonable statistics (due to layer norm)
|
||||
# Mean should be close to 0, std close to 1 for each sequence position
|
||||
output_mean = output.mean(dim=-1)
|
||||
output_std = output.std(dim=-1)
|
||||
|
||||
|
||||
# These are approximate checks since the data goes through multiple transformations
|
||||
assert torch.allclose(output_mean, torch.zeros_like(output_mean), atol=1.0)
|
||||
assert torch.allclose(output_std, torch.ones_like(output_std), atol=2.0)
|
||||
|
||||
|
||||
def test_gradient_flow(self, embed_dim, num_heads, random_embeddings):
|
||||
"""Test that gradients flow through Decoder."""
|
||||
head_size = embed_dim // num_heads
|
||||
max_seq_len = 1024
|
||||
decoder = Decoder(num_heads=num_heads, emb_size=embed_dim, head_size=head_size, max_seq_len=max_seq_len)
|
||||
|
||||
decoder = GptDecoder(
|
||||
num_heads=num_heads,
|
||||
emb_size=embed_dim,
|
||||
head_size=head_size,
|
||||
max_seq_len=max_seq_len,
|
||||
)
|
||||
|
||||
# Forward pass
|
||||
output = decoder(random_embeddings)
|
||||
|
||||
output, _ = decoder(random_embeddings)
|
||||
|
||||
# Create a dummy loss and backward pass
|
||||
loss = output.sum()
|
||||
loss.backward()
|
||||
|
||||
|
||||
# Check that gradients are computed for learnable parameters
|
||||
# in attention and feed forward components
|
||||
assert decoder._heads._layer.weight.grad is not None
|
||||
assert decoder._ff._layer1.weight.grad is not None
|
||||
assert decoder._norm1.weight.grad is not None
|
||||
assert decoder._norm2.weight.grad is not None
|
||||
|
||||
|
||||
def test_device_consistency(self, embed_dim, num_heads, random_embeddings, device):
|
||||
"""Test that Decoder works on correct device."""
|
||||
head_size = embed_dim // num_heads
|
||||
max_seq_len = 1024
|
||||
decoder = Decoder(num_heads=num_heads, emb_size=embed_dim, head_size=head_size, max_seq_len=max_seq_len).to(device)
|
||||
decoder = GptDecoder(
|
||||
num_heads=num_heads,
|
||||
emb_size=embed_dim,
|
||||
head_size=head_size,
|
||||
max_seq_len=max_seq_len,
|
||||
).to(device)
|
||||
inputs = random_embeddings.to(device)
|
||||
|
||||
|
||||
# Forward pass
|
||||
output = decoder(inputs)
|
||||
|
||||
output, _ = decoder(inputs)
|
||||
|
||||
# Check device consistency
|
||||
assert output.device == device
|
||||
assert decoder._heads._layer.weight.device == device
|
||||
|
||||
|
||||
def test_different_configurations(self):
|
||||
"""Test Decoder with different configurations."""
|
||||
test_cases = [
|
||||
(64, 2), # embed_dim=64, num_heads=2
|
||||
(64, 2), # embed_dim=64, num_heads=2
|
||||
(128, 4), # embed_dim=128, num_heads=4
|
||||
(256, 8), # embed_dim=256, num_heads=8
|
||||
]
|
||||
|
||||
|
||||
for embed_dim, num_heads in test_cases:
|
||||
head_size = embed_dim // num_heads
|
||||
max_seq_len = 1024
|
||||
decoder = Decoder(num_heads=num_heads, emb_size=embed_dim, head_size=head_size, max_seq_len=max_seq_len)
|
||||
decoder = GptDecoder(
|
||||
num_heads=num_heads,
|
||||
emb_size=embed_dim,
|
||||
head_size=head_size,
|
||||
max_seq_len=max_seq_len,
|
||||
)
|
||||
batch_size, seq_len = 2, 16
|
||||
inputs = torch.randn(batch_size, seq_len, embed_dim)
|
||||
|
||||
output = decoder(inputs)
|
||||
|
||||
|
||||
output, _ = decoder(inputs)
|
||||
|
||||
assert output.shape == inputs.shape
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size,seq_len", [(1, 8), (2, 16), (4, 32)])
|
||||
def test_different_input_shapes(self, embed_dim, num_heads, batch_size, seq_len):
|
||||
"""Test Decoder with different input shapes."""
|
||||
head_size = embed_dim // num_heads
|
||||
max_seq_len = 1024
|
||||
decoder = Decoder(num_heads=num_heads, emb_size=embed_dim, head_size=head_size, max_seq_len=max_seq_len)
|
||||
|
||||
decoder = GptDecoder(
|
||||
num_heads=num_heads,
|
||||
emb_size=embed_dim,
|
||||
head_size=head_size,
|
||||
max_seq_len=max_seq_len,
|
||||
)
|
||||
|
||||
inputs = torch.randn(batch_size, seq_len, embed_dim)
|
||||
output = decoder(inputs)
|
||||
|
||||
output, _ = decoder(inputs)
|
||||
|
||||
assert output.shape == (batch_size, seq_len, embed_dim)
|
||||
|
||||
|
||||
def test_training_vs_evaluation(self, embed_dim, num_heads, random_embeddings):
|
||||
"""Test that Decoder behaves differently in train vs eval mode."""
|
||||
head_size = embed_dim // num_heads
|
||||
max_seq_len = 1024
|
||||
decoder = Decoder(num_heads=num_heads, emb_size=embed_dim, head_size=head_size, max_seq_len=max_seq_len, dropout=0.5)
|
||||
|
||||
decoder = GptDecoder(
|
||||
num_heads=num_heads,
|
||||
emb_size=embed_dim,
|
||||
head_size=head_size,
|
||||
max_seq_len=max_seq_len,
|
||||
dropout=0.5,
|
||||
)
|
||||
|
||||
# Training mode
|
||||
decoder.train()
|
||||
output_train = decoder(random_embeddings)
|
||||
|
||||
output_train, _ = decoder(random_embeddings)
|
||||
|
||||
# Evaluation mode
|
||||
decoder.eval()
|
||||
output_eval = decoder(random_embeddings)
|
||||
|
||||
output_eval, _ = decoder(random_embeddings)
|
||||
|
||||
# Outputs should be different due to dropout
|
||||
assert not torch.allclose(output_train, output_eval)
|
||||
|
||||
|
||||
def test_parameter_initialization(self, embed_dim, num_heads):
|
||||
"""Test that parameters are properly initialized."""
|
||||
head_size = embed_dim // num_heads
|
||||
max_seq_len = 1024
|
||||
decoder = Decoder(num_heads=num_heads, emb_size=embed_dim, head_size=head_size, max_seq_len=max_seq_len)
|
||||
|
||||
decoder = GptDecoder(
|
||||
num_heads=num_heads,
|
||||
emb_size=embed_dim,
|
||||
head_size=head_size,
|
||||
max_seq_len=max_seq_len,
|
||||
)
|
||||
|
||||
# Check that various components have non-zero parameters
|
||||
assert not torch.allclose(
|
||||
decoder._heads._layer.weight,
|
||||
torch.zeros_like(decoder._heads._layer.weight)
|
||||
decoder._heads._layer.weight, torch.zeros_like(decoder._heads._layer.weight)
|
||||
)
|
||||
assert not torch.allclose(
|
||||
decoder._ff._layer1.weight,
|
||||
torch.zeros_like(decoder._ff._layer1.weight)
|
||||
decoder._ff._layer1.weight, torch.zeros_like(decoder._ff._layer1.weight)
|
||||
)
|
||||
assert not torch.allclose(
|
||||
decoder._norm1.weight,
|
||||
torch.zeros_like(decoder._norm1.weight)
|
||||
decoder._norm1.weight, torch.zeros_like(decoder._norm1.weight)
|
||||
)
|
||||
85
llm/tests/core/test_group_query_attention.py
Normal file
85
llm/tests/core/test_group_query_attention.py
Normal file
@@ -0,0 +1,85 @@
|
||||
# llm/tests/core/test_group_query_attention.py
|
||||
|
||||
import torch
|
||||
import pytest
|
||||
from llm.core.group_query_attention import GroupedQueryAttention
|
||||
from llm.core.rope import RoPE
|
||||
|
||||
@pytest.fixture
|
||||
def params():
|
||||
return {
|
||||
'num_q_heads': 4,
|
||||
'num_kv_heads': 2,
|
||||
'emb_size': 16,
|
||||
'head_size': 4,
|
||||
'max_seq_len': 32,
|
||||
'window_size': 8,
|
||||
'dropout': 0.0
|
||||
}
|
||||
|
||||
def test_initialization(params):
|
||||
attn = GroupedQueryAttention(**params)
|
||||
assert isinstance(attn, GroupedQueryAttention)
|
||||
|
||||
def test_forward_shape(params):
|
||||
batch, seq = 2, 10
|
||||
x = torch.randn(batch, seq, params['emb_size'])
|
||||
attn = GroupedQueryAttention(**params)
|
||||
y, cache = attn(x)
|
||||
assert y.shape == (batch, seq, params['emb_size'])
|
||||
assert cache is not None
|
||||
assert isinstance(y, torch.Tensor)
|
||||
|
||||
def test_forward_shape_with_mask(params):
|
||||
batch, seq = 2, 10
|
||||
x = torch.randn(batch, seq, params['emb_size'])
|
||||
mask = torch.tril(torch.ones(seq, seq)).bool()
|
||||
attn = GroupedQueryAttention(**params)
|
||||
y, _ = attn(x, mask=mask)
|
||||
assert y.shape == (batch, seq, params['emb_size'])
|
||||
|
||||
def test_kv_repetition(params):
|
||||
batch, seq = 1, 3
|
||||
attn = GroupedQueryAttention(**params)
|
||||
kv = torch.randn(batch, params['num_kv_heads'], seq, params['head_size'])
|
||||
rep = attn._repeat_kv_heads(kv, params['num_q_heads'], params['num_kv_heads'])
|
||||
assert rep.shape == (batch, params['num_q_heads'], seq, params['head_size'])
|
||||
|
||||
def test_window_mask(params):
|
||||
attn = GroupedQueryAttention(**params)
|
||||
mask = attn._create_sliding_window_mask(8, 3)
|
||||
assert mask.shape == (8, 8)
|
||||
# Проверим булеву маску окна в позиции 4
|
||||
expected = torch.tensor([True, True, True, True, False, False])
|
||||
assert torch.equal(mask[4, 1:7], expected)
|
||||
|
||||
def test_forward_with_rope(params):
|
||||
batch, seq = 2, 12
|
||||
x = torch.randn(batch, seq, params['emb_size'])
|
||||
rope = RoPE(head_size=params['head_size'], max_seq_len=params['max_seq_len'])
|
||||
params2 = params.copy()
|
||||
params2['rope'] = rope
|
||||
attn = GroupedQueryAttention(**params2)
|
||||
y, _ = attn(x)
|
||||
assert y.shape == (batch, seq, params['emb_size'])
|
||||
|
||||
def test_cache_usage(params):
|
||||
batch, seq = 1, 5
|
||||
x = torch.randn(batch, seq, params['emb_size'])
|
||||
attn = GroupedQueryAttention(**params)
|
||||
# Первый проход - получаем кэш
|
||||
_, cache = attn(x)
|
||||
# Второй проход с кэшем (имитируем автокомплит seq_len=1)
|
||||
x2 = torch.randn(batch, 1, params['emb_size'])
|
||||
y2, cache2 = attn(x2, cache=cache)
|
||||
assert cache2 is not None
|
||||
assert y2.shape == (batch, 1, params['emb_size'])
|
||||
|
||||
def test_gradient_backward(params):
|
||||
batch, seq = 2, 6
|
||||
x = torch.randn(batch, seq, params['emb_size'], requires_grad=True)
|
||||
attn = GroupedQueryAttention(**params)
|
||||
y, _ = attn(x)
|
||||
y.sum().backward()
|
||||
for param in attn.parameters():
|
||||
assert param.grad is not None
|
||||
66
llm/tests/core/test_mistral_decoder.py
Normal file
66
llm/tests/core/test_mistral_decoder.py
Normal file
@@ -0,0 +1,66 @@
|
||||
import torch
|
||||
import pytest
|
||||
from llm.core.mistral_decoder import MistralDecoder
|
||||
from llm.core.rope import RoPE
|
||||
|
||||
@pytest.fixture
|
||||
def decoder_config():
|
||||
# Current MistralDecoder is a single block (not a stack).
|
||||
return dict(
|
||||
num_q_heads=4,
|
||||
num_kv_heads=2,
|
||||
emb_size=32,
|
||||
head_size=8,
|
||||
max_seq_len=128,
|
||||
window_size=16,
|
||||
rope=RoPE(head_size=8, max_seq_len=128),
|
||||
dropout=0.0
|
||||
)
|
||||
|
||||
def test_mistral_decoder_init(decoder_config):
|
||||
model = MistralDecoder(**decoder_config)
|
||||
assert model is not None
|
||||
|
||||
def test_mistral_decoder_forward_shapes(decoder_config):
|
||||
model = MistralDecoder(**decoder_config)
|
||||
batch, seq_len, emb_size = 2, 10, decoder_config['emb_size']
|
||||
x = torch.randn(batch, seq_len, emb_size)
|
||||
output, cache = model(x, use_cache=True)
|
||||
assert output.shape == (batch, seq_len, emb_size)
|
||||
assert cache is not None
|
||||
|
||||
def test_mistral_decoder_forward_no_cache(decoder_config):
|
||||
model = MistralDecoder(**decoder_config)
|
||||
batch, seq_len, emb_size = 2, 10, decoder_config['emb_size']
|
||||
x = torch.randn(batch, seq_len, emb_size)
|
||||
output, cache = model(x, use_cache=False)
|
||||
assert output.shape == (batch, seq_len, emb_size)
|
||||
assert cache is None
|
||||
|
||||
def test_mistral_decoder_cache_shapes(decoder_config):
|
||||
model = MistralDecoder(**decoder_config)
|
||||
batch, seq_len, emb_size = 2, 8, decoder_config['emb_size']
|
||||
x = torch.randn(batch, seq_len, emb_size)
|
||||
# Первый проход — без кэша
|
||||
_, cache = model(x, use_cache=True)
|
||||
# Второй проход — заполняем кэш
|
||||
x_next = torch.randn(batch, 1, emb_size)
|
||||
_, cache2 = model(x_next, use_cache=True, cache=cache)
|
||||
# Можно проверить, что кэш не None и корректной структуры:
|
||||
assert cache2 is not None
|
||||
|
||||
def test_mistral_decoder_shape_error(decoder_config):
|
||||
model = MistralDecoder(**decoder_config)
|
||||
batch, seq_len, emb_size = 2, decoder_config['max_seq_len'] + 1, decoder_config['emb_size']
|
||||
x = torch.randn(batch, seq_len, emb_size)
|
||||
with pytest.raises(ValueError):
|
||||
model(x)
|
||||
|
||||
def test_mistral_decoder_backward(decoder_config):
|
||||
model = MistralDecoder(**decoder_config)
|
||||
batch, seq_len, emb_size = 2, 10, decoder_config['emb_size']
|
||||
x = torch.randn(batch, seq_len, emb_size, requires_grad=True)
|
||||
output, _ = model(x, use_cache=False)
|
||||
loss = output.sum()
|
||||
loss.backward()
|
||||
assert x.grad is not None
|
||||
80
llm/tests/core/test_mixtral_decoder.py
Normal file
80
llm/tests/core/test_mixtral_decoder.py
Normal file
@@ -0,0 +1,80 @@
|
||||
import torch
|
||||
import pytest
|
||||
from llm.core.mixtral_decoder import MixtralDecoder
|
||||
from llm.core.rope import RoPE
|
||||
|
||||
@pytest.fixture
|
||||
def basic_decoder():
|
||||
emb_size = 16
|
||||
num_q_heads = 4
|
||||
num_kv_heads = 2
|
||||
head_size = 4
|
||||
max_seq_len = 32
|
||||
num_experts = 4
|
||||
top_k_experts = 2
|
||||
window_size = 8
|
||||
rope = RoPE(head_size=head_size, max_seq_len=max_seq_len)
|
||||
return MixtralDecoder(
|
||||
num_q_heads=num_q_heads,
|
||||
num_kv_heads=num_kv_heads,
|
||||
emb_size=emb_size,
|
||||
head_size=head_size,
|
||||
max_seq_len=max_seq_len,
|
||||
num_experts=num_experts,
|
||||
top_k_experts=top_k_experts,
|
||||
window_size=window_size,
|
||||
rope=rope,
|
||||
dropout=0.0,
|
||||
)
|
||||
|
||||
def test_forward_shape(basic_decoder):
|
||||
x = torch.randn(2, 10, 16)
|
||||
out, cache = basic_decoder(x)
|
||||
assert out.shape == (2, 10, 16)
|
||||
assert cache is None or isinstance(cache, (tuple, list))
|
||||
|
||||
def test_forward_masked(basic_decoder):
|
||||
x = torch.randn(3, 7, 16)
|
||||
mask = torch.ones(3, 7, 7, dtype=torch.bool)
|
||||
out, cache = basic_decoder(x, mask=mask)
|
||||
assert out.shape == (3, 7, 16)
|
||||
|
||||
def test_forward_with_cache_flag(basic_decoder):
|
||||
x = torch.randn(2, 8, 16)
|
||||
out, cache = basic_decoder(x, use_cache=True, cache=None)
|
||||
assert out.shape == (2, 8, 16)
|
||||
assert isinstance(cache, (tuple, list)) or cache is None
|
||||
|
||||
def test_backprop_pass(basic_decoder):
|
||||
x = torch.randn(2, 5, 16, requires_grad=True)
|
||||
out, _ = basic_decoder(x)
|
||||
y = out.sum()
|
||||
y.backward()
|
||||
assert x.grad is not None
|
||||
assert x.grad.shape == x.shape
|
||||
|
||||
def test_seq_too_long_raises(basic_decoder):
|
||||
x = torch.randn(1, 40, 16) # seq_len > max_seq_len
|
||||
with pytest.raises(Exception):
|
||||
basic_decoder(x)
|
||||
|
||||
def test_different_config():
|
||||
rope = RoPE(head_size=2, max_seq_len=12)
|
||||
decoder = MixtralDecoder(
|
||||
num_q_heads=2, num_kv_heads=2, emb_size=4, head_size=2,
|
||||
max_seq_len=12, num_experts=2, top_k_experts=1, window_size=4, rope=rope, dropout=0.1
|
||||
)
|
||||
x = torch.randn(1, 8, 4)
|
||||
out, cache = decoder(x)
|
||||
assert out.shape == x.shape
|
||||
|
||||
def test_forward_no_dropout():
|
||||
# Проверка на корректность shape при отсутствии Dropout
|
||||
rope = RoPE(head_size=2, max_seq_len=12)
|
||||
decoder = MixtralDecoder(
|
||||
num_q_heads=2, num_kv_heads=1, emb_size=4, head_size=2,
|
||||
max_seq_len=12, num_experts=2, top_k_experts=1, window_size=3, rope=rope, dropout=0.0
|
||||
)
|
||||
x = torch.randn(2, 3, 4)
|
||||
out, cache = decoder(x)
|
||||
assert out.shape == x.shape
|
||||
61
llm/tests/core/test_moe.py
Normal file
61
llm/tests/core/test_moe.py
Normal file
@@ -0,0 +1,61 @@
|
||||
import torch
|
||||
import pytest
|
||||
from llm.core.moe import MoE
|
||||
|
||||
@pytest.fixture
|
||||
def moe():
|
||||
# Базовая MoE для коротких тестов
|
||||
return MoE(emb_size=16, num_experts=4, top_k_experts=2, dropout=0.0)
|
||||
|
||||
def test_forward_shape(moe):
|
||||
x = torch.randn(3, 5, 16) # [batch, seq, emb]
|
||||
y = moe(x)
|
||||
assert y.shape == x.shape
|
||||
|
||||
def test_forward_grad(moe):
|
||||
x = torch.randn(2, 4, 16, requires_grad=True)
|
||||
y = moe(x)
|
||||
(y.sum()).backward()
|
||||
assert x.grad is not None
|
||||
assert x.grad.shape == x.shape
|
||||
|
||||
def test_top_k_larger_than_experts():
|
||||
# top_k_experts > num_experts должно падать
|
||||
with pytest.raises(ValueError):
|
||||
MoE(emb_size=8, num_experts=2, top_k_experts=4)
|
||||
|
||||
def test_single_expert_no_error():
|
||||
# один эксперт, один топ-к — модель всё ещё валидна
|
||||
moe = MoE(emb_size=8, num_experts=1, top_k_experts=1)
|
||||
x = torch.randn(2, 2, 8)
|
||||
y = moe(x)
|
||||
assert y.shape == x.shape
|
||||
|
||||
def test_forward_trivial_weights():
|
||||
"""Проверяет, что при одинаковых весах роутера MoE возвращает усреднённое по экспертам."""
|
||||
class DummyMoE(MoE):
|
||||
def forward(self, x):
|
||||
# Роутер отдаёт всегда единичные логиты = softmax -> uniform
|
||||
self._router = torch.nn.Linear(x.size(-1), self._num_experts, bias=False)
|
||||
torch.nn.init.constant_(self._router.weight, 0.0)
|
||||
return super().forward(x)
|
||||
moe = DummyMoE(emb_size=4, num_experts=2, top_k_experts=2)
|
||||
x = torch.zeros(1, 2, 4)
|
||||
y = moe(x)
|
||||
assert y.shape == x.shape
|
||||
|
||||
def test_forward_deterministic_seed(moe):
|
||||
torch.manual_seed(42)
|
||||
x = torch.randn(2, 3, 16)
|
||||
y1 = moe(x)
|
||||
torch.manual_seed(42)
|
||||
y2 = moe(x)
|
||||
assert torch.allclose(y1, y2, atol=1e-5)
|
||||
|
||||
def test_forward_no_dropout():
|
||||
"""Без dropout MoE не меняет shape и не даёт NaN."""
|
||||
moe = MoE(emb_size=5, num_experts=3, top_k_experts=2, dropout=0.0)
|
||||
x = torch.randn(2, 7, 5)
|
||||
y = moe(x)
|
||||
assert y.shape == x.shape
|
||||
assert not torch.isnan(y).any()
|
||||
@@ -9,157 +9,183 @@ from llm.core.multi_head_attention import MultiHeadAttention
|
||||
|
||||
class TestMultiHeadAttention:
|
||||
"""Test cases for MultiHeadAttention."""
|
||||
|
||||
|
||||
def test_initialization(self, embed_dim, num_heads):
|
||||
"""Test that MultiHeadAttention can be initialized."""
|
||||
head_size = embed_dim // num_heads
|
||||
attention = MultiHeadAttention(num_heads, embed_dim, head_size, max_seq_len=1024)
|
||||
attention = MultiHeadAttention(
|
||||
num_heads, embed_dim, head_size, max_seq_len=1024
|
||||
)
|
||||
assert attention is not None
|
||||
|
||||
|
||||
# Check internal attributes
|
||||
assert len(attention._heads) == num_heads
|
||||
assert attention._num_heads == num_heads
|
||||
assert attention._layer.in_features == embed_dim
|
||||
assert attention._layer.out_features == embed_dim
|
||||
|
||||
|
||||
def test_forward_pass(self, embed_dim, num_heads, random_embeddings):
|
||||
"""Test forward pass of MultiHeadAttention."""
|
||||
head_size = embed_dim // num_heads
|
||||
attention = MultiHeadAttention(num_heads, embed_dim, head_size, max_seq_len=1024)
|
||||
|
||||
attention = MultiHeadAttention(
|
||||
num_heads, embed_dim, head_size, max_seq_len=1024
|
||||
)
|
||||
|
||||
# Forward pass
|
||||
output, _ = attention(random_embeddings)
|
||||
|
||||
|
||||
# Check output shape
|
||||
assert output.shape == random_embeddings.shape
|
||||
assert isinstance(output, torch.Tensor)
|
||||
|
||||
|
||||
def test_forward_with_mask(self, embed_dim, num_heads, random_embeddings):
|
||||
"""Test forward pass with attention mask."""
|
||||
head_size = embed_dim // num_heads
|
||||
attention = MultiHeadAttention(num_heads, embed_dim, head_size, max_seq_len=1024)
|
||||
|
||||
attention = MultiHeadAttention(
|
||||
num_heads, embed_dim, head_size, max_seq_len=1024
|
||||
)
|
||||
|
||||
# Create a simple mask
|
||||
seq_len = random_embeddings.shape[1]
|
||||
mask = torch.tril(torch.ones(seq_len, seq_len)) # Causal mask
|
||||
|
||||
|
||||
# Forward pass with mask
|
||||
output, _ = attention(random_embeddings, mask=mask)
|
||||
|
||||
|
||||
# Check output shape
|
||||
assert output.shape == random_embeddings.shape
|
||||
|
||||
|
||||
def test_causal_mask(self, embed_dim, num_heads, random_embeddings):
|
||||
"""Test that causal mask prevents attending to future positions."""
|
||||
head_size = embed_dim // num_heads
|
||||
attention = MultiHeadAttention(num_heads, embed_dim, head_size, max_seq_len=1024)
|
||||
|
||||
attention = MultiHeadAttention(
|
||||
num_heads, embed_dim, head_size, max_seq_len=1024
|
||||
)
|
||||
|
||||
# Create causal mask
|
||||
seq_len = random_embeddings.shape[1]
|
||||
causal_mask = torch.tril(torch.ones(seq_len, seq_len))
|
||||
|
||||
|
||||
# Forward pass with causal mask
|
||||
output, _ = attention(random_embeddings, mask=causal_mask)
|
||||
|
||||
|
||||
# Check output shape
|
||||
assert output.shape == random_embeddings.shape
|
||||
|
||||
def test_attention_weights_normalization(self, embed_dim, num_heads, random_embeddings):
|
||||
|
||||
def test_attention_weights_normalization(
|
||||
self, embed_dim, num_heads, random_embeddings
|
||||
):
|
||||
"""Test that attention weights are properly normalized."""
|
||||
head_size = embed_dim // num_heads
|
||||
attention = MultiHeadAttention(num_heads, embed_dim, head_size, max_seq_len=1024)
|
||||
|
||||
attention = MultiHeadAttention(
|
||||
num_heads, embed_dim, head_size, max_seq_len=1024
|
||||
)
|
||||
|
||||
# Forward pass
|
||||
output, _ = attention(random_embeddings)
|
||||
|
||||
|
||||
# Check output shape
|
||||
assert output.shape == random_embeddings.shape
|
||||
|
||||
|
||||
def test_gradient_flow(self, embed_dim, num_heads, random_embeddings):
|
||||
"""Test that gradients flow through MultiHeadAttention."""
|
||||
head_size = embed_dim // num_heads
|
||||
attention = MultiHeadAttention(num_heads, embed_dim, head_size, max_seq_len=1024)
|
||||
|
||||
attention = MultiHeadAttention(
|
||||
num_heads, embed_dim, head_size, max_seq_len=1024
|
||||
)
|
||||
|
||||
# Forward pass
|
||||
output, _ = attention(random_embeddings)
|
||||
|
||||
|
||||
# Create a dummy loss and backward pass
|
||||
loss = output.sum()
|
||||
loss.backward()
|
||||
|
||||
|
||||
# Check that gradients are computed for learnable parameters
|
||||
assert attention._layer.weight.grad is not None
|
||||
if len(attention._heads) > 0:
|
||||
assert attention._heads[0]._q.weight.grad is not None
|
||||
|
||||
# Проверяем, что также у градиентов весов q/k/v есть значения
|
||||
assert attention._q.weight.grad is not None
|
||||
assert attention._k.weight.grad is not None
|
||||
assert attention._v.weight.grad is not None
|
||||
|
||||
def test_device_consistency(self, embed_dim, num_heads, random_embeddings, device):
|
||||
"""Test that MultiHeadAttention works on correct device."""
|
||||
head_size = embed_dim // num_heads
|
||||
attention = MultiHeadAttention(num_heads, embed_dim, head_size, max_seq_len=1024).to(device)
|
||||
attention = MultiHeadAttention(
|
||||
num_heads, embed_dim, head_size, max_seq_len=1024
|
||||
).to(device)
|
||||
inputs = random_embeddings.to(device)
|
||||
|
||||
|
||||
# Forward pass
|
||||
output, _ = attention(inputs)
|
||||
|
||||
|
||||
# Check device consistency
|
||||
assert output.device == device
|
||||
assert attention._layer.weight.device == device
|
||||
|
||||
|
||||
def test_different_embed_dim_and_heads(self):
|
||||
"""Test MultiHeadAttention with different embed_dim and num_heads combinations."""
|
||||
test_cases = [
|
||||
(64, 2), # embed_dim=64, num_heads=2
|
||||
(64, 2), # embed_dim=64, num_heads=2
|
||||
(128, 4), # embed_dim=128, num_heads=4
|
||||
(256, 8), # embed_dim=256, num_heads=8
|
||||
(512, 16), # embed_dim=512, num_heads=16
|
||||
(512, 16), # embed_dim=512, num_heads=16
|
||||
]
|
||||
|
||||
|
||||
for embed_dim, num_heads in test_cases:
|
||||
head_size = embed_dim // num_heads
|
||||
attention = MultiHeadAttention(num_heads, embed_dim, head_size, max_seq_len=1024)
|
||||
attention = MultiHeadAttention(
|
||||
num_heads, embed_dim, head_size, max_seq_len=1024
|
||||
)
|
||||
batch_size, seq_len = 2, 16
|
||||
inputs = torch.randn(batch_size, seq_len, embed_dim)
|
||||
|
||||
|
||||
output, _ = attention(inputs)
|
||||
|
||||
|
||||
assert output.shape == inputs.shape
|
||||
|
||||
|
||||
def test_attention_output_range(self, embed_dim, num_heads, random_embeddings):
|
||||
"""Test that attention output is in reasonable range."""
|
||||
head_size = embed_dim // num_heads
|
||||
attention = MultiHeadAttention(num_heads, embed_dim, head_size, max_seq_len=1024)
|
||||
|
||||
attention = MultiHeadAttention(
|
||||
num_heads, embed_dim, head_size, max_seq_len=1024
|
||||
)
|
||||
|
||||
output, _ = attention(random_embeddings)
|
||||
|
||||
|
||||
# Output shouldn't have extreme values
|
||||
assert output.abs().max() < 100 # Reasonable upper bound
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size,seq_len", [(1, 8), (2, 16), (4, 32)])
|
||||
def test_different_input_shapes(self, embed_dim, num_heads, batch_size, seq_len):
|
||||
"""Test MultiHeadAttention with different input shapes."""
|
||||
head_size = embed_dim // num_heads
|
||||
attention = MultiHeadAttention(num_heads, embed_dim, head_size, max_seq_len=1024)
|
||||
|
||||
attention = MultiHeadAttention(
|
||||
num_heads, embed_dim, head_size, max_seq_len=1024
|
||||
)
|
||||
|
||||
inputs = torch.randn(batch_size, seq_len, embed_dim)
|
||||
output, _ = attention(inputs)
|
||||
|
||||
|
||||
assert output.shape == (batch_size, seq_len, embed_dim)
|
||||
|
||||
|
||||
def test_parameter_sharing(self, embed_dim, num_heads):
|
||||
"""Test that parameters are properly shared across the sequence."""
|
||||
head_size = embed_dim // num_heads
|
||||
attention = MultiHeadAttention(num_heads, embed_dim, head_size, max_seq_len=1024, dropout=0.0) # No dropout for deterministic test
|
||||
|
||||
attention = MultiHeadAttention(
|
||||
num_heads, embed_dim, head_size, max_seq_len=1024, dropout=0.0
|
||||
) # No dropout for deterministic test
|
||||
|
||||
# Create two identical sequences
|
||||
seq_len = 10
|
||||
base_sequence = torch.randn(1, seq_len, embed_dim)
|
||||
identical_sequence = base_sequence.clone()
|
||||
|
||||
|
||||
# Set to eval mode to disable dropout
|
||||
attention.eval()
|
||||
|
||||
|
||||
with torch.no_grad():
|
||||
output1, _ = attention(base_sequence)
|
||||
output2, _ = attention(identical_sequence)
|
||||
|
||||
|
||||
# With identical inputs and same parameters, outputs should be identical
|
||||
assert torch.allclose(output1, output2, rtol=1e-5)
|
||||
|
||||
71
llm/tests/core/test_multi_query_attention.py
Normal file
71
llm/tests/core/test_multi_query_attention.py
Normal file
@@ -0,0 +1,71 @@
|
||||
import torch
|
||||
import pytest
|
||||
from llm.core.multi_query_attention import MultiQueryAttention
|
||||
from llm.core.rope import RoPE
|
||||
|
||||
@pytest.fixture
|
||||
def mqa_rope():
|
||||
return MultiQueryAttention(
|
||||
num_q_heads=4, emb_size=16, head_size=4, max_seq_len=32, rope=RoPE(head_size=4, max_seq_len=32), dropout=0.1
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def mqa_no_rope():
|
||||
return MultiQueryAttention(
|
||||
num_q_heads=2, emb_size=8, head_size=4, max_seq_len=16, rope=None, dropout=0.0
|
||||
)
|
||||
|
||||
def test_forward_shape(mqa_rope):
|
||||
x = torch.randn(2, 10, 16)
|
||||
out, cache = mqa_rope(x)
|
||||
assert out.shape == (2, 10, 16)
|
||||
assert isinstance(cache, tuple) and len(cache) == 2
|
||||
|
||||
def test_forward_masked(mqa_rope):
|
||||
x = torch.randn(2, 8, 16)
|
||||
mask = torch.ones(2, 8, 8, dtype=torch.bool)
|
||||
out, cache = mqa_rope(x, mask=mask)
|
||||
assert out.shape == (2, 8, 16)
|
||||
|
||||
def test_forward_cache(mqa_rope):
|
||||
x = torch.randn(1, 4, 16)
|
||||
# Первый вызов — кэша нет
|
||||
out1, cache1 = mqa_rope(x)
|
||||
# Повторяем: подаем x второй раз — теперь добавим cache
|
||||
out2, cache2 = mqa_rope(x, use_cache=True, cache=cache1)
|
||||
assert out2.shape == (1, 4, 16)
|
||||
assert isinstance(cache2, tuple) and len(cache2) == 2
|
||||
# Проверка, что длина k_cache увеличилась
|
||||
assert cache2[0].shape[2] == cache1[0].shape[2] + x.shape[1] # по длине seq
|
||||
|
||||
def test_forward_no_rope(mqa_no_rope):
|
||||
x = torch.randn(3, 6, 8)
|
||||
out, _ = mqa_no_rope(x)
|
||||
assert out.shape == (3, 6, 8)
|
||||
|
||||
def test_forward_different_batch_seq(mqa_rope):
|
||||
for batch, seq in [(1, 1), (2, 5), (3, 32)]:
|
||||
x = torch.randn(batch, seq, 16)
|
||||
out, _ = mqa_rope(x)
|
||||
assert out.shape == (batch, seq, 16)
|
||||
|
||||
def test_forward_raise_on_long_seq(mqa_rope):
|
||||
x = torch.randn(2, 40, 16) # seq_len > max_seq_len
|
||||
with pytest.raises(ValueError):
|
||||
mqa_rope(x)
|
||||
|
||||
def test_forward_grad(mqa_rope):
|
||||
x = torch.randn(2, 7, 16, requires_grad=True)
|
||||
out, _ = mqa_rope(x)
|
||||
y = out.sum()
|
||||
y.backward()
|
||||
assert x.grad is not None
|
||||
assert x.grad.shape == x.shape
|
||||
|
||||
def test_dropout_applied():
|
||||
mqa = MultiQueryAttention(num_q_heads=2, emb_size=8, head_size=4, max_seq_len=12, rope=None, dropout=0.99)
|
||||
x = torch.ones(1, 3, 8)
|
||||
mqa.train()
|
||||
y, _ = mqa(x)
|
||||
# При очень большом dropout почти всё обнуляется
|
||||
assert (torch.abs(y) < 1e-5).float().mean() > 0.6 or y.sum() < 1e-2
|
||||
@@ -10,127 +10,134 @@ from llm.core.positional_embeddings import PositionalEmbeddings
|
||||
|
||||
class TestPositionalEmbeddings:
|
||||
"""Test cases for PositionalEmbeddings."""
|
||||
|
||||
|
||||
def test_initialization(self, embed_dim):
|
||||
"""Test that PositionalEmbeddings can be initialized."""
|
||||
max_seq_len = 1024
|
||||
embeddings = PositionalEmbeddings(max_seq_len, embed_dim)
|
||||
assert embeddings is not None
|
||||
|
||||
|
||||
# Check that positional embeddings are created
|
||||
assert hasattr(embeddings, 'embedding')
|
||||
assert hasattr(embeddings, "embedding")
|
||||
assert embeddings.embedding.weight.shape == (max_seq_len, embed_dim)
|
||||
|
||||
|
||||
def test_forward_pass(self, embed_dim):
|
||||
"""Test forward pass of PositionalEmbeddings."""
|
||||
max_seq_len = 1024
|
||||
seq_len = 64
|
||||
embeddings = PositionalEmbeddings(max_seq_len, embed_dim)
|
||||
|
||||
|
||||
# Forward pass - takes sequence length, not input tensor
|
||||
output = embeddings(seq_len)
|
||||
|
||||
|
||||
# Check output shape
|
||||
expected_shape = (seq_len, embed_dim)
|
||||
assert output.shape == expected_shape
|
||||
assert isinstance(output, torch.Tensor)
|
||||
|
||||
|
||||
def test_positional_encoding_values(self, embed_dim):
|
||||
"""Test that positional encoding values are computed correctly."""
|
||||
max_seq_len = 10
|
||||
embeddings = PositionalEmbeddings(max_seq_len, embed_dim)
|
||||
|
||||
|
||||
# Get embeddings for all positions
|
||||
pe = embeddings(max_seq_len) # Shape: [max_seq_len, embed_dim]
|
||||
|
||||
|
||||
# Check that different positions have different embeddings
|
||||
# (since these are learnable embeddings, not fixed sine/cosine)
|
||||
for pos in range(max_seq_len):
|
||||
for i in range(pos + 1, max_seq_len):
|
||||
assert not torch.allclose(pe[pos], pe[i], rtol=1e-4)
|
||||
|
||||
|
||||
def test_different_sequence_lengths(self, embed_dim):
|
||||
"""Test PositionalEmbeddings with different sequence lengths."""
|
||||
test_cases = [
|
||||
(10, 5), # seq_len < max_seq_len
|
||||
(10, 5), # seq_len < max_seq_len
|
||||
(10, 10), # seq_len == max_seq_len
|
||||
]
|
||||
|
||||
|
||||
for max_seq_len, seq_len in test_cases:
|
||||
embeddings = PositionalEmbeddings(max_seq_len, embed_dim)
|
||||
|
||||
|
||||
# Get embeddings for specific sequence length
|
||||
output = embeddings(seq_len)
|
||||
|
||||
|
||||
# Output should have shape [seq_len, embed_dim]
|
||||
assert output.shape == (seq_len, embed_dim)
|
||||
|
||||
|
||||
def test_gradient_flow(self, embed_dim):
|
||||
"""Test that gradients flow through PositionalEmbeddings."""
|
||||
max_seq_len = 64
|
||||
seq_len = 32
|
||||
embeddings = PositionalEmbeddings(max_seq_len, embed_dim)
|
||||
|
||||
|
||||
# Forward pass
|
||||
output = embeddings(seq_len)
|
||||
|
||||
|
||||
# Create a dummy loss and backward pass
|
||||
loss = output.sum()
|
||||
loss.backward()
|
||||
|
||||
|
||||
# Positional embeddings should have gradients (they're learnable)
|
||||
assert embeddings.embedding.weight.grad is not None
|
||||
assert not torch.allclose(embeddings.embedding.weight.grad,
|
||||
torch.zeros_like(embeddings.embedding.weight.grad))
|
||||
|
||||
assert not torch.allclose(
|
||||
embeddings.embedding.weight.grad,
|
||||
torch.zeros_like(embeddings.embedding.weight.grad),
|
||||
)
|
||||
|
||||
def test_device_consistency(self, embed_dim, device):
|
||||
"""Test that PositionalEmbeddings works on correct device."""
|
||||
max_seq_len = 64
|
||||
seq_len = 32
|
||||
embeddings = PositionalEmbeddings(max_seq_len, embed_dim).to(device)
|
||||
|
||||
|
||||
# Forward pass
|
||||
output = embeddings(seq_len)
|
||||
|
||||
|
||||
# Check device consistency
|
||||
assert output.device == device
|
||||
assert embeddings.embedding.weight.device == device
|
||||
|
||||
|
||||
def test_reproducibility(self, embed_dim):
|
||||
"""Test that positional embeddings are reproducible."""
|
||||
max_seq_len = 100
|
||||
embeddings1 = PositionalEmbeddings(max_seq_len, embed_dim)
|
||||
embeddings2 = PositionalEmbeddings(max_seq_len, embed_dim)
|
||||
|
||||
|
||||
# Different instances should have different embeddings (random initialization)
|
||||
assert not torch.allclose(embeddings1.embedding.weight, embeddings2.embedding.weight)
|
||||
|
||||
assert not torch.allclose(
|
||||
embeddings1.embedding.weight, embeddings2.embedding.weight
|
||||
)
|
||||
|
||||
# But same instance should produce same output for same input
|
||||
seq_len = 50
|
||||
output1 = embeddings1(seq_len)
|
||||
output2 = embeddings1(seq_len) # Same instance, same input
|
||||
assert torch.allclose(output1, output2)
|
||||
|
||||
|
||||
def test_positional_pattern(self, embed_dim):
|
||||
"""Test that positional embeddings create a meaningful pattern."""
|
||||
max_seq_len = 50
|
||||
embeddings = PositionalEmbeddings(max_seq_len, embed_dim)
|
||||
pe = embeddings(max_seq_len) # Get all positional embeddings
|
||||
|
||||
|
||||
# Check that different positions have different embeddings
|
||||
# (with high probability due to random initialization)
|
||||
assert not torch.allclose(pe[0], pe[1], rtol=1e-4)
|
||||
assert not torch.allclose(pe[10], pe[20], rtol=1e-4)
|
||||
|
||||
@pytest.mark.parametrize("max_seq_len,seq_len,embed_dim", [
|
||||
(64, 10, 64),
|
||||
(128, 50, 128),
|
||||
(256, 100, 256),
|
||||
])
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"max_seq_len,seq_len,embed_dim",
|
||||
[
|
||||
(64, 10, 64),
|
||||
(128, 50, 128),
|
||||
(256, 100, 256),
|
||||
],
|
||||
)
|
||||
def test_different_configurations(self, max_seq_len, seq_len, embed_dim):
|
||||
"""Test PositionalEmbeddings with different configurations."""
|
||||
embeddings = PositionalEmbeddings(max_seq_len, embed_dim)
|
||||
|
||||
|
||||
output = embeddings(seq_len)
|
||||
|
||||
|
||||
assert output.shape == (seq_len, embed_dim)
|
||||
|
||||
47
llm/tests/core/test_rms_norm.py
Normal file
47
llm/tests/core/test_rms_norm.py
Normal file
@@ -0,0 +1,47 @@
|
||||
import torch
|
||||
import pytest
|
||||
from llm.core.rms_norm import RMSNorm
|
||||
|
||||
def test_rmsnorm_shape_preservation():
|
||||
norm = RMSNorm(64)
|
||||
x = torch.randn(3, 5, 64)
|
||||
y = norm(x)
|
||||
assert y.shape == x.shape
|
||||
|
||||
def test_rmsnorm_dtype_and_device():
|
||||
norm = RMSNorm(32)
|
||||
x = torch.randn(8, 32, device='cpu', dtype=torch.float64)
|
||||
y = norm(x)
|
||||
assert y.dtype == torch.float64
|
||||
assert y.device == x.device
|
||||
|
||||
def test_rmsnorm_mean_no_shift():
|
||||
norm = RMSNorm(32)
|
||||
x = torch.randn(3, 128, 32)
|
||||
y = norm(x)
|
||||
rms = torch.sqrt((y ** 2).mean(dim=-1))
|
||||
w_mean = norm._w.mean().item()
|
||||
assert torch.allclose(rms.mean(), torch.tensor(w_mean), rtol=0.2, atol=0.2)
|
||||
|
||||
def test_rmsnorm_backward():
|
||||
norm = RMSNorm(16)
|
||||
x = torch.randn(2, 15, 16, requires_grad=True)
|
||||
y = norm(x)
|
||||
loss = y.sum()
|
||||
loss.backward()
|
||||
assert x.grad is not None
|
||||
assert norm._w.grad is not None
|
||||
|
||||
def test_rmsnorm_fp16():
|
||||
norm = RMSNorm(8).half()
|
||||
x = torch.randn(2, 6, 8).half()
|
||||
y = norm(x)
|
||||
assert y.shape == x.shape
|
||||
assert y.dtype == torch.float16
|
||||
|
||||
def test_rmsnorm_large_eps_stability():
|
||||
norm = RMSNorm(16, eps=1)
|
||||
x = torch.zeros(2, 5, 16)
|
||||
y = norm(x)
|
||||
assert not torch.isnan(y).any()
|
||||
assert not torch.isinf(y).any()
|
||||
55
llm/tests/core/test_rope.py
Normal file
55
llm/tests/core/test_rope.py
Normal file
@@ -0,0 +1,55 @@
|
||||
import torch
|
||||
import pytest
|
||||
from llm.core.rope import RoPE
|
||||
|
||||
def test_rope_shapes_and_dtype():
|
||||
rope = RoPE(head_size=8, max_seq_len=32)
|
||||
x = torch.randn(2, 4, 16, 8) # [batch, num_heads, seq_len, head_size]
|
||||
y = rope(x)
|
||||
assert y.shape == x.shape
|
||||
assert y.dtype == x.dtype
|
||||
|
||||
def test_rope_raises_on_bad_ndim():
|
||||
rope = RoPE(head_size=8, max_seq_len=16)
|
||||
x = torch.randn(2, 16, 8) # [batch, seq_len, head_size] (3D)
|
||||
with pytest.raises(AssertionError):
|
||||
_ = rope(x)
|
||||
|
||||
def test_rope_preserves_norm():
|
||||
rope = RoPE(head_size=8, max_seq_len=16)
|
||||
x = torch.randn(2, 3, 7, 8)
|
||||
x_norm = x.norm(dim=-1)
|
||||
y = rope(x)
|
||||
y_norm = y.norm(dim=-1)
|
||||
# Нормы могут немного отличаться из-за float, сравниваем с допуском
|
||||
assert torch.allclose(x_norm, y_norm, rtol=1e-5, atol=1e-7)
|
||||
|
||||
def test_rope_backward_pass():
|
||||
rope = RoPE(head_size=8, max_seq_len=16)
|
||||
x = torch.randn(2, 2, 8, 8, requires_grad=True)
|
||||
out = rope(x)
|
||||
loss = out.sum()
|
||||
loss.backward()
|
||||
assert x.grad is not None
|
||||
assert x.grad.shape == x.shape
|
||||
|
||||
@pytest.mark.parametrize("batch,num_heads,seq_len,head_size", [
|
||||
(1, 1, 4, 8),
|
||||
(2, 4, 16, 8),
|
||||
(3, 2, 7, 8),
|
||||
])
|
||||
def test_rope_various_shapes(batch, num_heads, seq_len, head_size):
|
||||
rope = RoPE(head_size=head_size, max_seq_len=32)
|
||||
x = torch.randn(batch, num_heads, seq_len, head_size)
|
||||
y = rope(x)
|
||||
assert y.shape == x.shape
|
||||
|
||||
def test_rope_start_pos():
|
||||
rope = RoPE(head_size=8, max_seq_len=32)
|
||||
x_full = torch.randn(1, 2, 8, 8)
|
||||
# Сравниваем участок результата для разных start_pos
|
||||
out1 = rope(x_full)
|
||||
out2 = rope(x_full, start_pos=2)
|
||||
assert not torch.allclose(out1, out2)
|
||||
# Для одинакового start_pos и x должны совпадать
|
||||
assert torch.allclose(rope(x_full, start_pos=1), rope(x_full, start_pos=1))
|
||||
42
llm/tests/core/test_silu.py
Normal file
42
llm/tests/core/test_silu.py
Normal file
@@ -0,0 +1,42 @@
|
||||
import torch
|
||||
import pytest
|
||||
from llm.core.silu import SiLU
|
||||
|
||||
def test_silu_shape_and_dtype():
|
||||
silu = SiLU()
|
||||
x = torch.randn(3, 10, 8)
|
||||
y = silu(x)
|
||||
assert y.shape == x.shape
|
||||
assert y.dtype == x.dtype
|
||||
|
||||
def test_silu_known_values():
|
||||
silu = SiLU()
|
||||
x = torch.tensor([-2.0, 0.0, 2.0])
|
||||
y = silu(x)
|
||||
# PyTorch эталон
|
||||
y_ref = torch.nn.functional.silu(x)
|
||||
assert torch.allclose(y, y_ref, atol=1e-6)
|
||||
|
||||
def test_silu_large_vs_small():
|
||||
silu = SiLU()
|
||||
x_pos = torch.tensor([100.0])
|
||||
x_neg = torch.tensor([-100.0])
|
||||
y_pos = silu(x_pos)
|
||||
y_neg = silu(x_neg)
|
||||
assert torch.allclose(y_pos, x_pos, rtol=1e-4, atol=1e-4) # SiLU(x) ~ x для больших x>0
|
||||
assert torch.allclose(y_neg, torch.zeros_like(x_neg), rtol=1e-4, atol=1e-4) # SiLU(x) ~ 0 для x<0
|
||||
|
||||
def test_silu_gradients():
|
||||
silu = SiLU()
|
||||
x = torch.randn(4, 4, requires_grad=True)
|
||||
y = silu(x)
|
||||
loss = y.sum()
|
||||
loss.backward()
|
||||
assert x.grad is not None
|
||||
assert x.grad.shape == x.shape
|
||||
|
||||
def test_silu_broadcast():
|
||||
silu = SiLU()
|
||||
x = torch.randn(3, 1, 16)
|
||||
y = silu(x)
|
||||
assert y.shape == x.shape
|
||||
39
llm/tests/core/test_swi_glu.py
Normal file
39
llm/tests/core/test_swi_glu.py
Normal file
@@ -0,0 +1,39 @@
|
||||
import torch
|
||||
import pytest
|
||||
from llm.core.swi_glu import SwiGLU
|
||||
|
||||
def test_swiglu_shape_and_dtype():
|
||||
swiglu = SwiGLU(emb_size=32, dropout=0.1)
|
||||
x = torch.randn(4, 10, 32)
|
||||
y = swiglu(x)
|
||||
assert y.shape == x.shape
|
||||
assert y.dtype == x.dtype
|
||||
|
||||
def test_swiglu_forward_range():
|
||||
swiglu = SwiGLU(emb_size=16, dropout=0.0)
|
||||
x = torch.randn(3, 7, 16)
|
||||
y = swiglu(x)
|
||||
assert y.abs().max() < 20
|
||||
|
||||
def test_swiglu_gradients():
|
||||
swiglu = SwiGLU(emb_size=8, dropout=0.0)
|
||||
x = torch.randn(2, 5, 8, requires_grad=True)
|
||||
out = swiglu(x)
|
||||
loss = out.pow(2).sum()
|
||||
loss.backward()
|
||||
assert x.grad is not None
|
||||
assert x.grad.shape == x.shape
|
||||
|
||||
def test_swiglu_fp16():
|
||||
swiglu = SwiGLU(emb_size=16, dropout=0.0).half()
|
||||
x = torch.randn(1, 8, 16).half()
|
||||
y = swiglu(x)
|
||||
assert y.shape == x.shape
|
||||
assert y.dtype == torch.float16
|
||||
|
||||
def test_swiglu_reproducibility():
|
||||
swiglu = SwiGLU(emb_size=8, dropout=0.0)
|
||||
x = torch.ones(2, 4, 8)
|
||||
y1 = swiglu(x)
|
||||
y2 = swiglu(x)
|
||||
assert torch.allclose(y1, y2)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user