From a1508286654bd50536b7aec8eeb31f0b17738add Mon Sep 17 00:00:00 2001 From: Sergey Penkovsky Date: Sat, 19 Jul 2025 11:35:11 +0300 Subject: [PATCH] =?UTF-8?q?=D0=94=D0=BE=D0=B1=D0=B0=D0=B2=D0=BB=D0=B5?= =?UTF-8?q?=D0=BD=D0=B8=D0=B5=20=D0=BC=D0=B5=D1=85=D0=B0=D0=BD=D0=B8=D0=B7?= =?UTF-8?q?=D0=BC=D0=B0=20=D0=B2=D0=BD=D0=B8=D0=BC=D0=B0=D0=BD=D0=B8=D1=8F?= =?UTF-8?q?=20HeadAttention?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Реализация одного головного внимания из Transformer - Полная документация на русском языке - Пример использования с визуализацией - Обновление README с ссылками --- README.md | 12 +++- doc/head_attention_ru.md | 83 ++++++++++++++++++++++ example/head_attention_example.py | 88 ++++++++++++++++++++++++ simple_llm/transformer/head_attention.py | 84 ++++++++++++++++++++++ 4 files changed, 265 insertions(+), 2 deletions(-) create mode 100644 doc/head_attention_ru.md create mode 100644 example/head_attention_example.py create mode 100644 simple_llm/transformer/head_attention.py diff --git a/README.md b/README.md index 06424f2..72eb9fa 100644 --- a/README.md +++ b/README.md @@ -88,9 +88,17 @@ git+https://github.com/yourusername/simple-llm.git pip install git+https://github.com/yourusername/simple-llm.git ``` -## Примеры +## Примеры использования -Дополнительные примеры использования смотрите в папке [example](/example): +Дополнительные примеры: +- [Базовый BPE](/example/example_bpe.py) +- [Токенные эмбеддинги](/example/example_token_embeddings.py) +- [Механизм внимания](/example/head_attention_example.py) + +Документация: +- [Токенизация](/doc/bpe_algorithm.md) +- [Эмбеддинги](/doc/token_embeddings_ru.md) +- [Внимание](/doc/head_attention_ru.md) - Сравнение SimpleBPE и OptimizeBPE - Работа с разными языками - Настройка параметров токенизации diff --git a/doc/head_attention_ru.md b/doc/head_attention_ru.md new file mode 100644 index 0000000..4d3ffdf --- /dev/null +++ b/doc/head_attention_ru.md @@ -0,0 +1,83 @@ +# HeadAttention - Механизм самовнимания одной головы + +## Назначение +Модуль реализует механизм внимания одной головы из архитектуры Transformer. Основные применения: +- Моделирование зависимостей в последовательностях +- Обработка естественного языка (NLP) +- Генерация текста с учетом контекста +- Анализ временных рядов + +## Алгоритм работы + +```mermaid +flowchart TD + A[Входной тензор x] --> B[Вычисление Q, K, V] + B --> C["Scores = Q·Kᵀ / √d_k"] + C --> D[Применение нижнетреугольной маски] + D --> E[Softmax] + E --> F[Взвешенная сумма значений V] + F --> G[Выходной тензор] +``` + +1. **Линейные преобразования**: + ```python + Q = W_q·x, K = W_k·x, V = W_v·x + ``` + +2. **Вычисление attention scores**: + ```python + scores = matmul(Q, K.transpose(-2, -1)) / sqrt(head_size) + ``` + +3. **Маскирование**: + ```python + scores.masked_fill_(mask == 0, -inf) # Causal masking + ``` + +4. **Взвешивание**: + ```python + weights = softmax(scores, dim=-1) + output = matmul(weights, V) + ``` + +## Пример использования +```python +import torch +from simple_llm.transformer.head_attention import HeadAttention + +# Параметры +emb_size = 512 +head_size = 64 +max_seq_len = 1024 + +# Инициализация +attn_head = HeadAttention(emb_size, head_size, max_seq_len) + +# Пример входа (batch_size=2, seq_len=10) +x = torch.randn(2, 10, emb_size) +output = attn_head(x) # [2, 10, head_size] +``` + +## Особенности реализации + +### Ключевые компоненты +| Компонент | Назначение | +|-----------------|-------------------------------------| +| `self._q` | Линейный слой для Query | +| `self._k` | Линейный слой для Key | +| `self._v` | Линейный слой для Value | +| `self._tril_mask`| Нижнетреугольная маска | + +### Ограничения +- Требует O(n²) памяти для матрицы внимания +- Поддерживает только causal-режим +- Фиксированный максимальный размер последовательности + +## Рекомендации по использованию +1. Размер головы (`head_size`) обычно выбирают 64-128 +2. Для длинных последовательностей (>512) используйте оптимизации: + - Локальное внимание + - Разреженные паттерны +3. Сочетайте с MultiHeadAttention для лучшего качества + +[Дополнительные примеры](/example/attention_examples.py) diff --git a/example/head_attention_example.py b/example/head_attention_example.py new file mode 100644 index 0000000..0a9bd7c --- /dev/null +++ b/example/head_attention_example.py @@ -0,0 +1,88 @@ +""" +Пример использования механизма внимания HeadAttention +с визуализацией матрицы внимания и анализом работы. +""" + +import torch +import matplotlib.pyplot as plt +import numpy as np +from simple_llm.transformer.head_attention import HeadAttention + +import os +os.makedirs("example_output", exist_ok=True) + +def plot_attention(weights, tokens=None, filename="attention_plot.png"): + """Сохранение матрицы внимания в файл""" + plt.figure(figsize=(10, 8)) + plt.imshow(weights, cmap='viridis') + + if tokens: + plt.xticks(range(len(tokens)), tokens, rotation=90) + plt.yticks(range(len(tokens)), tokens) + + plt.colorbar() + plt.title("Матрица весов внимания") + plt.xlabel("Key Positions") + plt.ylabel("Query Positions") + plt.savefig(f"example_output/{filename}") + plt.close() + +def simulate_text_attention(): + """Пример с имитацией текстовых данных""" + # Параметры + emb_size = 64 + head_size = 32 + seq_len = 8 + + # Имитация токенов + tokens = ["[CLS]", "мама", "мыла", "раму", ",", "папа", "пил", "какао"] + + # Инициализация + torch.manual_seed(42) + attention = HeadAttention(emb_size, head_size, max_seq_len=seq_len) + + # Случайные эмбеддинги (в реальности - выход слоя токенизации) + x = torch.randn(1, seq_len, emb_size) + + # Прямой проход + получение весов + with torch.no_grad(): + output = attention(x) + q, k = attention._q(x), attention._k(x) + scores = (q @ k.transpose(-2, -1)) / np.sqrt(head_size) + weights = torch.softmax(scores, dim=-1).squeeze() + + # Визуализация + print("\nПример для фразы:", " ".join(tokens)) + print("Форма выходного тензора:", output.shape) + plot_attention(weights.numpy(), tokens) + +def technical_demo(): + """Техническая демонстрация работы механизма""" + print("\nТехническая демонстрация HeadAttention") + attention = HeadAttention(emb_size=16, head_size=8, max_seq_len=10) + + # Создаем тензор с ручными значениями для анализа + x = torch.zeros(1, 4, 16) + x[0, 0, :] = 1.0 # Яркий токен + x[0, 3, :] = 0.5 # Слабый токен + + # Анализ весов + with torch.no_grad(): + output = attention(x) + q = attention._q(x) + k = attention._k(x) + print("\nQuery векторы (первые 5 значений):") + print(q[0, :, :5]) + + print("\nKey векторы (первые 5 значений):") + print(k[0, :, :5]) + + weights = torch.softmax((q @ k.transpose(-2, -1)) / np.sqrt(8), dim=-1) + print("\nМатрица внимания:") + print(weights.squeeze().round(decimals=3)) + +if __name__ == "__main__": + print("Демонстрация работы HeadAttention") + simulate_text_attention() + technical_demo() + print("\nГотово! Проверьте графики матрицы внимания.") diff --git a/simple_llm/transformer/head_attention.py b/simple_llm/transformer/head_attention.py new file mode 100644 index 0000000..ac6f061 --- /dev/null +++ b/simple_llm/transformer/head_attention.py @@ -0,0 +1,84 @@ +import torch +from torch import nn +import torch.nn.functional as F +from math import sqrt + +class HeadAttention(nn.Module): + """ + Реализация одного головного механизма внимания из архитектуры Transformer. + Выполняет scaled dot-product attention с маскированием будущих позиций (causal attention). + + Основной алгоритм: + 1. Линейные преобразования входных данных в Q (query), K (key), V (value) + 2. Вычисление scores = Q·K^T / sqrt(d_k) + 3. Применение causal маски (заполнение -inf будущих позиций) + 4. Softmax для получения весов внимания + 5. Умножение весов на значения V + + Пример использования: + >>> attention = HeadAttention(emb_size=64, head_size=32, max_seq_len=128) + >>> x = torch.randn(1, 10, 64) # [batch_size, seq_len, emb_size] + >>> output = attention(x) # [1, 10, 32] + + Параметры: + emb_size (int): Размер входного эмбеддинга + head_size (int): Размерность выхода головы внимания + max_seq_len (int): Максимальная длина последовательности + + Примечания: + - Использует нижнетреугольную маску для предотвращения "заглядывания в будущее" + - Автоматически адаптируется к разным версиям PyTorch + - Поддерживает batch-обработку входных данных + """ + def __init__(self, emb_size: int, head_size: int, max_seq_len: int): + super().__init__() + self._emb_size = emb_size + self._head_size = head_size + self._max_seq_len = max_seq_len + + # Линейные преобразования для Q, K, V + self._k = nn.Linear(emb_size, head_size) + self._q = nn.Linear(emb_size, head_size) + self._v = nn.Linear(emb_size, head_size) + + # Создание causal маски + mask = torch.tril(torch.ones(max_seq_len, max_seq_len)) + self.register_buffer('_tril_mask', mask.bool() if hasattr(torch, 'bool') else mask.byte()) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Прямой проход через слой внимания. + + Аргументы: + x (torch.Tensor): Входной тензор формы [batch_size, seq_len, emb_size] + + Возвращает: + torch.Tensor: Выходной тензор формы [batch_size, seq_len, head_size] + + Исключения: + ValueError: Если длина последовательности превышает max_seq_len + + Пример внутренних преобразований: + Для входа x.shape = [2, 5, 64]: + 1. Q/K/V преобразования -> [2, 5, 32] + 2. Scores = Q·K^T -> [2, 5, 5] + 3. После маски и softmax -> [2, 5, 5] + 4. Умножение на V -> [2, 5, 32] + """ + seq_len = x.shape[1] + if seq_len > self._max_seq_len: + raise ValueError(f"Длина последовательности {seq_len} превышает максимум {self._max_seq_len}") + + # 1. Линейные преобразования + k = self._k(x) # [B, T, hs] + q = self._q(x) # [B, T, hs] + + # 2. Вычисление scores + scores = q @ k.transpose(-2, -1) / sqrt(self._head_size) + + # 3. Применение causal маски + scores = scores.masked_fill(~self._tril_mask[:seq_len, :seq_len], float('-inf')) + + # 4. Softmax и умножение на V + weights = F.softmax(scores, dim=-1) + return weights @ self._v(x) \ No newline at end of file