refactor(core): refactor RoPE and MultiHeadAttention, add math-rich docs, expand tests, remove unused head_attention

- refactor: улучшена и унифицирована реализация RoPE, теперь поддерживаются строгие проверки размерности входа; внесены улучшения и структурные изменения в MultiHeadAttention (более понятная логика, строгая спецификация входов/выходов)
- docs: полностью переписаны docstrings для RoPE и MultiHeadAttention — включены математические формулы, ссылки на научные статьи, подробные пояснения по алгоритму, формату входных данных, ограничениям, примеры использования
- test: добавлены отдельные unit-тесты для RoPE (корректность формы, ошибки на неверную размерность, сохранение нормы, backward/градиенты, работу с параметрами start_pos и батчами)
- chore: удалён неиспользуемый модуль core/head_attention.py
- fix: теперь выбрасывается AssertionError при неправильной размерности входа RoPE; это позволило полностью покрыть тест-кейсы на ошибки

Этот коммит синхронизирует логику реализации базового внимания с современной практикой LLM, укрепляет документацию для инженеров и исследователей, а также расширяет надежность автотестирования библиотеки.
This commit is contained in:
Sergey Penkovsky
2025-10-15 10:59:56 +03:00
parent ec0d2bd8d0
commit d10044e4a7
8 changed files with 468 additions and 371 deletions

View File

@@ -1,121 +0,0 @@
import torch
from torch import nn
import torch.nn.functional as F
from math import sqrt
from .rope import RoPE
class HeadAttention(nn.Module):
"""
Одноголовый механизм внимания (scaled dot-product attention) — фундаментальный строительный блок всех современных Transformer.
Научная суть:
- Attention учит модель самостоятельно "выбирать" важные связи между словами, независимо от их положения.
- Механизм causal mask гарантирует невозможность "заглядывания в будущее" при генерации (авторегрессия).
Формула:
Attention(Q, K, V) = softmax(QK^T / sqrt(d_k)) · V
(Q — запросы, K — ключи, V — значения; d_k — размерность ключа)
Поддерживает Rotary Position Encoding (RoPE) для относительного позиционного кодирования.
Args:
emb_size (int): размер входного эмбеддинга
head_size (int): размерность attention-головы
max_seq_len (int): максимальная длина последовательности
rope (RoPE, optional): экземпляр RoPE для позиций
Примечания:
- Использует нижнетреугольную маску для предотвращения "заглядывания в будущее"
- Автоматически адаптируется к разным версиям PyTorch
- Поддерживает batch-обработку входных данных
Пример использования:
>>> attention = HeadAttention(emb_size=64, head_size=32, max_seq_len=128)
>>> x = torch.randn(1, 10, 64)
>>> output, _ = attention(x)
>>> print(output.shape) # torch.Size([1, 10, 32])
"""
def __init__(
self, emb_size: int, head_size: int, max_seq_len: int, rope: RoPE = None
):
super().__init__()
self._emb_size = emb_size
self._head_size = head_size
self._max_seq_len = max_seq_len
self._rope = rope
# Линейные преобразования для Q, K, V
self._k = nn.Linear(emb_size, head_size)
self._q = nn.Linear(emb_size, head_size)
self._v = nn.Linear(emb_size, head_size)
# Создание causal маски
mask = torch.tril(torch.ones(max_seq_len, max_seq_len))
self.register_buffer(
"_tril_mask", mask.bool() if hasattr(torch, "bool") else mask.byte()
)
def forward(
self, x: torch.Tensor, use_cache: bool = True, cache: tuple = None
) -> tuple:
"""
Прямой проход через слой внимания.
Аргументы:
x (torch.Tensor): Входной тензор формы [batch_size, seq_len, emb_size]
Возвращает:
torch.Tensor: Выходной тензор формы [batch_size, seq_len, head_size]
Исключения:
ValueError: Если длина последовательности превышает max_seq_len
Пример внутренних преобразований:
Для входа x.shape = [2, 5, 64]:
1. Q/K/V преобразования -> [2, 5, 32]
2. Scores = Q·K^T -> [2, 5, 5]
3. После маски и softmax -> [2, 5, 5]
4. Умножение на V -> [2, 5, 32]
"""
seq_len = x.shape[1]
if seq_len > self._max_seq_len:
raise ValueError(
f"Длина последовательности {seq_len} превышает максимум {self._max_seq_len}"
)
k = self._k(x) # [B, T, hs]
q = self._q(x) # [B, T, hs]
v = self._v(x) # [B, T, hs]
start_pos = 0
if cache is not None:
k_cache, v_cache = cache
cache_len = k_cache.shape[1]
start_pos = cache_len
if self._rope is not None:
# ✅ Применяем RoPE к Q и K (НЕ к V!)
q = self._rope(q, start_pos=start_pos) # [B, T, hs]
k = self._rope(k, start_pos=start_pos) # [B, T, hs]
if cache is not None:
k_cache, v_cache = cache
k = torch.cat([k_cache, k], dim=1) # [B, cache_len + T, hs]
v = torch.cat([v_cache, v], dim=1) # [B, cache_len + T, hs]
scores = q @ k.transpose(-2, -1) / sqrt(self._head_size)
if cache is None:
scores = scores.masked_fill(
~self._tril_mask[:seq_len, :seq_len], float("-inf")
)
weights = F.softmax(scores, dim=-1)
x_out = weights @ v # [B, T, hs]
if use_cache is True:
return (x_out, (k, v))
else:
return (x_out, None)

View File

@@ -1,37 +1,70 @@
from torch import nn
import torch
from .head_attention import HeadAttention
import torch.nn.functional as F
from .rope import RoPE
class MultiHeadAttention(nn.Module):
"""
Мультиголовый (многоголовый) механизм внимания — ключевой компонент любого Transformer.
Multi-Head Attention (Многоголовое внимание)
============================================
Научная суть:
- Модель параллельно агрегирует информацию через несколько подпространств (головы),
чтобы видеть разные связи в последовательности (разный контекст, локально/глобально).
- Каждый attention блок работает независимо, выход конкатенируется.
- Механизм предложен в статье "Attention is All You Need" (Vaswani et al., 2017).
Что такое Multi-Head Attention?
-------------------------------
Это ключевой компонент трансформеров, который позволяет "смотреть" на разные части предложения
одновременно с нескольких независимых ракурсов ("голов"). Всё, что делает Single-Head Attention — только гораздо мощнее и глубже!
Формула внимания для одной головы:
Attention(Q, K, V) = softmax(QK^T/sqrt(d_k))·V
Мультиголовый:
MultiHead(Q, K, V) = Concat([head_i])*W^O
Зачем это нужно?
----------------
- Модель может учиться одновременно учитывать и локальные, и глобальные взаимосвязи между токенами.
- Каждая attention head "ловит" свой собственный смысл/зависимости, и на выходе они объединяются.
- Это значительно улучшает понимание сложных зависимостей в тексте, особенно на длинных последовательностях.
Args:
num_heads (int): количество attention "голов"
emb_size (int): размерности входа и выхода
head_size (int): размер одной attention-головы (emb_size/num_heads)
max_seq_len (int): максимальная длина последовательности
rope (RoPE, optional): если задан, используется Rotary Positional Encoding
dropout (float): вероятность регуляризации
Как работает алгоритм? (основная схема)
---------------------------------------
1. Генерируются Q, K, V (query, key, value) — по отдельной проекции для каждой головы.
2. Для каждой головы: attention(Q, K, V) = softmax(Q·K^T / sqrt(d)) · V
3. Все головы "склеиваются" (concatenate) и прогоняются через общий финальный линейный слой.
Почему это работает?
--------------------
- Даёт трансформеру многомерное восприятие текста.
- Позволяет эффективно обучаться на задачах, где порядок и "дальние" связи важнее, чем простое соседство.
Что принимается на вход:
------------------------
- x: shape [batch, seq_len, embed_dim] — обычный batched-embed тензор.
- mask (опционально): shape [seq_len, seq_len] — маска для автогерерации или causal attention.
Какие параметры важны:
----------------------
- num_heads: сколько attention heads внутри (обычно 4, 8, 16...).
- embed_dim: исходная размерность входного тензора.
- head_size: размер одной attention-head (обычно embed_dim // num_heads).
- max_seq_len: максимальная длина последовательности для маски.
Что возвращает:
---------------
- output: shape [batch, seq_len, embed_dim] — результат применения всех attention heads.
- (опционально) cache: кэш для Q/K/V (нужно для генерации по одному токену).
Особенности реализации:
-----------------------
- Оптимизированно работает через матричные умножения (без python for циклов!).
- Включена поддержка causal attention (маска, предотвращающая «заглядывание в будущее»).
- Является ядром любого трансформера (и LLM!).
Пример использования:
>>> mha = MultiHeadAttention(num_heads=8, emb_size=512, head_size=64, max_seq_len=1024)
>>> x = torch.randn(2, 50, 512)
>>> out, cache = mha(x)
>>> print(out.shape)
---------------------
>>> attn = MultiHeadAttention(num_heads=8, embed_dim=256, head_size=32, max_seq_len=1024)
>>> x = torch.randn(2, 128, 256) # [batch, seq_len, embed_dim]
>>> context, _ = attn(x)
>>> print(context.shape) # torch.Size([2, 128, 256])
Где прочитать подробнее:
-------------------------
- Attention is All You Need (Vaswani et al, 2017): https://arxiv.org/abs/1706.03762
- Illustrated Transformer (blog): https://jalammar.github.io/illustrated-transformer/
"""
def __init__(
@@ -44,32 +77,59 @@ class MultiHeadAttention(nn.Module):
dropout: float = 0.1,
):
"""
Инициализация многоголового внимания.
Конструктор многоголового внимания (MultiHeadAttention).
Параметры:
num_heads (int): Количество голов внимания. Типичные значения: 4-16
emb_size (int): Размерность входных и выходных эмбеддингов
head_size (int): Размерность каждой головы внимания (обычно emb_size // num_heads)
max_seq_len (int): Максимальная длина последовательности
dropout (float): Вероятность dropout (по умолчанию 0.1)
Здесь создаются все параметры и внутренние слои для эффективного параллельного внимания (attention) сразу из нескольких "голов".
Контрольные значения:
- num_heads * head_size должно равняться emb_size
- head_size обычно выбирают 32-128
- max_seq_len зависит от задачи (512 для BERT, 2048 для GPT-3)
Аргументы:
----------
num_heads : int
Сколько attention-heads будет внутри слоя.
Каждая “голова” учится видеть уникальные зависимости в тексте. Обычно это 4, 8, 16 и т.п.
Чем больше голов — тем богаче контекст, но и больше памяти.
emb_size : int
Сколько float-значений в каждом входном векторе (размерность embedding).
Обычно это 256, 512, 768, 1024 и т.д.
head_size : int
Сколько компонент будет у каждой головы внимания.
Важно: num_heads * head_size должно ровно совпадать с emb_size!
Обычно head_size = emb_size // num_heads.
max_seq_len : int
Максимально допустимая длина последовательности для attention/маски/генерации.
Определяет размер буферов для causal mask.
rope : RoPE, по умолчанию None
Объект Rotary Positional Encoding (если хотите привнести продвинутое позиционное кодирование в attention).
Не обязателен, но нужен для современных LLM (Llama, Mistral и пр.).
dropout : float, по умолчанию 0.1
Величина dropout (регуляризации) — помогает борьбе с переобучением. Чем больше, тем сильнее регуляризация.
Внутри конструктора происходит:
-------------------------------
- Создаются три линейных слоя для Q, K, V (“где смотреть” и “что вытаскивать” в attention).
- Генерируется нижнетреугольная causal-маска (запрещает видеть будущее для автогерерации).
- Создаётся финальный линейный слой для склейки всех голов в одно пространство emb_size.
- Вводится dropout (случайное зануление, чтобы не было сильной зависимости внимания к отдельным "плейсам").
Пример:
-------
>>> attn = MultiHeadAttention(num_heads=8, emb_size=256, head_size=32, max_seq_len=1024)
"""
super().__init__()
self._heads = nn.ModuleList(
[
HeadAttention(
emb_size=emb_size,
head_size=head_size,
max_seq_len=max_seq_len,
rope=rope,
)
for _ in range(num_heads)
]
self._num_heads = num_heads
self._head_size = head_size
self._max_seq_len = max_seq_len
self._rope = rope
self._q = nn.Linear(emb_size, num_heads * head_size)
self._k = nn.Linear(emb_size, num_heads * head_size)
self._v = nn.Linear(emb_size, num_heads * head_size)
# Создание causal маски
mask = torch.tril(torch.ones(max_seq_len, max_seq_len))
self.register_buffer(
"_tril_mask", mask.bool() if hasattr(torch, "bool") else mask.byte()
)
self._layer = nn.Linear(head_size * num_heads, emb_size)
self._dropout = nn.Dropout(dropout)
@@ -81,61 +141,116 @@ class MultiHeadAttention(nn.Module):
cache: list = None,
):
"""
Прямой проход (forward):
Для каждого токена оценивает "важность" остальных токенов сразу через несколько attention-блоков.
Основной шаг \"многоголового внимания\": находит взаимосвязи между токенами
в последовательности сразу из нескольких “ракурсов” (attention heads).
Подробное описание преобразований тензоров:
1. Входной тензор [batch_size, seq_len, emb_size] разделяется на N голов:
- Каждая голова получает тензор [batch_size, seq_len, head_size]
2. Каждая голова вычисляет attention:
- Вход: [batch_size, seq_len, head_size]
- Выход: [batch_size, seq_len, head_size]
3. Конкатенация результатов:
- Объединенный выход: [batch_size, seq_len, num_heads * head_size]
4. Линейная проекция:
- Выход: [batch_size, seq_len, emb_size]
5. Применение dropout
Что делает этот метод:
----------------------
- Для каждого токена сравнивает его с остальными во входной последовательности.
- Делает это одновременно через несколько attention heads (каждая head видит текст по-своему).
- Итоговое “внимание” — это взвешенная сумма других токенов (контекста) для каждого токена.
- Можно использовать кэш для генерации длинных последовательностей по одному токену (ускоряет инференс).
Args:
x (Tensor[float]): [batch, seq_len, emb_size] — вход
mask (Optional[Tensor[bool]]): маска позиции [seq_len, seq_len]
use_cache (bool): использовать ли key-value кэш (для генерации)
cache (list): предыдущие значения KV для ускорения
Аргументы:
----------
x : torch.Tensor
Входной тензор формы [batch, seq_len, emb_size].
Это ваши входные эмбеддинги (обычно после token + positional embedding).
mask : torch.Tensor, опционально
Матрица формы [seq_len, seq_len], задающая “разрешения” — кто может смотреть на кого (например, causal mask).
Если не указана — используется внутренняя маска (например, для autoregressive генерации).
use_cache : bool, по умолчанию True
Нужно ли использовать кэш для KV attention (важно для ускорения генерации по одному токену).
cache : list, опционально
Предыдущий кэш Key/Value — для генерации текста по частям.
Returns:
out (Tensor[float]): [batch, seq_len, emb_size] — результат MHA
kv_caches (list): списки новых KV-кэшей (если используется)
Возвращает:
-----------
- output: torch.Tensor формы [batch, seq_len, emb_size] — результат применения multi-head attention.
- kv_caches: список новых KV для кэширования при генерации (или None).
Типичный паттерн:
Вход: [batch, seq, emb] → N голов [batch, seq, head_size] →
→ concat [batch, seq, N*head_size] → проекция → dropout
Пример преобразований для emb_size=512, num_heads=8:
Вход: [4, 100, 512]
-> Каждая голова: [4, 100, 64]
-> После внимания: 8 x [4, 100, 64]
-> Конкатенация: [4, 100, 512]
-> Проекция: [4, 100, 512]
-> Dropout: [4, 100, 512]
Важно:
-------
- Shape входа всегда [batch, seq_len, emb_size], выход тот же.
- При seq_len > max_seq_len выбросит ошибку (безопасно для контроля переполнения буферов).
- При использовании use_cache=True кешируется только последние токены (актуально для LLM).
Пример:
>>> out, caches = mha(x)
>>> out.shape # [batch, seq_len, emb_size]
-------
>>> attn = MultiHeadAttention(num_heads=8, emb_size=256, head_size=32, max_seq_len=1024)
>>> x = torch.randn(2, 100, 256)
>>> y, kv_cache = attn(x)
>>> print(y.shape) # torch.Size([2, 100, 256])
"""
# 1. Вычисляем attention для каждой головы
attention_results = []
for i, head in enumerate(self._heads):
head_cache = cache[i] if cache is not None else None
result = head(x, use_cache=use_cache, cache=head_cache)
attention_results.append(result)
batch_size, seq_len, emb_size = x.shape
outputs, caches = zip(*attention_results)
attention_outputs = list(outputs)
kv_caches = list(caches)
if seq_len > self._max_seq_len:
raise ValueError(
f"Длина последовательности {seq_len} превышает максимум {self._max_seq_len}"
)
# 2. Объединяем результаты всех голов
concatenated_attention = torch.cat(attention_outputs, dim=-1)
# Пропустите тензор x через матрицы Wq, Wk , Wv, чтобы получить матрицы запроса, ключа и значения.
k = self._k(x) # [B, T, hs]
q = self._q(x) # [B, T, hs]
v = self._v(x) # [B, T, hs]
# Шаг 2: Изменение формы для multi-head
# [batch_size, seq_len, num_heads * head_size]
# -> [batch_size, seq_len, num_heads, head_size]
q = q.reshape(batch_size, seq_len, self._num_heads, self._head_size)
k = k.reshape(batch_size, seq_len, self._num_heads, self._head_size)
v = v.reshape(batch_size, seq_len, self._num_heads, self._head_size)
# 3. Transpose: [B, T, H, hs] -> [B, H, T, hs]
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
start_pos = 0
if cache is not None:
k_cache, v_cache = cache
cache_len = k_cache.shape[2]
start_pos = cache_len
# Пропустите матрицы запроса и ключа через экземпляр rope, чтобы выполнить поворот.
if self._rope is not None:
# ✅ Применяем RoPE к Q и K (НЕ к V!)
q = self._rope(q, start_pos=start_pos) # [B, T, hs]
k = self._rope(k, start_pos=start_pos) # [B, T, hs]
# Если cache пришел, то объединяем кэш и одну строку из ключа и значения. Это будут новые key и value для последующих вычислений.
# 5. Кэширование (для autoregressive generation)
if cache is not None:
k_cache, v_cache = cache
k = torch.cat([k_cache, k], dim=2) # Concat по seq_len (dim=2)
v = torch.cat([v_cache, v], dim=2)
# Перемножим матрицы запроса и ключа (транспонированную), чтобы вычислить матрицу внимания.
# И разделить все значения в матрице внимания на корень из head_size.
scores = q @ k.transpose(-2, -1) / (self._head_size ** 0.5)
# Если cache пришел, то маску не накладываем. Иначе наложите на матрицу внимания треугольную маску, созданную при инициализации. Все скрытые значения должны быть приведены к минус бесконечности: float('-inf').
if cache is None:
scores = scores.masked_fill(
~self._tril_mask[:seq_len, :seq_len], float("-inf")
)
# Применить к матрице внимания (построчно) функцию Softmax.
weights = F.softmax(scores, dim=-1)
# Перемножим матрицу внимания и матрицу значения.
x_out = weights @ v # [B, T, hs]
# Измените форму тензора на batch_size × seq_len × num_heads*head_size.
# Transpose обратно и concatenate heads
x_out = x_out.transpose(1, 2) # [B, T_q, H, hs]
x_out = x_out.contiguous() # Важно для reshape!
concatenated_attention = x_out.reshape(batch_size, seq_len, self._num_heads * self._head_size)
#concatenated_attention = x_out.reshape(batch_size, seq_len, self._num_heads * self._head_size)
# Пропустите получившийся тензор через последний линейный слой.
# 3. Проецируем в пространство эмбеддингов
projected_output = self._layer(concatenated_attention)
@@ -143,6 +258,6 @@ class MultiHeadAttention(nn.Module):
final_output = self._dropout(projected_output)
if use_cache is True:
return (final_output, kv_caches)
return (final_output, (k, v))
else:
return (final_output, None)

View File

@@ -1,21 +1,51 @@
"""
Rotary Positional Embeddings (RoPE) - ротационные позиционные эмбеддинги.
Rotary Positional Embeddings (RoPE)
===================================
Реализация ротационного позиционного кодирования, которое кодирует позиционную
информацию через вращение векторов запросов и ключей в комплексном пространстве.
Что такое RoPE?
----------------
RoPE — это способ "вписать" информацию о позиции токенов в скрытые вектора модели трансформера.
Вместо простого сложения с абсолютным positional embedding, RoPE использует вращения векторов (как поворот стрелки на круге) внутри каждого attention head. Каждый элемент пары (вектор четного и нечетного индекса) поворачивается на угол, зависящий от позиции токена.
Научная статья: "RoFormer: Enhanced Transformer with Rotary Position Embedding"
https://arxiv.org/abs/2104.09864
Зачем это?
-----------
- RoPE реализует **относительное позиционное кодирование**: модель может сравнивать "расстояния" между токенами, а не просто помнить положение.
- Такое кодирование **улучшает генерацию длинных последовательностей** и перенос модели на тексты большей длины, чем были в обучении.
- Форма векторов и длина (норма) НЕ искажаются.
Математическая основа:
Для позиции m и измерения i:
θ_i = base^(-2i/d)
q'_m = q_m * cos(mθ_i) + rotate(q_m) * sin(mθ_i)
Как это работает? (главная формула)
-------------------------------------
Для каждой позиции m и пары компонент (2i, 2i+1) внутри head применяются:
θ_i = base^(-2i / d)
q'_{m,2i} = q_{m,2i} * cos(m * θ_i) - q_{m,2i+1} * sin(m * θ_i)
q'_{m,2i+1} = q_{m,2i+1} * cos(m * θ_i) + q_{m,2i} * sin(m * θ_i)
где d — размерность "головы" attention (head_size), base обычно 10_000.
То есть, берём каждый "вектор" (в рамках head), делим на четные/нечетные части и поворачиваем их на уникальный угол, связанный с позицией/частотой.
Архитектурные детали:
---------------------
- Ваш тензор должен быть строго 4-мерным: [batch, num_heads, seq_len, head_size].
- Размер head_size должен быть чётным!
- RoPE применяется отдельно к **Q** и **K** в механизме внимания (но не к V).
Где об этом читать:
-------------------
- RoFormer: Enhanced Transformer with Rotary Position Embedding
https://arxiv.org/abs/2104.09864
- Llama: Open and Efficient Foundation Language Models
https://arxiv.org/abs/2302.13971
- Визуализация позиционных кодировок:
https://kazemnejad.com/blog/transformer_architecture_positional_encoding/
Пример использования:
---------------------
>>> rope = RoPE(head_size=64, max_seq_len=2048)
>>> x = torch.randn(2, 8, 128, 64) # [batch, num_heads, seq_len, head_size]
>>> x_enc = rope(x) # здесь вектор x обогатится позиционной информацией
Преимущества:
- Относительное позиционное кодирование
- Лучшая экстраполяция на длинные последовательности
- Сохранение нормы векторов
"""
import torch
@@ -25,32 +55,72 @@ from typing import Optional
class RoPE(nn.Module):
"""
Rotary Positional Embeddings (RoPE) для механизма внимания.
Реализация RoPE (Rotary Positional Embeddings) для self-attention в трансформерах.
Кодирует позиционную информацию через вращение векторов запросов и ключей
в многомерном пространстве с использованием синусов и косинусов.
Этот слой добавляет позиционную информацию к векторам внимания (Q, K) —
не с помощью простого сложения с positional embedding, а с помощью математического
вращения (как если бы вы крутили стрелку на круге) для каждой пары компонент
(even/odd) в каждом attention head.
Args:
head_size: Размерность головы внимания (должен быть четным)
max_seq_len: Максимальная длина последовательности
base: Базовое значение для вычисления частот (по умолчанию 10000)
Формула (для каждого токена и каждой пары компонент внутри head):
θ_i = base^(-2i / d)
out_{m,2i} = x_{m,2i} * cos(m * θ_i) - x_{m,2i+1} * sin(m * θ_i)
out_{m,2i+1} = x_{m,2i+1} * cos(m * θ_i) + x_{m,2i} * sin(m * θ_i)
где d — head_size, base обычно 10_000, степень i по head axis.
Attributes:
cos_matrix: Буферизованная матрица косинусов формы [max_seq_len, head_size//2]
sin_matrix: Буферизованная матрица синусов формы [max_seq_len, head_size//2]
Какие входы принимает:
----------------------
- x: обязательно размерности [batch, num_heads, seq_len, head_size]!
- head_size (размер внимания) должен быть чётным.
- start_pos: опционально, позволяет сдвигать позиционный offset для генерации с кэшем.
Что возвращает:
---------------
- Тот же тензор (x), только со встроенной позиционной информацией (“повёрнутый” RoPE-кодировкой).
- Форма и тип выходного тензора не меняются.
Где используется:
-----------------
- В любых современных LLM (Llama, Mistral, GPT-NeoX и др.) для повышения устойчивости и generalization transformer's attention.
Пример использования:
---------------------
>>> rope = RoPE(head_size=64, max_seq_len=2048)
>>> x = torch.randn(2, 8, 128, 64) # (batch, num_heads, seq_len, head_size)
>>> x_encoded = rope(x)
Подробнее про математику и примеры с визуализацией:
---------------------------------------------------
- RoFormer: https://arxiv.org/abs/2104.09864
- Llama: https://arxiv.org/abs/2302.13971
- Демонстрация наглядно: https://kazemnejad.com/blog/transformer_architecture_positional_encoding/
"""
def __init__(self, head_size: int, max_seq_len: int, base: int = 10_000):
"""
Инициализация RoPE эмбеддингов.
Инициализация объекта RoPE — настраивает и предвычисляет все необходимые
параметры для ротационного позиционного кодирования.
Args:
head_size: Размерность головы внимания (должен быть четным)
max_seq_len: Максимальная поддерживаемая длина последовательности
base: Базовое значение для вычисления частот (типично 10000)
Аргументы:
----------
head_size : int
Размер одного attention head (последнего измерения вектора) — сколько компонент
(float-значений) отвечает за одну "голову". Должен быть ЧЁТНЫМ числом, иначе RoPE не применим.
Обычно head_size = embed_dim // num_heads.
max_seq_len : int
Максимальная длина последовательности, которую RoPE сможет обработать.
Если ваш текст длиннее этого числа — будет ошибка! Например, для GPT2 обычно 1024, у LLaMA — до 4096.
Это число определяет размер внутренних буферов cos/sin.
base : int, по умолчанию 10_000
База для вычисления частот вращения (θ_i) для каждой компоненты.
В оригинальных статьях почти всегда используют base=10000.
Менять этот параметр не нужно, если вы не исследуете математические детали.
Raises:
AssertionError: Если head_size не четный
Что происходит внутри:
----------------------
- Проверяется чётность head_size.
- Для каждого возможного положения в пределах max_seq_len и каждой пары component высчитываются уникальные cos/sin значения (матрицы частот).
- Эти матрицы используются далее для быстрого наложения позиционного "вращения" токенов внутри attention.
"""
super().__init__()
assert head_size % 2 == 0, "head_size должен быть четным"
@@ -70,28 +140,51 @@ class RoPE(nn.Module):
def forward(self, x: torch.Tensor, start_pos: int = 0) -> torch.Tensor:
"""
Применение ротационного позиционного кодирования к входному тензору.
Применяет ротационное позиционное кодирование (RoPE) к входному тензору.
Args:
x: Входной тензор формы [batch_size, seq_len, head_size]
Что делает эта функция:
-----------------------
Для каждого токена в последовательности внутри каждого attention head
"поворачивает" его вектор в подпространстве (even/odd пар) на свой уникальный угол,
зависящий от позиции токена. Это позволяет attention "понимать расстояния" между токенами.
Returns:
Тензор с примененным RoPE формы [batch_size, seq_len, head_size]
Аргументы:
----------
x : torch.Tensor
Входной тензор строго формы [batch, num_heads, seq_len, head_size].
Это обычно либо Q, либо K из механизма внимания.
start_pos : int, по умолчанию 0
Сдвиг начала позиции (нужно при генерации с кэшем, почти всегда оставить 0 если не пишете автогенератор).
Алгоритм:
1. Разделение векторов на четные и нечетные компоненты
2. Применение вращения через синусы и косинусы
3. Объединение компонент обратно
Возвращает:
-----------
torch.Tensor с теми же формой и типом, что и x, но уже с наложенным позиционным кодированием.
Важно:
-------
- Если передан тензор не 4D, будет выброшено исключение!
- Не изменяет значения "на месте", всегда возвращает новый тензор.
Пример:
-------
>>> rope = RoPE(head_size=64, max_seq_len=1024)
>>> q = torch.randn(2, 8, 32, 64) # batch, num_heads, seq_len, head_size
>>> q_rope = rope(q)
"""
batch_size, seq_len, emb_size = x.shape
assert x.ndim == 4, "RoPE поддерживает только 4D-вход [batch, num_heads, seq_len, head_size]"
batch_size, num_heads, seq_len, head_size = x.shape
# Берем нужную часть матриц и приводим к типу x
cos = self.cos_matrix[start_pos:start_pos+seq_len].to(x.dtype) # [seq_len, head_size//2]
sin = self.sin_matrix[start_pos:start_pos+seq_len].to(x.dtype) # [seq_len, head_size//2]
# Разделяем на четные и нечетные компоненты
x_even = x[:, :, 0::2] # [batch_size, seq_len, head_size//2]
x_odd = x[:, :, 1::2] # [batch_size, seq_len, head_size//2]
# Явное изменение формы для broadcasting
cos = cos.reshape(1, 1, seq_len, head_size // 2)
sin = sin.reshape(1, 1, seq_len, head_size // 2)
# Разделяем на четные и нечетные компоненты по ПОСЛЕДНЕМУ измерению
x_even = x[..., 0::2] # [batch_size, num_heads, seq_len, head_size//2]
x_odd = x[..., 1::2] # [batch_size, num_heads, seq_len, head_size//2]
# Применяем поворот: q' = q * cos(mθ) + rotate(q) * sin(mθ)
x_rotated_even = x_even * cos - x_odd * sin

View File

@@ -104,12 +104,20 @@ class GPT2(BaseModel):
# Вычисление start_pos из кэша (если кэш передан)
if cache is not None:
# При кэше обрабатываем только один токен (последний)
seq_len = 1
# Вычисляем start_pos из самого нижнего уровня кэша
if cache and cache[0] and cache[0][0]:
key_cache, _ = cache[0][0] # Первый декодер, первая голова
start_pos = key_cache.size(1) # cache_len
# Безопасно извлекаем key_cache для вычисления start_pos
if (
isinstance(cache, (list, tuple))
and len(cache) > 0
and cache[0] is not None
and isinstance(cache[0], (list, tuple))
and len(cache[0]) > 0
and cache[0][0] is not None
and isinstance(cache[0][0], (tuple, list))
and len(cache[0][0]) > 0
):
key_cache, _ = cache[0][0]
start_pos = key_cache.size(1)
else:
start_pos = 0
else:

View File

@@ -4,71 +4,12 @@ from torch import Tensor
import torch.nn.functional as F
from math import sqrt
from llm.core.base_model import BaseModel
class SiLU(nn.Module):
def forward(self, x: torch.Tensor): # [batch_size × seq_len × emb_size]
return torch.sigmoid(x) * x
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self._eps = eps
self._w = nn.Parameter(torch.ones(dim))
def forward(self, x: torch.Tensor): # [batch_size × seq_len × emb_size]
rms = (x.pow(2).mean(-1, keepdim=True) + self._eps) ** 0.5
norm_x = x / rms
return self._w * norm_x
class SwiGLU(nn.Module):
def __init__(self, emb_size: int, dropout: float = 0.1):
super().__init__()
self._gate = nn.Linear(emb_size, 4 * emb_size)
self._up = nn.Linear(emb_size, 4 * emb_size)
self._down = nn.Linear(4 * emb_size, emb_size)
self._activation = SiLU()
self._dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor): # [batch_size × seq_len × emb_size].
gate_out = self._gate(x) # [batch, seq, 4*emb]
activation_out = self._activation(gate_out) # [batch, seq, 4*emb]
up_out = self._up(x) # [batch, seq, 4*emb]
out = up_out * activation_out # поэлементное!
out = self._down(out) # [batch, seq, emb]
return self._dropout(out)
class TokenEmbeddings(nn.Module):
def __init__(self, vocab_size: int, emb_size: int):
super().__init__()
self._embedding = nn.Embedding(
num_embeddings=vocab_size,
embedding_dim=emb_size
)
def forward(self, x: Tensor) -> Tensor:
return self._embedding(x)
@property
def num_embeddings(self) -> int:
return self._embedding.num_embeddings
@property
def embedding_dim(self) -> int:
return self._embedding.embedding_dim
class GELU(nn.Module):
def __init__(self):
super().__init__()
self.sqrt_2_over_pi = torch.sqrt(torch.tensor(2.0) / math.pi)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return 0.5 * x * (1 + torch.tanh(
self.sqrt_2_over_pi * (x + 0.044715 * torch.pow(x, 3))
))
from llm.core.token_embeddings import TokenEmbeddings
from llm.core.silu import SiLU
from llm.core.rms_norm import RMSNorm
from llm.core.swi_glu import SwiGLU
from llm.core.gelu import GELU
from llm.core.rope import RoPE
@@ -77,49 +18,49 @@ from torch import nn
from typing import Optional
class RoPE(nn.Module):
def __init__(self, head_size: int, max_seq_len: int, base: int = 10_000):
super().__init__()
assert head_size % 2 == 0, "head_size должен быть четным"
# Вычисление частот: θ_i = base^(-2i/d) для i ∈ [0, d/2-1]
freqs = 1.0 / (base ** (2 * torch.arange(head_size // 2).float() / head_size))
# Позиции от 0 до max_seq_len-1
positions = torch.arange(max_seq_len).float()
# Внешнее произведение: m * θ_i для всех позиций и частот
freq_matrix = positions.unsqueeze(1) * freqs.unsqueeze(0)
# Предвычисление матриц косинусов и синусов
self.register_buffer("cos_matrix", torch.cos(freq_matrix))
self.register_buffer("sin_matrix", torch.sin(freq_matrix))
def forward(self, x: torch.Tensor, start_pos: int = 0) -> torch.Tensor: # [batch_size × seq_len × head_size] [batch_size × num_heads × seq_len × head_size]
batch_size, num_heads, seq_len, head_size = x.shape
# Берем нужную часть матриц и приводим к типу x
cos = self.cos_matrix[start_pos:start_pos+seq_len].to(x.dtype) # [seq_len, head_size//2]
sin = self.sin_matrix[start_pos:start_pos+seq_len].to(x.dtype) # [seq_len, head_size//2]
# Явное изменение формы для broadcasting
cos = cos.reshape(1, 1, seq_len, head_size // 2)
sin = sin.reshape(1, 1, seq_len, head_size // 2)
# Разделяем на четные и нечетные компоненты по ПОСЛЕДНЕМУ измерению
x_even = x[..., 0::2] # [batch_size, num_heads, seq_len, head_size//2]
x_odd = x[..., 1::2] # [batch_size, num_heads, seq_len, head_size//2]
# Применяем поворот: q' = q * cos(mθ) + rotate(q) * sin(mθ)
x_rotated_even = x_even * cos - x_odd * sin
x_rotated_odd = x_even * sin + x_odd * cos
# Объединяем обратно в исходную размерность
x_rotated = torch.stack([x_rotated_even, x_rotated_odd], dim=-1)
x_rotated = x_rotated.flatten(-2) # [batch_size, seq_len, head_size]
return x_rotated
#class RoPE(nn.Module):
#
# def __init__(self, head_size: int, max_seq_len: int, base: int = 10_000):
# super().__init__()
# assert head_size % 2 == 0, "head_size должен быть четным"
#
# # Вычисление частот: θ_i = base^(-2i/d) для i ∈ [0, d/2-1]
# freqs = 1.0 / (base ** (2 * torch.arange(head_size // 2).float() / head_size))
#
# # Позиции от 0 до max_seq_len-1
# positions = torch.arange(max_seq_len).float()
#
# # Внешнее произведение: m * θ_i для всех позиций и частот
# freq_matrix = positions.unsqueeze(1) * freqs.unsqueeze(0)
#
# # Предвычисление матриц косинусов и синусов
# self.register_buffer("cos_matrix", torch.cos(freq_matrix))
# self.register_buffer("sin_matrix", torch.sin(freq_matrix))
#
# def forward(self, x: torch.Tensor, start_pos: int = 0) -> torch.Tensor: # [batch_size × seq_len × head_size] [batch_size × num_heads × seq_len × head_size]
# batch_size, num_heads, seq_len, head_size = x.shape
#
# # Берем нужную часть матриц и приводим к типу x
# cos = self.cos_matrix[start_pos:start_pos+seq_len].to(x.dtype) # [seq_len, head_size//2]
# sin = self.sin_matrix[start_pos:start_pos+seq_len].to(x.dtype) # [seq_len, head_size//2]
#
# # Явное изменение формы для broadcasting
# cos = cos.reshape(1, 1, seq_len, head_size // 2)
# sin = sin.reshape(1, 1, seq_len, head_size // 2)
#
# # Разделяем на четные и нечетные компоненты по ПОСЛЕДНЕМУ измерению
# x_even = x[..., 0::2] # [batch_size, num_heads, seq_len, head_size//2]
# x_odd = x[..., 1::2] # [batch_size, num_heads, seq_len, head_size//2]
#
# # Применяем поворот: q' = q * cos(mθ) + rotate(q) * sin(mθ)
# x_rotated_even = x_even * cos - x_odd * sin
# x_rotated_odd = x_even * sin + x_odd * cos
#
# # Объединяем обратно в исходную размерность
# x_rotated = torch.stack([x_rotated_even, x_rotated_odd], dim=-1)
# x_rotated = x_rotated.flatten(-2) # [batch_size, seq_len, head_size]
#
# return x_rotated
import torch

View File

@@ -19,7 +19,7 @@ class TestMultiHeadAttention:
assert attention is not None
# Check internal attributes
assert len(attention._heads) == num_heads
assert attention._num_heads == num_heads
assert attention._layer.in_features == embed_dim
assert attention._layer.out_features == embed_dim
@@ -102,8 +102,10 @@ class TestMultiHeadAttention:
# Check that gradients are computed for learnable parameters
assert attention._layer.weight.grad is not None
if len(attention._heads) > 0:
assert attention._heads[0]._q.weight.grad is not None
# Проверяем, что также у градиентов весов q/k/v есть значения
assert attention._q.weight.grad is not None
assert attention._k.weight.grad is not None
assert attention._v.weight.grad is not None
def test_device_consistency(self, embed_dim, num_heads, random_embeddings, device):
"""Test that MultiHeadAttention works on correct device."""

View File

@@ -0,0 +1,55 @@
import torch
import pytest
from llm.core.rope import RoPE
def test_rope_shapes_and_dtype():
rope = RoPE(head_size=8, max_seq_len=32)
x = torch.randn(2, 4, 16, 8) # [batch, num_heads, seq_len, head_size]
y = rope(x)
assert y.shape == x.shape
assert y.dtype == x.dtype
def test_rope_raises_on_bad_ndim():
rope = RoPE(head_size=8, max_seq_len=16)
x = torch.randn(2, 16, 8) # [batch, seq_len, head_size] (3D)
with pytest.raises(AssertionError):
_ = rope(x)
def test_rope_preserves_norm():
rope = RoPE(head_size=8, max_seq_len=16)
x = torch.randn(2, 3, 7, 8)
x_norm = x.norm(dim=-1)
y = rope(x)
y_norm = y.norm(dim=-1)
# Нормы могут немного отличаться из-за float, сравниваем с допуском
assert torch.allclose(x_norm, y_norm, rtol=1e-5, atol=1e-7)
def test_rope_backward_pass():
rope = RoPE(head_size=8, max_seq_len=16)
x = torch.randn(2, 2, 8, 8, requires_grad=True)
out = rope(x)
loss = out.sum()
loss.backward()
assert x.grad is not None
assert x.grad.shape == x.shape
@pytest.mark.parametrize("batch,num_heads,seq_len,head_size", [
(1, 1, 4, 8),
(2, 4, 16, 8),
(3, 2, 7, 8),
])
def test_rope_various_shapes(batch, num_heads, seq_len, head_size):
rope = RoPE(head_size=head_size, max_seq_len=32)
x = torch.randn(batch, num_heads, seq_len, head_size)
y = rope(x)
assert y.shape == x.shape
def test_rope_start_pos():
rope = RoPE(head_size=8, max_seq_len=32)
x_full = torch.randn(1, 2, 8, 8)
# Сравниваем участок результата для разных start_pos
out1 = rope(x_full)
out2 = rope(x_full, start_pos=2)
assert not torch.allclose(out1, out2)
# Для одинакового start_pos и x должны совпадать
assert torch.allclose(rope(x_full, start_pos=1), rope(x_full, start_pos=1))

View File

@@ -145,7 +145,11 @@ class TestGPT:
assert model._token_embeddings._embedding.weight.grad is not None
assert model._linear.weight.grad is not None
if len(model._decoders) > 0:
assert model._decoders[0]._heads._heads[0]._q.weight.grad is not None
# Проверяем через новый интерфейс attention оптимизации:
attn = model._decoders[0]._heads
assert attn._q.weight.grad is not None
assert attn._k.weight.grad is not None
assert attn._v.weight.grad is not None
def test_device_consistency(self, gpt_config, random_inputs, device):
"""Test that GPT works on correct device."""