Files
simple-llm/simple_llm/transformer/head_attention.py

115 lines
5.4 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
from torch import nn
import torch.nn.functional as F
import math
from math import sqrt
class HeadAttention(nn.Module):
"""
Реализация одного головного механизма внимания из архитектуры Transformer.
Выполняет scaled dot-product attention с маскированием будущих позиций (causal attention).
Основной алгоритм:
1. Линейные преобразования входных данных в Q (query), K (key), V (value)
2. Вычисление scores = Q·K^T / sqrt(d_k)
3. Применение causal маски (заполнение -inf будущих позиций)
4. Softmax для получения весов внимания
5. Умножение весов на значения V
Пример использования:
>>> attention = HeadAttention(emb_size=64, head_size=32, max_seq_len=128)
>>> x = torch.randn(1, 10, 64) # [batch_size, seq_len, emb_size]
>>> output = attention(x) # [1, 10, 32]
Параметры:
emb_size (int): Размер входного эмбеддинга
head_size (int): Размерность выхода головы внимания
max_seq_len (int): Максимальная длина последовательности
Примечания:
- Использует нижнетреугольную маску для предотвращения "заглядывания в будущее"
- Автоматически адаптируется к разным версиям PyTorch
- Поддерживает batch-обработку входных данных
"""
def __init__(self, emb_size: int, head_size: int, max_seq_len: int):
super().__init__()
self._emb_size = emb_size
self._head_size = head_size
self._max_seq_len = max_seq_len
# Линейные преобразования для Q, K, V
self._k = nn.Linear(emb_size, head_size)
self._q = nn.Linear(emb_size, head_size)
self._v = nn.Linear(emb_size, head_size)
# Создание causal маски
mask = torch.tril(torch.ones(max_seq_len, max_seq_len))
self.register_buffer('_tril_mask', mask.bool() if hasattr(torch, 'bool') else mask.byte())
def get_attention_weights(self, x: torch.Tensor) -> torch.Tensor:
"""
Возвращает матрицу весов внимания без умножения на V.
Аргументы:
x (torch.Tensor): Входной тензор формы [batch_size, seq_len, emb_size]
Возвращает:
torch.Tensor: Матрица весов внимания формы [batch_size, seq_len, seq_len]
Пример:
>>> attention = HeadAttention(emb_size=64, head_size=32, max_seq_len=128)
>>> x = torch.randn(1, 10, 64)
>>> weights = attention.get_attention_weights(x) # [1, 10, 10]
"""
seq_len = x.shape[1]
if seq_len > self._max_seq_len:
raise ValueError(f"Длина последовательности {seq_len} превышает максимум {self._max_seq_len}")
# Вычисляем Q и K
k = self._k(x) # [B, T, hs]
q = self._q(x) # [B, T, hs]
# Вычисляем scores и применяем маску
scores = q @ k.transpose(-2, -1) / math.sqrt(self._head_size)
scores = scores.masked_fill(~self._tril_mask[:seq_len, :seq_len], float('-inf'))
# Возвращаем нормализованные веса
return torch.softmax(scores, dim=-1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Прямой проход через слой внимания.
Аргументы:
x (torch.Tensor): Входной тензор формы [batch_size, seq_len, emb_size]
Возвращает:
torch.Tensor: Выходной тензор формы [batch_size, seq_len, head_size]
Исключения:
ValueError: Если длина последовательности превышает max_seq_len
Пример внутренних преобразований:
Для входа x.shape = [2, 5, 64]:
1. Q/K/V преобразования -> [2, 5, 32]
2. Scores = Q·K^T -> [2, 5, 5]
3. После маски и softmax -> [2, 5, 5]
4. Умножение на V -> [2, 5, 32]
"""
seq_len = x.shape[1]
if seq_len > self._max_seq_len:
raise ValueError(f"Длина последовательности {seq_len} превышает максимум {self._max_seq_len}")
# 1. Линейные преобразования
k = self._k(x) # [B, T, hs]
q = self._q(x) # [B, T, hs]
# 2. Вычисление scores
scores = q @ k.transpose(-2, -1) / sqrt(self._head_size)
# 3. Применение causal маски
scores = scores.masked_fill(~self._tril_mask[:seq_len, :seq_len], float('-inf'))
# 4. Softmax и умножение на V
weights = F.softmax(scores, dim=-1)
return weights @ self._v(x)