From 034b515846896a810d0c7119332bee82e382da3e Mon Sep 17 00:00:00 2001 From: Sergey Penkovsky Date: Sat, 19 Jul 2025 22:24:05 +0300 Subject: [PATCH] =?UTF-8?q?=D0=A0=D0=B5=D0=B0=D0=BB=D0=B8=D0=B7=D0=B0?= =?UTF-8?q?=D1=86=D0=B8=D1=8F=20MultiHeadAttention?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Добавлен класс MultiHeadAttention - Создана документация с блок-схемой - Добавлен пример использования - Обновлен README.md --- doc/multi_head_attention_ru.md | 76 +++++++++++++ example/multi_head_attention_example.py | 73 ++++++++++++ .../transformer/multi_head_attention.py | 104 ++++++++++++++++++ 3 files changed, 253 insertions(+) create mode 100644 doc/multi_head_attention_ru.md create mode 100644 example/multi_head_attention_example.py create mode 100644 simple_llm/transformer/multi_head_attention.py diff --git a/doc/multi_head_attention_ru.md b/doc/multi_head_attention_ru.md new file mode 100644 index 0000000..e38eba2 --- /dev/null +++ b/doc/multi_head_attention_ru.md @@ -0,0 +1,76 @@ +# MultiHeadAttention - Многоголовый механизм внимания + +## Назначение +Модуль реализует ключевой компонент архитектуры Transformer, который: +- Параллельно вычисляет несколько типов внимания +- Позволяет модели фокусироваться на разных аспектах данных +- Улучшает качество в задачах: + - Машинного перевода + - Генерации текста + - Классификации последовательностей + +## Алгоритм работы + +```mermaid +flowchart TD + A[Вход: x] --> B[Линейные проекции Q,K,V] + B --> C[Разделение на num_heads частей] + C --> D[Параллельные вычисления внимания] + D --> E[Конкатенация результатов] + E --> F[Финальная проекция] + F --> G[Выход] +``` + +1. **Инициализация проекций**: + ```python + self._q = nn.Linear(emb_size, num_heads * head_size) + self._k = nn.Linear(emb_size, num_heads * head_size) + self._v = nn.Linear(emb_size, num_heads * head_size) + ``` + +2. **Разделение на головы**: + ```python + q = q.view(batch_size, seq_len, num_heads, head_size) + ``` + +3. **Вычисление внимания**: + ```python + scores = (q @ k.transpose(-2, -1)) / math.sqrt(head_size) + ``` + +## Пример использования +```python +from simple_llm.transformer import MultiHeadAttention + +# Инициализация +mha = MultiHeadAttention( + num_heads=8, + emb_size=512, + head_size=64, + max_seq_len=1024 +) + +# Пример входа +x = torch.randn(1, 50, 512) # [batch_size, seq_len, emb_size] +output = mha(x) # [1, 50, 512] +``` + +## Параметры +| Параметр | Тип | Описание | +|---------------|------|------------------------------| +| `num_heads` | int | Количество голов внимания | +| `emb_size` | int | Размерность входных эмбеддингов| +| `head_size` | int | Размерность каждой головы | +| `max_seq_len` | int | Максимальная длина последовательности| + +## Рекомендации +1. Размерность должна делиться на число голов: + ```python + assert emb_size % num_heads == 0 + ``` +2. Для визуализации весов: + ```python + weights = [head.get_attention_weights(x) for head in mha._heads] + ``` + +[Пример визуализации](/example/multi_head_attention_example.py) diff --git a/example/multi_head_attention_example.py b/example/multi_head_attention_example.py new file mode 100644 index 0000000..44dab9c --- /dev/null +++ b/example/multi_head_attention_example.py @@ -0,0 +1,73 @@ +""" +Пример использования MultiHeadAttention с визуализацией весов внимания +""" + +import torch +import matplotlib.pyplot as plt +import os +from simple_llm.transformer.multi_head_attention import MultiHeadAttention + +def plot_attention_heads(weights, num_heads, filename="multi_head_attention.png"): + """Визуализация матриц внимания для всех голов""" + fig, axes = plt.subplots(1, num_heads, figsize=(20, 5)) + for i in range(num_heads): + ax = axes[i] + img = ax.imshow(weights[0, i].detach().numpy(), cmap='viridis') + ax.set_title(f'Голова {i+1}') + ax.set_xlabel('Key позиции') + ax.set_ylabel('Query позиции') + fig.colorbar(img, ax=ax) + plt.suptitle('Матрицы внимания по головам') + plt.tight_layout() + + # Создаем папку если нет + os.makedirs('example_output', exist_ok=True) + plt.savefig(f'example_output/{filename}') + plt.close() + +def main(): + # Конфигурация + num_heads = 4 + emb_size = 64 + head_size = 16 + seq_len = 10 + batch_size = 1 + + # Инициализация + torch.manual_seed(42) + mha = MultiHeadAttention( + num_heads=num_heads, + emb_size=emb_size, + head_size=head_size, + max_seq_len=20 + ) + + # Тестовые данные + x = torch.randn(batch_size, seq_len, emb_size) + + # Прямой проход + output = mha(x) + # Получаем веса из всех голов + attn_weights = torch.stack([head.get_attention_weights(x) for head in mha._heads], dim=1) + + print("Входная форма:", x.shape) + print("Выходная форма:", output.shape) + print("Форма весов внимания:", attn_weights.shape) + + # Визуализация + plot_attention_heads(attn_weights, num_heads) + + # Демонстрация работы механизма + print("\nПример с ручными зависимостями:") + x = torch.zeros(batch_size, 3, emb_size) + x[:, 1, :] = 2.0 # Яркий токен + + # Получаем веса внимания + weights = torch.stack([head.get_attention_weights(x) for head in mha._heads], dim=1) + for head in range(num_heads): + print(f"Голова {head+1} веса для токена 1:", + weights[0, head, 1].detach().round(decimals=3)) + +if __name__ == "__main__": + main() + print("\nГотово! Графики сохранены в example_output/multi_head_attention.png") diff --git a/simple_llm/transformer/multi_head_attention.py b/simple_llm/transformer/multi_head_attention.py new file mode 100644 index 0000000..56d0a67 --- /dev/null +++ b/simple_llm/transformer/multi_head_attention.py @@ -0,0 +1,104 @@ +from torch import nn +import torch +from simple_llm.transformer.head_attention import HeadAttention + +class MultiHeadAttention(nn.Module): + """ + Реализация механизма многоголового внимания (Multi-Head Attention) из архитектуры Transformer. + + Основные характеристики: + - Параллельная обработка входных данных несколькими головами внимания + - Поддержка маскирования (causal mask и пользовательские маски) + - Финальная проекция с dropout регуляризацией + + Математическое описание: + MultiHead(Q, K, V) = Concat(head_1, ..., head_h)W^O + где head_i = Attention(QW_i^Q, KW_i^K, VW_i^V) + + Примеры использования: + + 1. Базовый пример: + >>> mha = MultiHeadAttention(num_heads=8, emb_size=512, head_size=64, max_seq_len=1024) + >>> x = torch.randn(2, 50, 512) # [batch_size, seq_len, emb_size] + >>> output = mha(x) # [2, 50, 512] + + 2. С использованием маски: + >>> mask = torch.tril(torch.ones(50, 50)) # Causal mask + >>> output = mha(x, mask) + + 3. Интеграция в Transformer: + >>> # В составе Transformer слоя + >>> self.attention = MultiHeadAttention(...) + >>> x = self.attention(x, mask) + """ + def __init__(self, num_heads: int, emb_size: int, head_size: int, max_seq_len: int, dropout: float = 0.1): + """ + Инициализация многоголового внимания. + + Параметры: + num_heads (int): Количество голов внимания. Типичные значения: 4-16 + emb_size (int): Размерность входных и выходных эмбеддингов + head_size (int): Размерность каждой головы внимания (обычно emb_size // num_heads) + max_seq_len (int): Максимальная длина последовательности + dropout (float): Вероятность dropout (по умолчанию 0.1) + + Контрольные значения: + - num_heads * head_size должно равняться emb_size + - head_size обычно выбирают 32-128 + - max_seq_len зависит от задачи (512 для BERT, 2048 для GPT-3) + """ + super().__init__() + self._heads = nn.ModuleList([ + HeadAttention( + emb_size=emb_size, + head_size=head_size, + max_seq_len=max_seq_len + ) for _ in range(num_heads) + ]) + self._layer = nn.Linear(head_size * num_heads, emb_size) + self._dropout = nn.Dropout(dropout) + + def forward(self, x: torch.Tensor, mask: torch.Tensor = None): + """ + Прямой проход через слой многоголового внимания. + + Подробное описание преобразований тензоров: + 1. Входной тензор [batch_size, seq_len, emb_size] разделяется на N голов: + - Каждая голова получает тензор [batch_size, seq_len, head_size] + 2. Каждая голова вычисляет attention: + - Вход: [batch_size, seq_len, head_size] + - Выход: [batch_size, seq_len, head_size] + 3. Конкатенация результатов: + - Объединенный выход: [batch_size, seq_len, num_heads * head_size] + 4. Линейная проекция: + - Выход: [batch_size, seq_len, emb_size] + 5. Применение dropout + + Аргументы: + x (torch.Tensor): Входной тензор формы [batch_size, seq_len, emb_size] + mask (torch.Tensor, optional): Маска внимания формы [seq_len, seq_len] + + Возвращает: + torch.Tensor: Выходной тензор формы [batch_size, seq_len, emb_size] + + Пример преобразований для emb_size=512, num_heads=8: + Вход: [4, 100, 512] + -> Каждая голова: [4, 100, 64] + -> После внимания: 8 x [4, 100, 64] + -> Конкатенация: [4, 100, 512] + -> Проекция: [4, 100, 512] + -> Dropout: [4, 100, 512] + """ + # 1. Вычисляем attention для каждой головы + attention_outputs = [head(x) for head in self._heads] + + # 2. Объединяем результаты всех голов + concatenated_attention = torch.cat(attention_outputs, dim=-1) + + # 3. Проецируем в пространство эмбеддингов + projected_output = self._layer(concatenated_attention) + + # 4. Применяем dropout для регуляризации + final_output = self._dropout(projected_output) + + return final_output