mirror of
https://github.com/pese-git/simple-llm.git
synced 2026-01-24 05:26:02 +00:00
Добавление механизма внимания HeadAttention
- Реализация одного головного внимания из Transformer - Полная документация на русском языке - Пример использования с визуализацией - Обновление README с ссылками
This commit is contained in:
84
simple_llm/transformer/head_attention.py
Normal file
84
simple_llm/transformer/head_attention.py
Normal file
@@ -0,0 +1,84 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from math import sqrt
|
||||
|
||||
class HeadAttention(nn.Module):
|
||||
"""
|
||||
Реализация одного головного механизма внимания из архитектуры Transformer.
|
||||
Выполняет scaled dot-product attention с маскированием будущих позиций (causal attention).
|
||||
|
||||
Основной алгоритм:
|
||||
1. Линейные преобразования входных данных в Q (query), K (key), V (value)
|
||||
2. Вычисление scores = Q·K^T / sqrt(d_k)
|
||||
3. Применение causal маски (заполнение -inf будущих позиций)
|
||||
4. Softmax для получения весов внимания
|
||||
5. Умножение весов на значения V
|
||||
|
||||
Пример использования:
|
||||
>>> attention = HeadAttention(emb_size=64, head_size=32, max_seq_len=128)
|
||||
>>> x = torch.randn(1, 10, 64) # [batch_size, seq_len, emb_size]
|
||||
>>> output = attention(x) # [1, 10, 32]
|
||||
|
||||
Параметры:
|
||||
emb_size (int): Размер входного эмбеддинга
|
||||
head_size (int): Размерность выхода головы внимания
|
||||
max_seq_len (int): Максимальная длина последовательности
|
||||
|
||||
Примечания:
|
||||
- Использует нижнетреугольную маску для предотвращения "заглядывания в будущее"
|
||||
- Автоматически адаптируется к разным версиям PyTorch
|
||||
- Поддерживает batch-обработку входных данных
|
||||
"""
|
||||
def __init__(self, emb_size: int, head_size: int, max_seq_len: int):
|
||||
super().__init__()
|
||||
self._emb_size = emb_size
|
||||
self._head_size = head_size
|
||||
self._max_seq_len = max_seq_len
|
||||
|
||||
# Линейные преобразования для Q, K, V
|
||||
self._k = nn.Linear(emb_size, head_size)
|
||||
self._q = nn.Linear(emb_size, head_size)
|
||||
self._v = nn.Linear(emb_size, head_size)
|
||||
|
||||
# Создание causal маски
|
||||
mask = torch.tril(torch.ones(max_seq_len, max_seq_len))
|
||||
self.register_buffer('_tril_mask', mask.bool() if hasattr(torch, 'bool') else mask.byte())
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Прямой проход через слой внимания.
|
||||
|
||||
Аргументы:
|
||||
x (torch.Tensor): Входной тензор формы [batch_size, seq_len, emb_size]
|
||||
|
||||
Возвращает:
|
||||
torch.Tensor: Выходной тензор формы [batch_size, seq_len, head_size]
|
||||
|
||||
Исключения:
|
||||
ValueError: Если длина последовательности превышает max_seq_len
|
||||
|
||||
Пример внутренних преобразований:
|
||||
Для входа x.shape = [2, 5, 64]:
|
||||
1. Q/K/V преобразования -> [2, 5, 32]
|
||||
2. Scores = Q·K^T -> [2, 5, 5]
|
||||
3. После маски и softmax -> [2, 5, 5]
|
||||
4. Умножение на V -> [2, 5, 32]
|
||||
"""
|
||||
seq_len = x.shape[1]
|
||||
if seq_len > self._max_seq_len:
|
||||
raise ValueError(f"Длина последовательности {seq_len} превышает максимум {self._max_seq_len}")
|
||||
|
||||
# 1. Линейные преобразования
|
||||
k = self._k(x) # [B, T, hs]
|
||||
q = self._q(x) # [B, T, hs]
|
||||
|
||||
# 2. Вычисление scores
|
||||
scores = q @ k.transpose(-2, -1) / sqrt(self._head_size)
|
||||
|
||||
# 3. Применение causal маски
|
||||
scores = scores.masked_fill(~self._tril_mask[:seq_len, :seq_len], float('-inf'))
|
||||
|
||||
# 4. Softmax и умножение на V
|
||||
weights = F.softmax(scores, dim=-1)
|
||||
return weights @ self._v(x)
|
||||
Reference in New Issue
Block a user