Добавление механизма внимания HeadAttention

- Реализация одного головного внимания из Transformer
- Полная документация на русском языке
- Пример использования с визуализацией
- Обновление README с ссылками
This commit is contained in:
Sergey Penkovsky
2025-07-19 11:35:11 +03:00
parent 9765140f67
commit a150828665
4 changed files with 265 additions and 2 deletions

View File

@@ -88,9 +88,17 @@ git+https://github.com/yourusername/simple-llm.git
pip install git+https://github.com/yourusername/simple-llm.git
```
## Примеры
## Примеры использования
Дополнительные примеры использования смотрите в папке [example](/example):
Дополнительные примеры:
- [Базовый BPE](/example/example_bpe.py)
- [Токенные эмбеддинги](/example/example_token_embeddings.py)
- [Механизм внимания](/example/head_attention_example.py)
Документация:
- [Токенизация](/doc/bpe_algorithm.md)
- [Эмбеддинги](/doc/token_embeddings_ru.md)
- [Внимание](/doc/head_attention_ru.md)
- Сравнение SimpleBPE и OptimizeBPE
- Работа с разными языками
- Настройка параметров токенизации

83
doc/head_attention_ru.md Normal file
View File

@@ -0,0 +1,83 @@
# HeadAttention - Механизм самовнимания одной головы
## Назначение
Модуль реализует механизм внимания одной головы из архитектуры Transformer. Основные применения:
- Моделирование зависимостей в последовательностях
- Обработка естественного языка (NLP)
- Генерация текста с учетом контекста
- Анализ временных рядов
## Алгоритм работы
```mermaid
flowchart TD
A[Входной тензор x] --> B[Вычисление Q, K, V]
B --> C["Scores = Q·Kᵀ / √d_k"]
C --> D[Применение нижнетреугольной маски]
D --> E[Softmax]
E --> F[Взвешенная сумма значений V]
F --> G[Выходной тензор]
```
1. **Линейные преобразования**:
```python
Q = W_q·x, K = W_k·x, V = W_v·x
```
2. **Вычисление attention scores**:
```python
scores = matmul(Q, K.transpose(-2, -1)) / sqrt(head_size)
```
3. **Маскирование**:
```python
scores.masked_fill_(mask == 0, -inf) # Causal masking
```
4. **Взвешивание**:
```python
weights = softmax(scores, dim=-1)
output = matmul(weights, V)
```
## Пример использования
```python
import torch
from simple_llm.transformer.head_attention import HeadAttention
# Параметры
emb_size = 512
head_size = 64
max_seq_len = 1024
# Инициализация
attn_head = HeadAttention(emb_size, head_size, max_seq_len)
# Пример входа (batch_size=2, seq_len=10)
x = torch.randn(2, 10, emb_size)
output = attn_head(x) # [2, 10, head_size]
```
## Особенности реализации
### Ключевые компоненты
| Компонент | Назначение |
|-----------------|-------------------------------------|
| `self._q` | Линейный слой для Query |
| `self._k` | Линейный слой для Key |
| `self._v` | Линейный слой для Value |
| `self._tril_mask`| Нижнетреугольная маска |
### Ограничения
- Требует O(n²) памяти для матрицы внимания
- Поддерживает только causal-режим
- Фиксированный максимальный размер последовательности
## Рекомендации по использованию
1. Размер головы (`head_size`) обычно выбирают 64-128
2. Для длинных последовательностей (>512) используйте оптимизации:
- Локальное внимание
- Разреженные паттерны
3. Сочетайте с MultiHeadAttention для лучшего качества
[Дополнительные примеры](/example/attention_examples.py)

View File

@@ -0,0 +1,88 @@
"""
Пример использования механизма внимания HeadAttention
с визуализацией матрицы внимания и анализом работы.
"""
import torch
import matplotlib.pyplot as plt
import numpy as np
from simple_llm.transformer.head_attention import HeadAttention
import os
os.makedirs("example_output", exist_ok=True)
def plot_attention(weights, tokens=None, filename="attention_plot.png"):
"""Сохранение матрицы внимания в файл"""
plt.figure(figsize=(10, 8))
plt.imshow(weights, cmap='viridis')
if tokens:
plt.xticks(range(len(tokens)), tokens, rotation=90)
plt.yticks(range(len(tokens)), tokens)
plt.colorbar()
plt.title("Матрица весов внимания")
plt.xlabel("Key Positions")
plt.ylabel("Query Positions")
plt.savefig(f"example_output/{filename}")
plt.close()
def simulate_text_attention():
"""Пример с имитацией текстовых данных"""
# Параметры
emb_size = 64
head_size = 32
seq_len = 8
# Имитация токенов
tokens = ["[CLS]", "мама", "мыла", "раму", ",", "папа", "пил", "какао"]
# Инициализация
torch.manual_seed(42)
attention = HeadAttention(emb_size, head_size, max_seq_len=seq_len)
# Случайные эмбеддинги (в реальности - выход слоя токенизации)
x = torch.randn(1, seq_len, emb_size)
# Прямой проход + получение весов
with torch.no_grad():
output = attention(x)
q, k = attention._q(x), attention._k(x)
scores = (q @ k.transpose(-2, -1)) / np.sqrt(head_size)
weights = torch.softmax(scores, dim=-1).squeeze()
# Визуализация
print("\nПример для фразы:", " ".join(tokens))
print("Форма выходного тензора:", output.shape)
plot_attention(weights.numpy(), tokens)
def technical_demo():
"""Техническая демонстрация работы механизма"""
print("\nТехническая демонстрация HeadAttention")
attention = HeadAttention(emb_size=16, head_size=8, max_seq_len=10)
# Создаем тензор с ручными значениями для анализа
x = torch.zeros(1, 4, 16)
x[0, 0, :] = 1.0 # Яркий токен
x[0, 3, :] = 0.5 # Слабый токен
# Анализ весов
with torch.no_grad():
output = attention(x)
q = attention._q(x)
k = attention._k(x)
print("\nQuery векторы (первые 5 значений):")
print(q[0, :, :5])
print("\nKey векторы (первые 5 значений):")
print(k[0, :, :5])
weights = torch.softmax((q @ k.transpose(-2, -1)) / np.sqrt(8), dim=-1)
print("\nМатрица внимания:")
print(weights.squeeze().round(decimals=3))
if __name__ == "__main__":
print("Демонстрация работы HeadAttention")
simulate_text_attention()
technical_demo()
print("\nГотово! Проверьте графики матрицы внимания.")

View File

@@ -0,0 +1,84 @@
import torch
from torch import nn
import torch.nn.functional as F
from math import sqrt
class HeadAttention(nn.Module):
"""
Реализация одного головного механизма внимания из архитектуры Transformer.
Выполняет scaled dot-product attention с маскированием будущих позиций (causal attention).
Основной алгоритм:
1. Линейные преобразования входных данных в Q (query), K (key), V (value)
2. Вычисление scores = Q·K^T / sqrt(d_k)
3. Применение causal маски (заполнение -inf будущих позиций)
4. Softmax для получения весов внимания
5. Умножение весов на значения V
Пример использования:
>>> attention = HeadAttention(emb_size=64, head_size=32, max_seq_len=128)
>>> x = torch.randn(1, 10, 64) # [batch_size, seq_len, emb_size]
>>> output = attention(x) # [1, 10, 32]
Параметры:
emb_size (int): Размер входного эмбеддинга
head_size (int): Размерность выхода головы внимания
max_seq_len (int): Максимальная длина последовательности
Примечания:
- Использует нижнетреугольную маску для предотвращения "заглядывания в будущее"
- Автоматически адаптируется к разным версиям PyTorch
- Поддерживает batch-обработку входных данных
"""
def __init__(self, emb_size: int, head_size: int, max_seq_len: int):
super().__init__()
self._emb_size = emb_size
self._head_size = head_size
self._max_seq_len = max_seq_len
# Линейные преобразования для Q, K, V
self._k = nn.Linear(emb_size, head_size)
self._q = nn.Linear(emb_size, head_size)
self._v = nn.Linear(emb_size, head_size)
# Создание causal маски
mask = torch.tril(torch.ones(max_seq_len, max_seq_len))
self.register_buffer('_tril_mask', mask.bool() if hasattr(torch, 'bool') else mask.byte())
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Прямой проход через слой внимания.
Аргументы:
x (torch.Tensor): Входной тензор формы [batch_size, seq_len, emb_size]
Возвращает:
torch.Tensor: Выходной тензор формы [batch_size, seq_len, head_size]
Исключения:
ValueError: Если длина последовательности превышает max_seq_len
Пример внутренних преобразований:
Для входа x.shape = [2, 5, 64]:
1. Q/K/V преобразования -> [2, 5, 32]
2. Scores = Q·K^T -> [2, 5, 5]
3. После маски и softmax -> [2, 5, 5]
4. Умножение на V -> [2, 5, 32]
"""
seq_len = x.shape[1]
if seq_len > self._max_seq_len:
raise ValueError(f"Длина последовательности {seq_len} превышает максимум {self._max_seq_len}")
# 1. Линейные преобразования
k = self._k(x) # [B, T, hs]
q = self._q(x) # [B, T, hs]
# 2. Вычисление scores
scores = q @ k.transpose(-2, -1) / sqrt(self._head_size)
# 3. Применение causal маски
scores = scores.masked_fill(~self._tril_mask[:seq_len, :seq_len], float('-inf'))
# 4. Softmax и умножение на V
weights = F.softmax(scores, dim=-1)
return weights @ self._v(x)