mirror of
https://github.com/pese-git/simple-llm.git
synced 2026-01-24 13:48:01 +00:00
84 lines
4.0 KiB
Python
84 lines
4.0 KiB
Python
|
|
import torch
|
|||
|
|
from torch import nn
|
|||
|
|
import torch.nn.functional as F
|
|||
|
|
from math import sqrt
|
|||
|
|
|
|||
|
|
class HeadAttention(nn.Module):
|
|||
|
|
"""
|
|||
|
|
Реализация одного головного механизма внимания из архитектуры Transformer.
|
|||
|
|
Выполняет scaled dot-product attention с маскированием будущих позиций (causal attention).
|
|||
|
|
|
|||
|
|
Основной алгоритм:
|
|||
|
|
1. Линейные преобразования входных данных в Q (query), K (key), V (value)
|
|||
|
|
2. Вычисление scores = Q·K^T / sqrt(d_k)
|
|||
|
|
3. Применение causal маски (заполнение -inf будущих позиций)
|
|||
|
|
4. Softmax для получения весов внимания
|
|||
|
|
5. Умножение весов на значения V
|
|||
|
|
|
|||
|
|
Пример использования:
|
|||
|
|
>>> attention = HeadAttention(emb_size=64, head_size=32, max_seq_len=128)
|
|||
|
|
>>> x = torch.randn(1, 10, 64) # [batch_size, seq_len, emb_size]
|
|||
|
|
>>> output = attention(x) # [1, 10, 32]
|
|||
|
|
|
|||
|
|
Параметры:
|
|||
|
|
emb_size (int): Размер входного эмбеддинга
|
|||
|
|
head_size (int): Размерность выхода головы внимания
|
|||
|
|
max_seq_len (int): Максимальная длина последовательности
|
|||
|
|
|
|||
|
|
Примечания:
|
|||
|
|
- Использует нижнетреугольную маску для предотвращения "заглядывания в будущее"
|
|||
|
|
- Автоматически адаптируется к разным версиям PyTorch
|
|||
|
|
- Поддерживает batch-обработку входных данных
|
|||
|
|
"""
|
|||
|
|
def __init__(self, emb_size: int, head_size: int, max_seq_len: int):
|
|||
|
|
super().__init__()
|
|||
|
|
self._emb_size = emb_size
|
|||
|
|
self._head_size = head_size
|
|||
|
|
self._max_seq_len = max_seq_len
|
|||
|
|
|
|||
|
|
# Линейные преобразования для Q, K, V
|
|||
|
|
self._k = nn.Linear(emb_size, head_size)
|
|||
|
|
self._q = nn.Linear(emb_size, head_size)
|
|||
|
|
self._v = nn.Linear(emb_size, head_size)
|
|||
|
|
|
|||
|
|
# Создание causal маски
|
|||
|
|
mask = torch.tril(torch.ones(max_seq_len, max_seq_len))
|
|||
|
|
self.register_buffer('_tril_mask', mask.bool() if hasattr(torch, 'bool') else mask.byte())
|
|||
|
|
|
|||
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|||
|
|
"""
|
|||
|
|
Прямой проход через слой внимания.
|
|||
|
|
|
|||
|
|
Аргументы:
|
|||
|
|
x (torch.Tensor): Входной тензор формы [batch_size, seq_len, emb_size]
|
|||
|
|
|
|||
|
|
Возвращает:
|
|||
|
|
torch.Tensor: Выходной тензор формы [batch_size, seq_len, head_size]
|
|||
|
|
|
|||
|
|
Исключения:
|
|||
|
|
ValueError: Если длина последовательности превышает max_seq_len
|
|||
|
|
|
|||
|
|
Пример внутренних преобразований:
|
|||
|
|
Для входа x.shape = [2, 5, 64]:
|
|||
|
|
1. Q/K/V преобразования -> [2, 5, 32]
|
|||
|
|
2. Scores = Q·K^T -> [2, 5, 5]
|
|||
|
|
3. После маски и softmax -> [2, 5, 5]
|
|||
|
|
4. Умножение на V -> [2, 5, 32]
|
|||
|
|
"""
|
|||
|
|
seq_len = x.shape[1]
|
|||
|
|
if seq_len > self._max_seq_len:
|
|||
|
|
raise ValueError(f"Длина последовательности {seq_len} превышает максимум {self._max_seq_len}")
|
|||
|
|
|
|||
|
|
# 1. Линейные преобразования
|
|||
|
|
k = self._k(x) # [B, T, hs]
|
|||
|
|
q = self._q(x) # [B, T, hs]
|
|||
|
|
|
|||
|
|
# 2. Вычисление scores
|
|||
|
|
scores = q @ k.transpose(-2, -1) / sqrt(self._head_size)
|
|||
|
|
|
|||
|
|
# 3. Применение causal маски
|
|||
|
|
scores = scores.masked_fill(~self._tril_mask[:seq_len, :seq_len], float('-inf'))
|
|||
|
|
|
|||
|
|
# 4. Softmax и умножение на V
|
|||
|
|
weights = F.softmax(scores, dim=-1)
|
|||
|
|
return weights @ self._v(x)
|