mirror of
https://github.com/pese-git/simple-llm.git
synced 2026-01-24 05:26:02 +00:00
Добавление тестов для MultiHeadAttention + финальные правки
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
import math
|
||||
from math import sqrt
|
||||
|
||||
class HeadAttention(nn.Module):
|
||||
@@ -45,6 +46,36 @@ class HeadAttention(nn.Module):
|
||||
mask = torch.tril(torch.ones(max_seq_len, max_seq_len))
|
||||
self.register_buffer('_tril_mask', mask.bool() if hasattr(torch, 'bool') else mask.byte())
|
||||
|
||||
def get_attention_weights(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Возвращает матрицу весов внимания без умножения на V.
|
||||
|
||||
Аргументы:
|
||||
x (torch.Tensor): Входной тензор формы [batch_size, seq_len, emb_size]
|
||||
|
||||
Возвращает:
|
||||
torch.Tensor: Матрица весов внимания формы [batch_size, seq_len, seq_len]
|
||||
|
||||
Пример:
|
||||
>>> attention = HeadAttention(emb_size=64, head_size=32, max_seq_len=128)
|
||||
>>> x = torch.randn(1, 10, 64)
|
||||
>>> weights = attention.get_attention_weights(x) # [1, 10, 10]
|
||||
"""
|
||||
seq_len = x.shape[1]
|
||||
if seq_len > self._max_seq_len:
|
||||
raise ValueError(f"Длина последовательности {seq_len} превышает максимум {self._max_seq_len}")
|
||||
|
||||
# Вычисляем Q и K
|
||||
k = self._k(x) # [B, T, hs]
|
||||
q = self._q(x) # [B, T, hs]
|
||||
|
||||
# Вычисляем scores и применяем маску
|
||||
scores = q @ k.transpose(-2, -1) / math.sqrt(self._head_size)
|
||||
scores = scores.masked_fill(~self._tril_mask[:seq_len, :seq_len], float('-inf'))
|
||||
|
||||
# Возвращаем нормализованные веса
|
||||
return torch.softmax(scores, dim=-1)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Прямой проход через слой внимания.
|
||||
|
||||
Reference in New Issue
Block a user