mirror of
https://github.com/pese-git/simple-llm.git
synced 2026-01-23 21:14:17 +00:00
115 lines
5.4 KiB
Python
115 lines
5.4 KiB
Python
import torch
|
||
from torch import nn
|
||
import torch.nn.functional as F
|
||
import math
|
||
from math import sqrt
|
||
|
||
class HeadAttention(nn.Module):
|
||
"""
|
||
Реализация одного головного механизма внимания из архитектуры Transformer.
|
||
Выполняет scaled dot-product attention с маскированием будущих позиций (causal attention).
|
||
|
||
Основной алгоритм:
|
||
1. Линейные преобразования входных данных в Q (query), K (key), V (value)
|
||
2. Вычисление scores = Q·K^T / sqrt(d_k)
|
||
3. Применение causal маски (заполнение -inf будущих позиций)
|
||
4. Softmax для получения весов внимания
|
||
5. Умножение весов на значения V
|
||
|
||
Пример использования:
|
||
>>> attention = HeadAttention(emb_size=64, head_size=32, max_seq_len=128)
|
||
>>> x = torch.randn(1, 10, 64) # [batch_size, seq_len, emb_size]
|
||
>>> output = attention(x) # [1, 10, 32]
|
||
|
||
Параметры:
|
||
emb_size (int): Размер входного эмбеддинга
|
||
head_size (int): Размерность выхода головы внимания
|
||
max_seq_len (int): Максимальная длина последовательности
|
||
|
||
Примечания:
|
||
- Использует нижнетреугольную маску для предотвращения "заглядывания в будущее"
|
||
- Автоматически адаптируется к разным версиям PyTorch
|
||
- Поддерживает batch-обработку входных данных
|
||
"""
|
||
def __init__(self, emb_size: int, head_size: int, max_seq_len: int):
|
||
super().__init__()
|
||
self._emb_size = emb_size
|
||
self._head_size = head_size
|
||
self._max_seq_len = max_seq_len
|
||
|
||
# Линейные преобразования для Q, K, V
|
||
self._k = nn.Linear(emb_size, head_size)
|
||
self._q = nn.Linear(emb_size, head_size)
|
||
self._v = nn.Linear(emb_size, head_size)
|
||
|
||
# Создание causal маски
|
||
mask = torch.tril(torch.ones(max_seq_len, max_seq_len))
|
||
self.register_buffer('_tril_mask', mask.bool() if hasattr(torch, 'bool') else mask.byte())
|
||
|
||
def get_attention_weights(self, x: torch.Tensor) -> torch.Tensor:
|
||
"""
|
||
Возвращает матрицу весов внимания без умножения на V.
|
||
|
||
Аргументы:
|
||
x (torch.Tensor): Входной тензор формы [batch_size, seq_len, emb_size]
|
||
|
||
Возвращает:
|
||
torch.Tensor: Матрица весов внимания формы [batch_size, seq_len, seq_len]
|
||
|
||
Пример:
|
||
>>> attention = HeadAttention(emb_size=64, head_size=32, max_seq_len=128)
|
||
>>> x = torch.randn(1, 10, 64)
|
||
>>> weights = attention.get_attention_weights(x) # [1, 10, 10]
|
||
"""
|
||
seq_len = x.shape[1]
|
||
if seq_len > self._max_seq_len:
|
||
raise ValueError(f"Длина последовательности {seq_len} превышает максимум {self._max_seq_len}")
|
||
|
||
# Вычисляем Q и K
|
||
k = self._k(x) # [B, T, hs]
|
||
q = self._q(x) # [B, T, hs]
|
||
|
||
# Вычисляем scores и применяем маску
|
||
scores = q @ k.transpose(-2, -1) / math.sqrt(self._head_size)
|
||
scores = scores.masked_fill(~self._tril_mask[:seq_len, :seq_len], float('-inf'))
|
||
|
||
# Возвращаем нормализованные веса
|
||
return torch.softmax(scores, dim=-1)
|
||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||
"""
|
||
Прямой проход через слой внимания.
|
||
|
||
Аргументы:
|
||
x (torch.Tensor): Входной тензор формы [batch_size, seq_len, emb_size]
|
||
|
||
Возвращает:
|
||
torch.Tensor: Выходной тензор формы [batch_size, seq_len, head_size]
|
||
|
||
Исключения:
|
||
ValueError: Если длина последовательности превышает max_seq_len
|
||
|
||
Пример внутренних преобразований:
|
||
Для входа x.shape = [2, 5, 64]:
|
||
1. Q/K/V преобразования -> [2, 5, 32]
|
||
2. Scores = Q·K^T -> [2, 5, 5]
|
||
3. После маски и softmax -> [2, 5, 5]
|
||
4. Умножение на V -> [2, 5, 32]
|
||
"""
|
||
seq_len = x.shape[1]
|
||
if seq_len > self._max_seq_len:
|
||
raise ValueError(f"Длина последовательности {seq_len} превышает максимум {self._max_seq_len}")
|
||
|
||
# 1. Линейные преобразования
|
||
k = self._k(x) # [B, T, hs]
|
||
q = self._q(x) # [B, T, hs]
|
||
|
||
# 2. Вычисление scores
|
||
scores = q @ k.transpose(-2, -1) / sqrt(self._head_size)
|
||
|
||
# 3. Применение causal маски
|
||
scores = scores.masked_fill(~self._tril_mask[:seq_len, :seq_len], float('-inf'))
|
||
|
||
# 4. Softmax и умножение на V
|
||
weights = F.softmax(scores, dim=-1)
|
||
return weights @ self._v(x) |