mirror of
https://github.com/pese-git/simple-llm.git
synced 2026-01-23 21:14:17 +00:00
Добавление механизма внимания HeadAttention
- Реализация одного головного внимания из Transformer - Полная документация на русском языке - Пример использования с визуализацией - Обновление README с ссылками
This commit is contained in:
12
README.md
12
README.md
@@ -88,9 +88,17 @@ git+https://github.com/yourusername/simple-llm.git
|
||||
pip install git+https://github.com/yourusername/simple-llm.git
|
||||
```
|
||||
|
||||
## Примеры
|
||||
## Примеры использования
|
||||
|
||||
Дополнительные примеры использования смотрите в папке [example](/example):
|
||||
Дополнительные примеры:
|
||||
- [Базовый BPE](/example/example_bpe.py)
|
||||
- [Токенные эмбеддинги](/example/example_token_embeddings.py)
|
||||
- [Механизм внимания](/example/head_attention_example.py)
|
||||
|
||||
Документация:
|
||||
- [Токенизация](/doc/bpe_algorithm.md)
|
||||
- [Эмбеддинги](/doc/token_embeddings_ru.md)
|
||||
- [Внимание](/doc/head_attention_ru.md)
|
||||
- Сравнение SimpleBPE и OptimizeBPE
|
||||
- Работа с разными языками
|
||||
- Настройка параметров токенизации
|
||||
|
||||
83
doc/head_attention_ru.md
Normal file
83
doc/head_attention_ru.md
Normal file
@@ -0,0 +1,83 @@
|
||||
# HeadAttention - Механизм самовнимания одной головы
|
||||
|
||||
## Назначение
|
||||
Модуль реализует механизм внимания одной головы из архитектуры Transformer. Основные применения:
|
||||
- Моделирование зависимостей в последовательностях
|
||||
- Обработка естественного языка (NLP)
|
||||
- Генерация текста с учетом контекста
|
||||
- Анализ временных рядов
|
||||
|
||||
## Алгоритм работы
|
||||
|
||||
```mermaid
|
||||
flowchart TD
|
||||
A[Входной тензор x] --> B[Вычисление Q, K, V]
|
||||
B --> C["Scores = Q·Kᵀ / √d_k"]
|
||||
C --> D[Применение нижнетреугольной маски]
|
||||
D --> E[Softmax]
|
||||
E --> F[Взвешенная сумма значений V]
|
||||
F --> G[Выходной тензор]
|
||||
```
|
||||
|
||||
1. **Линейные преобразования**:
|
||||
```python
|
||||
Q = W_q·x, K = W_k·x, V = W_v·x
|
||||
```
|
||||
|
||||
2. **Вычисление attention scores**:
|
||||
```python
|
||||
scores = matmul(Q, K.transpose(-2, -1)) / sqrt(head_size)
|
||||
```
|
||||
|
||||
3. **Маскирование**:
|
||||
```python
|
||||
scores.masked_fill_(mask == 0, -inf) # Causal masking
|
||||
```
|
||||
|
||||
4. **Взвешивание**:
|
||||
```python
|
||||
weights = softmax(scores, dim=-1)
|
||||
output = matmul(weights, V)
|
||||
```
|
||||
|
||||
## Пример использования
|
||||
```python
|
||||
import torch
|
||||
from simple_llm.transformer.head_attention import HeadAttention
|
||||
|
||||
# Параметры
|
||||
emb_size = 512
|
||||
head_size = 64
|
||||
max_seq_len = 1024
|
||||
|
||||
# Инициализация
|
||||
attn_head = HeadAttention(emb_size, head_size, max_seq_len)
|
||||
|
||||
# Пример входа (batch_size=2, seq_len=10)
|
||||
x = torch.randn(2, 10, emb_size)
|
||||
output = attn_head(x) # [2, 10, head_size]
|
||||
```
|
||||
|
||||
## Особенности реализации
|
||||
|
||||
### Ключевые компоненты
|
||||
| Компонент | Назначение |
|
||||
|-----------------|-------------------------------------|
|
||||
| `self._q` | Линейный слой для Query |
|
||||
| `self._k` | Линейный слой для Key |
|
||||
| `self._v` | Линейный слой для Value |
|
||||
| `self._tril_mask`| Нижнетреугольная маска |
|
||||
|
||||
### Ограничения
|
||||
- Требует O(n²) памяти для матрицы внимания
|
||||
- Поддерживает только causal-режим
|
||||
- Фиксированный максимальный размер последовательности
|
||||
|
||||
## Рекомендации по использованию
|
||||
1. Размер головы (`head_size`) обычно выбирают 64-128
|
||||
2. Для длинных последовательностей (>512) используйте оптимизации:
|
||||
- Локальное внимание
|
||||
- Разреженные паттерны
|
||||
3. Сочетайте с MultiHeadAttention для лучшего качества
|
||||
|
||||
[Дополнительные примеры](/example/attention_examples.py)
|
||||
88
example/head_attention_example.py
Normal file
88
example/head_attention_example.py
Normal file
@@ -0,0 +1,88 @@
|
||||
"""
|
||||
Пример использования механизма внимания HeadAttention
|
||||
с визуализацией матрицы внимания и анализом работы.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
from simple_llm.transformer.head_attention import HeadAttention
|
||||
|
||||
import os
|
||||
os.makedirs("example_output", exist_ok=True)
|
||||
|
||||
def plot_attention(weights, tokens=None, filename="attention_plot.png"):
|
||||
"""Сохранение матрицы внимания в файл"""
|
||||
plt.figure(figsize=(10, 8))
|
||||
plt.imshow(weights, cmap='viridis')
|
||||
|
||||
if tokens:
|
||||
plt.xticks(range(len(tokens)), tokens, rotation=90)
|
||||
plt.yticks(range(len(tokens)), tokens)
|
||||
|
||||
plt.colorbar()
|
||||
plt.title("Матрица весов внимания")
|
||||
plt.xlabel("Key Positions")
|
||||
plt.ylabel("Query Positions")
|
||||
plt.savefig(f"example_output/{filename}")
|
||||
plt.close()
|
||||
|
||||
def simulate_text_attention():
|
||||
"""Пример с имитацией текстовых данных"""
|
||||
# Параметры
|
||||
emb_size = 64
|
||||
head_size = 32
|
||||
seq_len = 8
|
||||
|
||||
# Имитация токенов
|
||||
tokens = ["[CLS]", "мама", "мыла", "раму", ",", "папа", "пил", "какао"]
|
||||
|
||||
# Инициализация
|
||||
torch.manual_seed(42)
|
||||
attention = HeadAttention(emb_size, head_size, max_seq_len=seq_len)
|
||||
|
||||
# Случайные эмбеддинги (в реальности - выход слоя токенизации)
|
||||
x = torch.randn(1, seq_len, emb_size)
|
||||
|
||||
# Прямой проход + получение весов
|
||||
with torch.no_grad():
|
||||
output = attention(x)
|
||||
q, k = attention._q(x), attention._k(x)
|
||||
scores = (q @ k.transpose(-2, -1)) / np.sqrt(head_size)
|
||||
weights = torch.softmax(scores, dim=-1).squeeze()
|
||||
|
||||
# Визуализация
|
||||
print("\nПример для фразы:", " ".join(tokens))
|
||||
print("Форма выходного тензора:", output.shape)
|
||||
plot_attention(weights.numpy(), tokens)
|
||||
|
||||
def technical_demo():
|
||||
"""Техническая демонстрация работы механизма"""
|
||||
print("\nТехническая демонстрация HeadAttention")
|
||||
attention = HeadAttention(emb_size=16, head_size=8, max_seq_len=10)
|
||||
|
||||
# Создаем тензор с ручными значениями для анализа
|
||||
x = torch.zeros(1, 4, 16)
|
||||
x[0, 0, :] = 1.0 # Яркий токен
|
||||
x[0, 3, :] = 0.5 # Слабый токен
|
||||
|
||||
# Анализ весов
|
||||
with torch.no_grad():
|
||||
output = attention(x)
|
||||
q = attention._q(x)
|
||||
k = attention._k(x)
|
||||
print("\nQuery векторы (первые 5 значений):")
|
||||
print(q[0, :, :5])
|
||||
|
||||
print("\nKey векторы (первые 5 значений):")
|
||||
print(k[0, :, :5])
|
||||
|
||||
weights = torch.softmax((q @ k.transpose(-2, -1)) / np.sqrt(8), dim=-1)
|
||||
print("\nМатрица внимания:")
|
||||
print(weights.squeeze().round(decimals=3))
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("Демонстрация работы HeadAttention")
|
||||
simulate_text_attention()
|
||||
technical_demo()
|
||||
print("\nГотово! Проверьте графики матрицы внимания.")
|
||||
84
simple_llm/transformer/head_attention.py
Normal file
84
simple_llm/transformer/head_attention.py
Normal file
@@ -0,0 +1,84 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from math import sqrt
|
||||
|
||||
class HeadAttention(nn.Module):
|
||||
"""
|
||||
Реализация одного головного механизма внимания из архитектуры Transformer.
|
||||
Выполняет scaled dot-product attention с маскированием будущих позиций (causal attention).
|
||||
|
||||
Основной алгоритм:
|
||||
1. Линейные преобразования входных данных в Q (query), K (key), V (value)
|
||||
2. Вычисление scores = Q·K^T / sqrt(d_k)
|
||||
3. Применение causal маски (заполнение -inf будущих позиций)
|
||||
4. Softmax для получения весов внимания
|
||||
5. Умножение весов на значения V
|
||||
|
||||
Пример использования:
|
||||
>>> attention = HeadAttention(emb_size=64, head_size=32, max_seq_len=128)
|
||||
>>> x = torch.randn(1, 10, 64) # [batch_size, seq_len, emb_size]
|
||||
>>> output = attention(x) # [1, 10, 32]
|
||||
|
||||
Параметры:
|
||||
emb_size (int): Размер входного эмбеддинга
|
||||
head_size (int): Размерность выхода головы внимания
|
||||
max_seq_len (int): Максимальная длина последовательности
|
||||
|
||||
Примечания:
|
||||
- Использует нижнетреугольную маску для предотвращения "заглядывания в будущее"
|
||||
- Автоматически адаптируется к разным версиям PyTorch
|
||||
- Поддерживает batch-обработку входных данных
|
||||
"""
|
||||
def __init__(self, emb_size: int, head_size: int, max_seq_len: int):
|
||||
super().__init__()
|
||||
self._emb_size = emb_size
|
||||
self._head_size = head_size
|
||||
self._max_seq_len = max_seq_len
|
||||
|
||||
# Линейные преобразования для Q, K, V
|
||||
self._k = nn.Linear(emb_size, head_size)
|
||||
self._q = nn.Linear(emb_size, head_size)
|
||||
self._v = nn.Linear(emb_size, head_size)
|
||||
|
||||
# Создание causal маски
|
||||
mask = torch.tril(torch.ones(max_seq_len, max_seq_len))
|
||||
self.register_buffer('_tril_mask', mask.bool() if hasattr(torch, 'bool') else mask.byte())
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Прямой проход через слой внимания.
|
||||
|
||||
Аргументы:
|
||||
x (torch.Tensor): Входной тензор формы [batch_size, seq_len, emb_size]
|
||||
|
||||
Возвращает:
|
||||
torch.Tensor: Выходной тензор формы [batch_size, seq_len, head_size]
|
||||
|
||||
Исключения:
|
||||
ValueError: Если длина последовательности превышает max_seq_len
|
||||
|
||||
Пример внутренних преобразований:
|
||||
Для входа x.shape = [2, 5, 64]:
|
||||
1. Q/K/V преобразования -> [2, 5, 32]
|
||||
2. Scores = Q·K^T -> [2, 5, 5]
|
||||
3. После маски и softmax -> [2, 5, 5]
|
||||
4. Умножение на V -> [2, 5, 32]
|
||||
"""
|
||||
seq_len = x.shape[1]
|
||||
if seq_len > self._max_seq_len:
|
||||
raise ValueError(f"Длина последовательности {seq_len} превышает максимум {self._max_seq_len}")
|
||||
|
||||
# 1. Линейные преобразования
|
||||
k = self._k(x) # [B, T, hs]
|
||||
q = self._q(x) # [B, T, hs]
|
||||
|
||||
# 2. Вычисление scores
|
||||
scores = q @ k.transpose(-2, -1) / sqrt(self._head_size)
|
||||
|
||||
# 3. Применение causal маски
|
||||
scores = scores.masked_fill(~self._tril_mask[:seq_len, :seq_len], float('-inf'))
|
||||
|
||||
# 4. Softmax и умножение на V
|
||||
weights = F.softmax(scores, dim=-1)
|
||||
return weights @ self._v(x)
|
||||
Reference in New Issue
Block a user