mirror of
https://github.com/pese-git/simple-llm.git
synced 2026-01-23 21:14:17 +00:00
- Добавлен класс MultiHeadAttention - Создана документация с блок-схемой - Добавлен пример использования - Обновлен README.md
2.7 KiB
2.7 KiB
MultiHeadAttention - Многоголовый механизм внимания
Назначение
Модуль реализует ключевой компонент архитектуры Transformer, который:
- Параллельно вычисляет несколько типов внимания
- Позволяет модели фокусироваться на разных аспектах данных
- Улучшает качество в задачах:
- Машинного перевода
- Генерации текста
- Классификации последовательностей
Алгоритм работы
flowchart TD
A[Вход: x] --> B[Линейные проекции Q,K,V]
B --> C[Разделение на num_heads частей]
C --> D[Параллельные вычисления внимания]
D --> E[Конкатенация результатов]
E --> F[Финальная проекция]
F --> G[Выход]
-
Инициализация проекций:
self._q = nn.Linear(emb_size, num_heads * head_size) self._k = nn.Linear(emb_size, num_heads * head_size) self._v = nn.Linear(emb_size, num_heads * head_size) -
Разделение на головы:
q = q.view(batch_size, seq_len, num_heads, head_size) -
Вычисление внимания:
scores = (q @ k.transpose(-2, -1)) / math.sqrt(head_size)
Пример использования
from simple_llm.transformer import MultiHeadAttention
# Инициализация
mha = MultiHeadAttention(
num_heads=8,
emb_size=512,
head_size=64,
max_seq_len=1024
)
# Пример входа
x = torch.randn(1, 50, 512) # [batch_size, seq_len, emb_size]
output = mha(x) # [1, 50, 512]
Параметры
| Параметр | Тип | Описание |
|---|---|---|
num_heads |
int | Количество голов внимания |
emb_size |
int | Размерность входных эмбеддингов |
head_size |
int | Размерность каждой головы |
max_seq_len |
int | Максимальная длина последовательности |
Рекомендации
- Размерность должна делиться на число голов:
assert emb_size % num_heads == 0 - Для визуализации весов:
weights = [head.get_attention_weights(x) for head in mha._heads]