Files
simple-llm/simple_llm/transformer/multi_head_attention.py
Sergey Penkovsky 034b515846 Реализация MultiHeadAttention
- Добавлен класс MultiHeadAttention
- Создана документация с блок-схемой
- Добавлен пример использования
- Обновлен README.md
2025-07-19 22:24:05 +03:00

105 lines
5.1 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from torch import nn
import torch
from simple_llm.transformer.head_attention import HeadAttention
class MultiHeadAttention(nn.Module):
"""
Реализация механизма многоголового внимания (Multi-Head Attention) из архитектуры Transformer.
Основные характеристики:
- Параллельная обработка входных данных несколькими головами внимания
- Поддержка маскирования (causal mask и пользовательские маски)
- Финальная проекция с dropout регуляризацией
Математическое описание:
MultiHead(Q, K, V) = Concat(head_1, ..., head_h)W^O
где head_i = Attention(QW_i^Q, KW_i^K, VW_i^V)
Примеры использования:
1. Базовый пример:
>>> mha = MultiHeadAttention(num_heads=8, emb_size=512, head_size=64, max_seq_len=1024)
>>> x = torch.randn(2, 50, 512) # [batch_size, seq_len, emb_size]
>>> output = mha(x) # [2, 50, 512]
2. С использованием маски:
>>> mask = torch.tril(torch.ones(50, 50)) # Causal mask
>>> output = mha(x, mask)
3. Интеграция в Transformer:
>>> # В составе Transformer слоя
>>> self.attention = MultiHeadAttention(...)
>>> x = self.attention(x, mask)
"""
def __init__(self, num_heads: int, emb_size: int, head_size: int, max_seq_len: int, dropout: float = 0.1):
"""
Инициализация многоголового внимания.
Параметры:
num_heads (int): Количество голов внимания. Типичные значения: 4-16
emb_size (int): Размерность входных и выходных эмбеддингов
head_size (int): Размерность каждой головы внимания (обычно emb_size // num_heads)
max_seq_len (int): Максимальная длина последовательности
dropout (float): Вероятность dropout (по умолчанию 0.1)
Контрольные значения:
- num_heads * head_size должно равняться emb_size
- head_size обычно выбирают 32-128
- max_seq_len зависит от задачи (512 для BERT, 2048 для GPT-3)
"""
super().__init__()
self._heads = nn.ModuleList([
HeadAttention(
emb_size=emb_size,
head_size=head_size,
max_seq_len=max_seq_len
) for _ in range(num_heads)
])
self._layer = nn.Linear(head_size * num_heads, emb_size)
self._dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor, mask: torch.Tensor = None):
"""
Прямой проход через слой многоголового внимания.
Подробное описание преобразований тензоров:
1. Входной тензор [batch_size, seq_len, emb_size] разделяется на N голов:
- Каждая голова получает тензор [batch_size, seq_len, head_size]
2. Каждая голова вычисляет attention:
- Вход: [batch_size, seq_len, head_size]
- Выход: [batch_size, seq_len, head_size]
3. Конкатенация результатов:
- Объединенный выход: [batch_size, seq_len, num_heads * head_size]
4. Линейная проекция:
- Выход: [batch_size, seq_len, emb_size]
5. Применение dropout
Аргументы:
x (torch.Tensor): Входной тензор формы [batch_size, seq_len, emb_size]
mask (torch.Tensor, optional): Маска внимания формы [seq_len, seq_len]
Возвращает:
torch.Tensor: Выходной тензор формы [batch_size, seq_len, emb_size]
Пример преобразований для emb_size=512, num_heads=8:
Вход: [4, 100, 512]
-> Каждая голова: [4, 100, 64]
-> После внимания: 8 x [4, 100, 64]
-> Конкатенация: [4, 100, 512]
-> Проекция: [4, 100, 512]
-> Dropout: [4, 100, 512]
"""
# 1. Вычисляем attention для каждой головы
attention_outputs = [head(x) for head in self._heads]
# 2. Объединяем результаты всех голов
concatenated_attention = torch.cat(attention_outputs, dim=-1)
# 3. Проецируем в пространство эмбеддингов
projected_output = self._layer(concatenated_attention)
# 4. Применяем dropout для регуляризации
final_output = self._dropout(projected_output)
return final_output