mirror of
https://github.com/pese-git/simple-llm.git
synced 2026-01-23 21:14:17 +00:00
Добавление механизма внимания HeadAttention
- Реализация одного головного внимания из Transformer - Полная документация на русском языке - Пример использования с визуализацией - Обновление README с ссылками
This commit is contained in:
12
README.md
12
README.md
@@ -88,9 +88,17 @@ git+https://github.com/yourusername/simple-llm.git
|
|||||||
pip install git+https://github.com/yourusername/simple-llm.git
|
pip install git+https://github.com/yourusername/simple-llm.git
|
||||||
```
|
```
|
||||||
|
|
||||||
## Примеры
|
## Примеры использования
|
||||||
|
|
||||||
Дополнительные примеры использования смотрите в папке [example](/example):
|
Дополнительные примеры:
|
||||||
|
- [Базовый BPE](/example/example_bpe.py)
|
||||||
|
- [Токенные эмбеддинги](/example/example_token_embeddings.py)
|
||||||
|
- [Механизм внимания](/example/head_attention_example.py)
|
||||||
|
|
||||||
|
Документация:
|
||||||
|
- [Токенизация](/doc/bpe_algorithm.md)
|
||||||
|
- [Эмбеддинги](/doc/token_embeddings_ru.md)
|
||||||
|
- [Внимание](/doc/head_attention_ru.md)
|
||||||
- Сравнение SimpleBPE и OptimizeBPE
|
- Сравнение SimpleBPE и OptimizeBPE
|
||||||
- Работа с разными языками
|
- Работа с разными языками
|
||||||
- Настройка параметров токенизации
|
- Настройка параметров токенизации
|
||||||
|
|||||||
83
doc/head_attention_ru.md
Normal file
83
doc/head_attention_ru.md
Normal file
@@ -0,0 +1,83 @@
|
|||||||
|
# HeadAttention - Механизм самовнимания одной головы
|
||||||
|
|
||||||
|
## Назначение
|
||||||
|
Модуль реализует механизм внимания одной головы из архитектуры Transformer. Основные применения:
|
||||||
|
- Моделирование зависимостей в последовательностях
|
||||||
|
- Обработка естественного языка (NLP)
|
||||||
|
- Генерация текста с учетом контекста
|
||||||
|
- Анализ временных рядов
|
||||||
|
|
||||||
|
## Алгоритм работы
|
||||||
|
|
||||||
|
```mermaid
|
||||||
|
flowchart TD
|
||||||
|
A[Входной тензор x] --> B[Вычисление Q, K, V]
|
||||||
|
B --> C["Scores = Q·Kᵀ / √d_k"]
|
||||||
|
C --> D[Применение нижнетреугольной маски]
|
||||||
|
D --> E[Softmax]
|
||||||
|
E --> F[Взвешенная сумма значений V]
|
||||||
|
F --> G[Выходной тензор]
|
||||||
|
```
|
||||||
|
|
||||||
|
1. **Линейные преобразования**:
|
||||||
|
```python
|
||||||
|
Q = W_q·x, K = W_k·x, V = W_v·x
|
||||||
|
```
|
||||||
|
|
||||||
|
2. **Вычисление attention scores**:
|
||||||
|
```python
|
||||||
|
scores = matmul(Q, K.transpose(-2, -1)) / sqrt(head_size)
|
||||||
|
```
|
||||||
|
|
||||||
|
3. **Маскирование**:
|
||||||
|
```python
|
||||||
|
scores.masked_fill_(mask == 0, -inf) # Causal masking
|
||||||
|
```
|
||||||
|
|
||||||
|
4. **Взвешивание**:
|
||||||
|
```python
|
||||||
|
weights = softmax(scores, dim=-1)
|
||||||
|
output = matmul(weights, V)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Пример использования
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
from simple_llm.transformer.head_attention import HeadAttention
|
||||||
|
|
||||||
|
# Параметры
|
||||||
|
emb_size = 512
|
||||||
|
head_size = 64
|
||||||
|
max_seq_len = 1024
|
||||||
|
|
||||||
|
# Инициализация
|
||||||
|
attn_head = HeadAttention(emb_size, head_size, max_seq_len)
|
||||||
|
|
||||||
|
# Пример входа (batch_size=2, seq_len=10)
|
||||||
|
x = torch.randn(2, 10, emb_size)
|
||||||
|
output = attn_head(x) # [2, 10, head_size]
|
||||||
|
```
|
||||||
|
|
||||||
|
## Особенности реализации
|
||||||
|
|
||||||
|
### Ключевые компоненты
|
||||||
|
| Компонент | Назначение |
|
||||||
|
|-----------------|-------------------------------------|
|
||||||
|
| `self._q` | Линейный слой для Query |
|
||||||
|
| `self._k` | Линейный слой для Key |
|
||||||
|
| `self._v` | Линейный слой для Value |
|
||||||
|
| `self._tril_mask`| Нижнетреугольная маска |
|
||||||
|
|
||||||
|
### Ограничения
|
||||||
|
- Требует O(n²) памяти для матрицы внимания
|
||||||
|
- Поддерживает только causal-режим
|
||||||
|
- Фиксированный максимальный размер последовательности
|
||||||
|
|
||||||
|
## Рекомендации по использованию
|
||||||
|
1. Размер головы (`head_size`) обычно выбирают 64-128
|
||||||
|
2. Для длинных последовательностей (>512) используйте оптимизации:
|
||||||
|
- Локальное внимание
|
||||||
|
- Разреженные паттерны
|
||||||
|
3. Сочетайте с MultiHeadAttention для лучшего качества
|
||||||
|
|
||||||
|
[Дополнительные примеры](/example/attention_examples.py)
|
||||||
88
example/head_attention_example.py
Normal file
88
example/head_attention_example.py
Normal file
@@ -0,0 +1,88 @@
|
|||||||
|
"""
|
||||||
|
Пример использования механизма внимания HeadAttention
|
||||||
|
с визуализацией матрицы внимания и анализом работы.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy as np
|
||||||
|
from simple_llm.transformer.head_attention import HeadAttention
|
||||||
|
|
||||||
|
import os
|
||||||
|
os.makedirs("example_output", exist_ok=True)
|
||||||
|
|
||||||
|
def plot_attention(weights, tokens=None, filename="attention_plot.png"):
|
||||||
|
"""Сохранение матрицы внимания в файл"""
|
||||||
|
plt.figure(figsize=(10, 8))
|
||||||
|
plt.imshow(weights, cmap='viridis')
|
||||||
|
|
||||||
|
if tokens:
|
||||||
|
plt.xticks(range(len(tokens)), tokens, rotation=90)
|
||||||
|
plt.yticks(range(len(tokens)), tokens)
|
||||||
|
|
||||||
|
plt.colorbar()
|
||||||
|
plt.title("Матрица весов внимания")
|
||||||
|
plt.xlabel("Key Positions")
|
||||||
|
plt.ylabel("Query Positions")
|
||||||
|
plt.savefig(f"example_output/{filename}")
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
def simulate_text_attention():
|
||||||
|
"""Пример с имитацией текстовых данных"""
|
||||||
|
# Параметры
|
||||||
|
emb_size = 64
|
||||||
|
head_size = 32
|
||||||
|
seq_len = 8
|
||||||
|
|
||||||
|
# Имитация токенов
|
||||||
|
tokens = ["[CLS]", "мама", "мыла", "раму", ",", "папа", "пил", "какао"]
|
||||||
|
|
||||||
|
# Инициализация
|
||||||
|
torch.manual_seed(42)
|
||||||
|
attention = HeadAttention(emb_size, head_size, max_seq_len=seq_len)
|
||||||
|
|
||||||
|
# Случайные эмбеддинги (в реальности - выход слоя токенизации)
|
||||||
|
x = torch.randn(1, seq_len, emb_size)
|
||||||
|
|
||||||
|
# Прямой проход + получение весов
|
||||||
|
with torch.no_grad():
|
||||||
|
output = attention(x)
|
||||||
|
q, k = attention._q(x), attention._k(x)
|
||||||
|
scores = (q @ k.transpose(-2, -1)) / np.sqrt(head_size)
|
||||||
|
weights = torch.softmax(scores, dim=-1).squeeze()
|
||||||
|
|
||||||
|
# Визуализация
|
||||||
|
print("\nПример для фразы:", " ".join(tokens))
|
||||||
|
print("Форма выходного тензора:", output.shape)
|
||||||
|
plot_attention(weights.numpy(), tokens)
|
||||||
|
|
||||||
|
def technical_demo():
|
||||||
|
"""Техническая демонстрация работы механизма"""
|
||||||
|
print("\nТехническая демонстрация HeadAttention")
|
||||||
|
attention = HeadAttention(emb_size=16, head_size=8, max_seq_len=10)
|
||||||
|
|
||||||
|
# Создаем тензор с ручными значениями для анализа
|
||||||
|
x = torch.zeros(1, 4, 16)
|
||||||
|
x[0, 0, :] = 1.0 # Яркий токен
|
||||||
|
x[0, 3, :] = 0.5 # Слабый токен
|
||||||
|
|
||||||
|
# Анализ весов
|
||||||
|
with torch.no_grad():
|
||||||
|
output = attention(x)
|
||||||
|
q = attention._q(x)
|
||||||
|
k = attention._k(x)
|
||||||
|
print("\nQuery векторы (первые 5 значений):")
|
||||||
|
print(q[0, :, :5])
|
||||||
|
|
||||||
|
print("\nKey векторы (первые 5 значений):")
|
||||||
|
print(k[0, :, :5])
|
||||||
|
|
||||||
|
weights = torch.softmax((q @ k.transpose(-2, -1)) / np.sqrt(8), dim=-1)
|
||||||
|
print("\nМатрица внимания:")
|
||||||
|
print(weights.squeeze().round(decimals=3))
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
print("Демонстрация работы HeadAttention")
|
||||||
|
simulate_text_attention()
|
||||||
|
technical_demo()
|
||||||
|
print("\nГотово! Проверьте графики матрицы внимания.")
|
||||||
84
simple_llm/transformer/head_attention.py
Normal file
84
simple_llm/transformer/head_attention.py
Normal file
@@ -0,0 +1,84 @@
|
|||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from math import sqrt
|
||||||
|
|
||||||
|
class HeadAttention(nn.Module):
|
||||||
|
"""
|
||||||
|
Реализация одного головного механизма внимания из архитектуры Transformer.
|
||||||
|
Выполняет scaled dot-product attention с маскированием будущих позиций (causal attention).
|
||||||
|
|
||||||
|
Основной алгоритм:
|
||||||
|
1. Линейные преобразования входных данных в Q (query), K (key), V (value)
|
||||||
|
2. Вычисление scores = Q·K^T / sqrt(d_k)
|
||||||
|
3. Применение causal маски (заполнение -inf будущих позиций)
|
||||||
|
4. Softmax для получения весов внимания
|
||||||
|
5. Умножение весов на значения V
|
||||||
|
|
||||||
|
Пример использования:
|
||||||
|
>>> attention = HeadAttention(emb_size=64, head_size=32, max_seq_len=128)
|
||||||
|
>>> x = torch.randn(1, 10, 64) # [batch_size, seq_len, emb_size]
|
||||||
|
>>> output = attention(x) # [1, 10, 32]
|
||||||
|
|
||||||
|
Параметры:
|
||||||
|
emb_size (int): Размер входного эмбеддинга
|
||||||
|
head_size (int): Размерность выхода головы внимания
|
||||||
|
max_seq_len (int): Максимальная длина последовательности
|
||||||
|
|
||||||
|
Примечания:
|
||||||
|
- Использует нижнетреугольную маску для предотвращения "заглядывания в будущее"
|
||||||
|
- Автоматически адаптируется к разным версиям PyTorch
|
||||||
|
- Поддерживает batch-обработку входных данных
|
||||||
|
"""
|
||||||
|
def __init__(self, emb_size: int, head_size: int, max_seq_len: int):
|
||||||
|
super().__init__()
|
||||||
|
self._emb_size = emb_size
|
||||||
|
self._head_size = head_size
|
||||||
|
self._max_seq_len = max_seq_len
|
||||||
|
|
||||||
|
# Линейные преобразования для Q, K, V
|
||||||
|
self._k = nn.Linear(emb_size, head_size)
|
||||||
|
self._q = nn.Linear(emb_size, head_size)
|
||||||
|
self._v = nn.Linear(emb_size, head_size)
|
||||||
|
|
||||||
|
# Создание causal маски
|
||||||
|
mask = torch.tril(torch.ones(max_seq_len, max_seq_len))
|
||||||
|
self.register_buffer('_tril_mask', mask.bool() if hasattr(torch, 'bool') else mask.byte())
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Прямой проход через слой внимания.
|
||||||
|
|
||||||
|
Аргументы:
|
||||||
|
x (torch.Tensor): Входной тензор формы [batch_size, seq_len, emb_size]
|
||||||
|
|
||||||
|
Возвращает:
|
||||||
|
torch.Tensor: Выходной тензор формы [batch_size, seq_len, head_size]
|
||||||
|
|
||||||
|
Исключения:
|
||||||
|
ValueError: Если длина последовательности превышает max_seq_len
|
||||||
|
|
||||||
|
Пример внутренних преобразований:
|
||||||
|
Для входа x.shape = [2, 5, 64]:
|
||||||
|
1. Q/K/V преобразования -> [2, 5, 32]
|
||||||
|
2. Scores = Q·K^T -> [2, 5, 5]
|
||||||
|
3. После маски и softmax -> [2, 5, 5]
|
||||||
|
4. Умножение на V -> [2, 5, 32]
|
||||||
|
"""
|
||||||
|
seq_len = x.shape[1]
|
||||||
|
if seq_len > self._max_seq_len:
|
||||||
|
raise ValueError(f"Длина последовательности {seq_len} превышает максимум {self._max_seq_len}")
|
||||||
|
|
||||||
|
# 1. Линейные преобразования
|
||||||
|
k = self._k(x) # [B, T, hs]
|
||||||
|
q = self._q(x) # [B, T, hs]
|
||||||
|
|
||||||
|
# 2. Вычисление scores
|
||||||
|
scores = q @ k.transpose(-2, -1) / sqrt(self._head_size)
|
||||||
|
|
||||||
|
# 3. Применение causal маски
|
||||||
|
scores = scores.masked_fill(~self._tril_mask[:seq_len, :seq_len], float('-inf'))
|
||||||
|
|
||||||
|
# 4. Softmax и умножение на V
|
||||||
|
weights = F.softmax(scores, dim=-1)
|
||||||
|
return weights @ self._v(x)
|
||||||
Reference in New Issue
Block a user