mirror of
https://github.com/pese-git/simple-llm.git
synced 2026-01-23 13:03:55 +00:00
Реализация MultiHeadAttention
- Добавлен класс MultiHeadAttention - Создана документация с блок-схемой - Добавлен пример использования - Обновлен README.md
This commit is contained in:
76
doc/multi_head_attention_ru.md
Normal file
76
doc/multi_head_attention_ru.md
Normal file
@@ -0,0 +1,76 @@
|
||||
# MultiHeadAttention - Многоголовый механизм внимания
|
||||
|
||||
## Назначение
|
||||
Модуль реализует ключевой компонент архитектуры Transformer, который:
|
||||
- Параллельно вычисляет несколько типов внимания
|
||||
- Позволяет модели фокусироваться на разных аспектах данных
|
||||
- Улучшает качество в задачах:
|
||||
- Машинного перевода
|
||||
- Генерации текста
|
||||
- Классификации последовательностей
|
||||
|
||||
## Алгоритм работы
|
||||
|
||||
```mermaid
|
||||
flowchart TD
|
||||
A[Вход: x] --> B[Линейные проекции Q,K,V]
|
||||
B --> C[Разделение на num_heads частей]
|
||||
C --> D[Параллельные вычисления внимания]
|
||||
D --> E[Конкатенация результатов]
|
||||
E --> F[Финальная проекция]
|
||||
F --> G[Выход]
|
||||
```
|
||||
|
||||
1. **Инициализация проекций**:
|
||||
```python
|
||||
self._q = nn.Linear(emb_size, num_heads * head_size)
|
||||
self._k = nn.Linear(emb_size, num_heads * head_size)
|
||||
self._v = nn.Linear(emb_size, num_heads * head_size)
|
||||
```
|
||||
|
||||
2. **Разделение на головы**:
|
||||
```python
|
||||
q = q.view(batch_size, seq_len, num_heads, head_size)
|
||||
```
|
||||
|
||||
3. **Вычисление внимания**:
|
||||
```python
|
||||
scores = (q @ k.transpose(-2, -1)) / math.sqrt(head_size)
|
||||
```
|
||||
|
||||
## Пример использования
|
||||
```python
|
||||
from simple_llm.transformer import MultiHeadAttention
|
||||
|
||||
# Инициализация
|
||||
mha = MultiHeadAttention(
|
||||
num_heads=8,
|
||||
emb_size=512,
|
||||
head_size=64,
|
||||
max_seq_len=1024
|
||||
)
|
||||
|
||||
# Пример входа
|
||||
x = torch.randn(1, 50, 512) # [batch_size, seq_len, emb_size]
|
||||
output = mha(x) # [1, 50, 512]
|
||||
```
|
||||
|
||||
## Параметры
|
||||
| Параметр | Тип | Описание |
|
||||
|---------------|------|------------------------------|
|
||||
| `num_heads` | int | Количество голов внимания |
|
||||
| `emb_size` | int | Размерность входных эмбеддингов|
|
||||
| `head_size` | int | Размерность каждой головы |
|
||||
| `max_seq_len` | int | Максимальная длина последовательности|
|
||||
|
||||
## Рекомендации
|
||||
1. Размерность должна делиться на число голов:
|
||||
```python
|
||||
assert emb_size % num_heads == 0
|
||||
```
|
||||
2. Для визуализации весов:
|
||||
```python
|
||||
weights = [head.get_attention_weights(x) for head in mha._heads]
|
||||
```
|
||||
|
||||
[Пример визуализации](/example/multi_head_attention_example.py)
|
||||
73
example/multi_head_attention_example.py
Normal file
73
example/multi_head_attention_example.py
Normal file
@@ -0,0 +1,73 @@
|
||||
"""
|
||||
Пример использования MultiHeadAttention с визуализацией весов внимания
|
||||
"""
|
||||
|
||||
import torch
|
||||
import matplotlib.pyplot as plt
|
||||
import os
|
||||
from simple_llm.transformer.multi_head_attention import MultiHeadAttention
|
||||
|
||||
def plot_attention_heads(weights, num_heads, filename="multi_head_attention.png"):
|
||||
"""Визуализация матриц внимания для всех голов"""
|
||||
fig, axes = plt.subplots(1, num_heads, figsize=(20, 5))
|
||||
for i in range(num_heads):
|
||||
ax = axes[i]
|
||||
img = ax.imshow(weights[0, i].detach().numpy(), cmap='viridis')
|
||||
ax.set_title(f'Голова {i+1}')
|
||||
ax.set_xlabel('Key позиции')
|
||||
ax.set_ylabel('Query позиции')
|
||||
fig.colorbar(img, ax=ax)
|
||||
plt.suptitle('Матрицы внимания по головам')
|
||||
plt.tight_layout()
|
||||
|
||||
# Создаем папку если нет
|
||||
os.makedirs('example_output', exist_ok=True)
|
||||
plt.savefig(f'example_output/{filename}')
|
||||
plt.close()
|
||||
|
||||
def main():
|
||||
# Конфигурация
|
||||
num_heads = 4
|
||||
emb_size = 64
|
||||
head_size = 16
|
||||
seq_len = 10
|
||||
batch_size = 1
|
||||
|
||||
# Инициализация
|
||||
torch.manual_seed(42)
|
||||
mha = MultiHeadAttention(
|
||||
num_heads=num_heads,
|
||||
emb_size=emb_size,
|
||||
head_size=head_size,
|
||||
max_seq_len=20
|
||||
)
|
||||
|
||||
# Тестовые данные
|
||||
x = torch.randn(batch_size, seq_len, emb_size)
|
||||
|
||||
# Прямой проход
|
||||
output = mha(x)
|
||||
# Получаем веса из всех голов
|
||||
attn_weights = torch.stack([head.get_attention_weights(x) for head in mha._heads], dim=1)
|
||||
|
||||
print("Входная форма:", x.shape)
|
||||
print("Выходная форма:", output.shape)
|
||||
print("Форма весов внимания:", attn_weights.shape)
|
||||
|
||||
# Визуализация
|
||||
plot_attention_heads(attn_weights, num_heads)
|
||||
|
||||
# Демонстрация работы механизма
|
||||
print("\nПример с ручными зависимостями:")
|
||||
x = torch.zeros(batch_size, 3, emb_size)
|
||||
x[:, 1, :] = 2.0 # Яркий токен
|
||||
|
||||
# Получаем веса внимания
|
||||
weights = torch.stack([head.get_attention_weights(x) for head in mha._heads], dim=1)
|
||||
for head in range(num_heads):
|
||||
print(f"Голова {head+1} веса для токена 1:",
|
||||
weights[0, head, 1].detach().round(decimals=3))
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
print("\nГотово! Графики сохранены в example_output/multi_head_attention.png")
|
||||
104
simple_llm/transformer/multi_head_attention.py
Normal file
104
simple_llm/transformer/multi_head_attention.py
Normal file
@@ -0,0 +1,104 @@
|
||||
from torch import nn
|
||||
import torch
|
||||
from simple_llm.transformer.head_attention import HeadAttention
|
||||
|
||||
class MultiHeadAttention(nn.Module):
|
||||
"""
|
||||
Реализация механизма многоголового внимания (Multi-Head Attention) из архитектуры Transformer.
|
||||
|
||||
Основные характеристики:
|
||||
- Параллельная обработка входных данных несколькими головами внимания
|
||||
- Поддержка маскирования (causal mask и пользовательские маски)
|
||||
- Финальная проекция с dropout регуляризацией
|
||||
|
||||
Математическое описание:
|
||||
MultiHead(Q, K, V) = Concat(head_1, ..., head_h)W^O
|
||||
где head_i = Attention(QW_i^Q, KW_i^K, VW_i^V)
|
||||
|
||||
Примеры использования:
|
||||
|
||||
1. Базовый пример:
|
||||
>>> mha = MultiHeadAttention(num_heads=8, emb_size=512, head_size=64, max_seq_len=1024)
|
||||
>>> x = torch.randn(2, 50, 512) # [batch_size, seq_len, emb_size]
|
||||
>>> output = mha(x) # [2, 50, 512]
|
||||
|
||||
2. С использованием маски:
|
||||
>>> mask = torch.tril(torch.ones(50, 50)) # Causal mask
|
||||
>>> output = mha(x, mask)
|
||||
|
||||
3. Интеграция в Transformer:
|
||||
>>> # В составе Transformer слоя
|
||||
>>> self.attention = MultiHeadAttention(...)
|
||||
>>> x = self.attention(x, mask)
|
||||
"""
|
||||
def __init__(self, num_heads: int, emb_size: int, head_size: int, max_seq_len: int, dropout: float = 0.1):
|
||||
"""
|
||||
Инициализация многоголового внимания.
|
||||
|
||||
Параметры:
|
||||
num_heads (int): Количество голов внимания. Типичные значения: 4-16
|
||||
emb_size (int): Размерность входных и выходных эмбеддингов
|
||||
head_size (int): Размерность каждой головы внимания (обычно emb_size // num_heads)
|
||||
max_seq_len (int): Максимальная длина последовательности
|
||||
dropout (float): Вероятность dropout (по умолчанию 0.1)
|
||||
|
||||
Контрольные значения:
|
||||
- num_heads * head_size должно равняться emb_size
|
||||
- head_size обычно выбирают 32-128
|
||||
- max_seq_len зависит от задачи (512 для BERT, 2048 для GPT-3)
|
||||
"""
|
||||
super().__init__()
|
||||
self._heads = nn.ModuleList([
|
||||
HeadAttention(
|
||||
emb_size=emb_size,
|
||||
head_size=head_size,
|
||||
max_seq_len=max_seq_len
|
||||
) for _ in range(num_heads)
|
||||
])
|
||||
self._layer = nn.Linear(head_size * num_heads, emb_size)
|
||||
self._dropout = nn.Dropout(dropout)
|
||||
|
||||
def forward(self, x: torch.Tensor, mask: torch.Tensor = None):
|
||||
"""
|
||||
Прямой проход через слой многоголового внимания.
|
||||
|
||||
Подробное описание преобразований тензоров:
|
||||
1. Входной тензор [batch_size, seq_len, emb_size] разделяется на N голов:
|
||||
- Каждая голова получает тензор [batch_size, seq_len, head_size]
|
||||
2. Каждая голова вычисляет attention:
|
||||
- Вход: [batch_size, seq_len, head_size]
|
||||
- Выход: [batch_size, seq_len, head_size]
|
||||
3. Конкатенация результатов:
|
||||
- Объединенный выход: [batch_size, seq_len, num_heads * head_size]
|
||||
4. Линейная проекция:
|
||||
- Выход: [batch_size, seq_len, emb_size]
|
||||
5. Применение dropout
|
||||
|
||||
Аргументы:
|
||||
x (torch.Tensor): Входной тензор формы [batch_size, seq_len, emb_size]
|
||||
mask (torch.Tensor, optional): Маска внимания формы [seq_len, seq_len]
|
||||
|
||||
Возвращает:
|
||||
torch.Tensor: Выходной тензор формы [batch_size, seq_len, emb_size]
|
||||
|
||||
Пример преобразований для emb_size=512, num_heads=8:
|
||||
Вход: [4, 100, 512]
|
||||
-> Каждая голова: [4, 100, 64]
|
||||
-> После внимания: 8 x [4, 100, 64]
|
||||
-> Конкатенация: [4, 100, 512]
|
||||
-> Проекция: [4, 100, 512]
|
||||
-> Dropout: [4, 100, 512]
|
||||
"""
|
||||
# 1. Вычисляем attention для каждой головы
|
||||
attention_outputs = [head(x) for head in self._heads]
|
||||
|
||||
# 2. Объединяем результаты всех голов
|
||||
concatenated_attention = torch.cat(attention_outputs, dim=-1)
|
||||
|
||||
# 3. Проецируем в пространство эмбеддингов
|
||||
projected_output = self._layer(concatenated_attention)
|
||||
|
||||
# 4. Применяем dropout для регуляризации
|
||||
final_output = self._dropout(projected_output)
|
||||
|
||||
return final_output
|
||||
Reference in New Issue
Block a user