mirror of
https://github.com/pese-git/llm-arch-research.git
synced 2026-01-23 21:10:54 +00:00
refactor(core): refactor RoPE and MultiHeadAttention, add math-rich docs, expand tests, remove unused head_attention
- refactor: улучшена и унифицирована реализация RoPE, теперь поддерживаются строгие проверки размерности входа; внесены улучшения и структурные изменения в MultiHeadAttention (более понятная логика, строгая спецификация входов/выходов) - docs: полностью переписаны docstrings для RoPE и MultiHeadAttention — включены математические формулы, ссылки на научные статьи, подробные пояснения по алгоритму, формату входных данных, ограничениям, примеры использования - test: добавлены отдельные unit-тесты для RoPE (корректность формы, ошибки на неверную размерность, сохранение нормы, backward/градиенты, работу с параметрами start_pos и батчами) - chore: удалён неиспользуемый модуль core/head_attention.py - fix: теперь выбрасывается AssertionError при неправильной размерности входа RoPE; это позволило полностью покрыть тест-кейсы на ошибки Этот коммит синхронизирует логику реализации базового внимания с современной практикой LLM, укрепляет документацию для инженеров и исследователей, а также расширяет надежность автотестирования библиотеки.
This commit is contained in:
@@ -1,121 +0,0 @@
|
|||||||
import torch
|
|
||||||
from torch import nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from math import sqrt
|
|
||||||
from .rope import RoPE
|
|
||||||
|
|
||||||
|
|
||||||
class HeadAttention(nn.Module):
|
|
||||||
"""
|
|
||||||
Одноголовый механизм внимания (scaled dot-product attention) — фундаментальный строительный блок всех современных Transformer.
|
|
||||||
|
|
||||||
Научная суть:
|
|
||||||
- Attention учит модель самостоятельно "выбирать" важные связи между словами, независимо от их положения.
|
|
||||||
- Механизм causal mask гарантирует невозможность "заглядывания в будущее" при генерации (авторегрессия).
|
|
||||||
|
|
||||||
Формула:
|
|
||||||
Attention(Q, K, V) = softmax(QK^T / sqrt(d_k)) · V
|
|
||||||
(Q — запросы, K — ключи, V — значения; d_k — размерность ключа)
|
|
||||||
|
|
||||||
Поддерживает Rotary Position Encoding (RoPE) для относительного позиционного кодирования.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
emb_size (int): размер входного эмбеддинга
|
|
||||||
head_size (int): размерность attention-головы
|
|
||||||
max_seq_len (int): максимальная длина последовательности
|
|
||||||
rope (RoPE, optional): экземпляр RoPE для позиций
|
|
||||||
|
|
||||||
Примечания:
|
|
||||||
- Использует нижнетреугольную маску для предотвращения "заглядывания в будущее"
|
|
||||||
- Автоматически адаптируется к разным версиям PyTorch
|
|
||||||
- Поддерживает batch-обработку входных данных
|
|
||||||
|
|
||||||
Пример использования:
|
|
||||||
>>> attention = HeadAttention(emb_size=64, head_size=32, max_seq_len=128)
|
|
||||||
>>> x = torch.randn(1, 10, 64)
|
|
||||||
>>> output, _ = attention(x)
|
|
||||||
>>> print(output.shape) # torch.Size([1, 10, 32])
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self, emb_size: int, head_size: int, max_seq_len: int, rope: RoPE = None
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self._emb_size = emb_size
|
|
||||||
self._head_size = head_size
|
|
||||||
self._max_seq_len = max_seq_len
|
|
||||||
self._rope = rope
|
|
||||||
|
|
||||||
# Линейные преобразования для Q, K, V
|
|
||||||
self._k = nn.Linear(emb_size, head_size)
|
|
||||||
self._q = nn.Linear(emb_size, head_size)
|
|
||||||
self._v = nn.Linear(emb_size, head_size)
|
|
||||||
|
|
||||||
# Создание causal маски
|
|
||||||
mask = torch.tril(torch.ones(max_seq_len, max_seq_len))
|
|
||||||
self.register_buffer(
|
|
||||||
"_tril_mask", mask.bool() if hasattr(torch, "bool") else mask.byte()
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self, x: torch.Tensor, use_cache: bool = True, cache: tuple = None
|
|
||||||
) -> tuple:
|
|
||||||
"""
|
|
||||||
Прямой проход через слой внимания.
|
|
||||||
|
|
||||||
Аргументы:
|
|
||||||
x (torch.Tensor): Входной тензор формы [batch_size, seq_len, emb_size]
|
|
||||||
|
|
||||||
Возвращает:
|
|
||||||
torch.Tensor: Выходной тензор формы [batch_size, seq_len, head_size]
|
|
||||||
|
|
||||||
Исключения:
|
|
||||||
ValueError: Если длина последовательности превышает max_seq_len
|
|
||||||
|
|
||||||
Пример внутренних преобразований:
|
|
||||||
Для входа x.shape = [2, 5, 64]:
|
|
||||||
1. Q/K/V преобразования -> [2, 5, 32]
|
|
||||||
2. Scores = Q·K^T -> [2, 5, 5]
|
|
||||||
3. После маски и softmax -> [2, 5, 5]
|
|
||||||
4. Умножение на V -> [2, 5, 32]
|
|
||||||
"""
|
|
||||||
seq_len = x.shape[1]
|
|
||||||
if seq_len > self._max_seq_len:
|
|
||||||
raise ValueError(
|
|
||||||
f"Длина последовательности {seq_len} превышает максимум {self._max_seq_len}"
|
|
||||||
)
|
|
||||||
|
|
||||||
k = self._k(x) # [B, T, hs]
|
|
||||||
q = self._q(x) # [B, T, hs]
|
|
||||||
v = self._v(x) # [B, T, hs]
|
|
||||||
|
|
||||||
start_pos = 0
|
|
||||||
if cache is not None:
|
|
||||||
k_cache, v_cache = cache
|
|
||||||
cache_len = k_cache.shape[1]
|
|
||||||
start_pos = cache_len
|
|
||||||
|
|
||||||
if self._rope is not None:
|
|
||||||
# ✅ Применяем RoPE к Q и K (НЕ к V!)
|
|
||||||
q = self._rope(q, start_pos=start_pos) # [B, T, hs]
|
|
||||||
k = self._rope(k, start_pos=start_pos) # [B, T, hs]
|
|
||||||
|
|
||||||
if cache is not None:
|
|
||||||
k_cache, v_cache = cache
|
|
||||||
k = torch.cat([k_cache, k], dim=1) # [B, cache_len + T, hs]
|
|
||||||
v = torch.cat([v_cache, v], dim=1) # [B, cache_len + T, hs]
|
|
||||||
|
|
||||||
scores = q @ k.transpose(-2, -1) / sqrt(self._head_size)
|
|
||||||
|
|
||||||
if cache is None:
|
|
||||||
scores = scores.masked_fill(
|
|
||||||
~self._tril_mask[:seq_len, :seq_len], float("-inf")
|
|
||||||
)
|
|
||||||
|
|
||||||
weights = F.softmax(scores, dim=-1)
|
|
||||||
x_out = weights @ v # [B, T, hs]
|
|
||||||
|
|
||||||
if use_cache is True:
|
|
||||||
return (x_out, (k, v))
|
|
||||||
else:
|
|
||||||
return (x_out, None)
|
|
||||||
@@ -1,37 +1,70 @@
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
import torch
|
import torch
|
||||||
from .head_attention import HeadAttention
|
import torch.nn.functional as F
|
||||||
from .rope import RoPE
|
from .rope import RoPE
|
||||||
|
|
||||||
|
|
||||||
class MultiHeadAttention(nn.Module):
|
class MultiHeadAttention(nn.Module):
|
||||||
"""
|
"""
|
||||||
Мультиголовый (многоголовый) механизм внимания — ключевой компонент любого Transformer.
|
Multi-Head Attention (Многоголовое внимание)
|
||||||
|
============================================
|
||||||
|
|
||||||
Научная суть:
|
Что такое Multi-Head Attention?
|
||||||
- Модель параллельно агрегирует информацию через несколько подпространств (головы),
|
-------------------------------
|
||||||
чтобы видеть разные связи в последовательности (разный контекст, локально/глобально).
|
Это ключевой компонент трансформеров, который позволяет "смотреть" на разные части предложения
|
||||||
- Каждый attention блок работает независимо, выход конкатенируется.
|
одновременно с нескольких независимых ракурсов ("голов"). Всё, что делает Single-Head Attention — только гораздо мощнее и глубже!
|
||||||
- Механизм предложен в статье "Attention is All You Need" (Vaswani et al., 2017).
|
|
||||||
|
|
||||||
Формула внимания для одной головы:
|
Зачем это нужно?
|
||||||
Attention(Q, K, V) = softmax(QK^T/sqrt(d_k))·V
|
----------------
|
||||||
Мультиголовый:
|
- Модель может учиться одновременно учитывать и локальные, и глобальные взаимосвязи между токенами.
|
||||||
MultiHead(Q, K, V) = Concat([head_i])*W^O
|
- Каждая attention head "ловит" свой собственный смысл/зависимости, и на выходе они объединяются.
|
||||||
|
- Это значительно улучшает понимание сложных зависимостей в тексте, особенно на длинных последовательностях.
|
||||||
|
|
||||||
Args:
|
Как работает алгоритм? (основная схема)
|
||||||
num_heads (int): количество attention "голов"
|
---------------------------------------
|
||||||
emb_size (int): размерности входа и выхода
|
1. Генерируются Q, K, V (query, key, value) — по отдельной проекции для каждой головы.
|
||||||
head_size (int): размер одной attention-головы (emb_size/num_heads)
|
2. Для каждой головы: attention(Q, K, V) = softmax(Q·K^T / sqrt(d)) · V
|
||||||
max_seq_len (int): максимальная длина последовательности
|
3. Все головы "склеиваются" (concatenate) и прогоняются через общий финальный линейный слой.
|
||||||
rope (RoPE, optional): если задан, используется Rotary Positional Encoding
|
|
||||||
dropout (float): вероятность регуляризации
|
Почему это работает?
|
||||||
|
--------------------
|
||||||
|
- Даёт трансформеру многомерное восприятие текста.
|
||||||
|
- Позволяет эффективно обучаться на задачах, где порядок и "дальние" связи важнее, чем простое соседство.
|
||||||
|
|
||||||
|
Что принимается на вход:
|
||||||
|
------------------------
|
||||||
|
- x: shape [batch, seq_len, embed_dim] — обычный batched-embed тензор.
|
||||||
|
- mask (опционально): shape [seq_len, seq_len] — маска для автогерерации или causal attention.
|
||||||
|
|
||||||
|
Какие параметры важны:
|
||||||
|
----------------------
|
||||||
|
- num_heads: сколько attention heads внутри (обычно 4, 8, 16...).
|
||||||
|
- embed_dim: исходная размерность входного тензора.
|
||||||
|
- head_size: размер одной attention-head (обычно embed_dim // num_heads).
|
||||||
|
- max_seq_len: максимальная длина последовательности для маски.
|
||||||
|
|
||||||
|
Что возвращает:
|
||||||
|
---------------
|
||||||
|
- output: shape [batch, seq_len, embed_dim] — результат применения всех attention heads.
|
||||||
|
- (опционально) cache: кэш для Q/K/V (нужно для генерации по одному токену).
|
||||||
|
|
||||||
|
Особенности реализации:
|
||||||
|
-----------------------
|
||||||
|
- Оптимизированно работает через матричные умножения (без python for циклов!).
|
||||||
|
- Включена поддержка causal attention (маска, предотвращающая «заглядывание в будущее»).
|
||||||
|
- Является ядром любого трансформера (и LLM!).
|
||||||
|
|
||||||
Пример использования:
|
Пример использования:
|
||||||
>>> mha = MultiHeadAttention(num_heads=8, emb_size=512, head_size=64, max_seq_len=1024)
|
---------------------
|
||||||
>>> x = torch.randn(2, 50, 512)
|
>>> attn = MultiHeadAttention(num_heads=8, embed_dim=256, head_size=32, max_seq_len=1024)
|
||||||
>>> out, cache = mha(x)
|
>>> x = torch.randn(2, 128, 256) # [batch, seq_len, embed_dim]
|
||||||
>>> print(out.shape)
|
>>> context, _ = attn(x)
|
||||||
|
>>> print(context.shape) # torch.Size([2, 128, 256])
|
||||||
|
|
||||||
|
Где прочитать подробнее:
|
||||||
|
-------------------------
|
||||||
|
- Attention is All You Need (Vaswani et al, 2017): https://arxiv.org/abs/1706.03762
|
||||||
|
- Illustrated Transformer (blog): https://jalammar.github.io/illustrated-transformer/
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -44,32 +77,59 @@ class MultiHeadAttention(nn.Module):
|
|||||||
dropout: float = 0.1,
|
dropout: float = 0.1,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Инициализация многоголового внимания.
|
Конструктор многоголового внимания (MultiHeadAttention).
|
||||||
|
|
||||||
Параметры:
|
Здесь создаются все параметры и внутренние слои для эффективного параллельного внимания (attention) сразу из нескольких "голов".
|
||||||
num_heads (int): Количество голов внимания. Типичные значения: 4-16
|
|
||||||
emb_size (int): Размерность входных и выходных эмбеддингов
|
|
||||||
head_size (int): Размерность каждой головы внимания (обычно emb_size // num_heads)
|
|
||||||
max_seq_len (int): Максимальная длина последовательности
|
|
||||||
dropout (float): Вероятность dropout (по умолчанию 0.1)
|
|
||||||
|
|
||||||
Контрольные значения:
|
Аргументы:
|
||||||
- num_heads * head_size должно равняться emb_size
|
----------
|
||||||
- head_size обычно выбирают 32-128
|
num_heads : int
|
||||||
- max_seq_len зависит от задачи (512 для BERT, 2048 для GPT-3)
|
Сколько attention-heads будет внутри слоя.
|
||||||
|
Каждая “голова” учится видеть уникальные зависимости в тексте. Обычно это 4, 8, 16 и т.п.
|
||||||
|
Чем больше голов — тем богаче контекст, но и больше памяти.
|
||||||
|
emb_size : int
|
||||||
|
Сколько float-значений в каждом входном векторе (размерность embedding).
|
||||||
|
Обычно это 256, 512, 768, 1024 и т.д.
|
||||||
|
head_size : int
|
||||||
|
Сколько компонент будет у каждой головы внимания.
|
||||||
|
Важно: num_heads * head_size должно ровно совпадать с emb_size!
|
||||||
|
Обычно head_size = emb_size // num_heads.
|
||||||
|
max_seq_len : int
|
||||||
|
Максимально допустимая длина последовательности для attention/маски/генерации.
|
||||||
|
Определяет размер буферов для causal mask.
|
||||||
|
rope : RoPE, по умолчанию None
|
||||||
|
Объект Rotary Positional Encoding (если хотите привнести продвинутое позиционное кодирование в attention).
|
||||||
|
Не обязателен, но нужен для современных LLM (Llama, Mistral и пр.).
|
||||||
|
dropout : float, по умолчанию 0.1
|
||||||
|
Величина dropout (регуляризации) — помогает борьбе с переобучением. Чем больше, тем сильнее регуляризация.
|
||||||
|
|
||||||
|
Внутри конструктора происходит:
|
||||||
|
-------------------------------
|
||||||
|
- Создаются три линейных слоя для Q, K, V (“где смотреть” и “что вытаскивать” в attention).
|
||||||
|
- Генерируется нижнетреугольная causal-маска (запрещает видеть будущее для автогерерации).
|
||||||
|
- Создаётся финальный линейный слой для склейки всех голов в одно пространство emb_size.
|
||||||
|
- Вводится dropout (случайное зануление, чтобы не было сильной зависимости внимания к отдельным "плейсам").
|
||||||
|
|
||||||
|
Пример:
|
||||||
|
-------
|
||||||
|
>>> attn = MultiHeadAttention(num_heads=8, emb_size=256, head_size=32, max_seq_len=1024)
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._heads = nn.ModuleList(
|
self._num_heads = num_heads
|
||||||
[
|
self._head_size = head_size
|
||||||
HeadAttention(
|
self._max_seq_len = max_seq_len
|
||||||
emb_size=emb_size,
|
self._rope = rope
|
||||||
head_size=head_size,
|
|
||||||
max_seq_len=max_seq_len,
|
self._q = nn.Linear(emb_size, num_heads * head_size)
|
||||||
rope=rope,
|
self._k = nn.Linear(emb_size, num_heads * head_size)
|
||||||
)
|
self._v = nn.Linear(emb_size, num_heads * head_size)
|
||||||
for _ in range(num_heads)
|
|
||||||
]
|
# Создание causal маски
|
||||||
|
mask = torch.tril(torch.ones(max_seq_len, max_seq_len))
|
||||||
|
self.register_buffer(
|
||||||
|
"_tril_mask", mask.bool() if hasattr(torch, "bool") else mask.byte()
|
||||||
)
|
)
|
||||||
|
|
||||||
self._layer = nn.Linear(head_size * num_heads, emb_size)
|
self._layer = nn.Linear(head_size * num_heads, emb_size)
|
||||||
self._dropout = nn.Dropout(dropout)
|
self._dropout = nn.Dropout(dropout)
|
||||||
|
|
||||||
@@ -81,61 +141,116 @@ class MultiHeadAttention(nn.Module):
|
|||||||
cache: list = None,
|
cache: list = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Прямой проход (forward):
|
Основной шаг \"многоголового внимания\": находит взаимосвязи между токенами
|
||||||
Для каждого токена оценивает "важность" остальных токенов сразу через несколько attention-блоков.
|
в последовательности сразу из нескольких “ракурсов” (attention heads).
|
||||||
|
|
||||||
Подробное описание преобразований тензоров:
|
Что делает этот метод:
|
||||||
1. Входной тензор [batch_size, seq_len, emb_size] разделяется на N голов:
|
----------------------
|
||||||
- Каждая голова получает тензор [batch_size, seq_len, head_size]
|
- Для каждого токена сравнивает его с остальными во входной последовательности.
|
||||||
2. Каждая голова вычисляет attention:
|
- Делает это одновременно через несколько attention heads (каждая head видит текст по-своему).
|
||||||
- Вход: [batch_size, seq_len, head_size]
|
- Итоговое “внимание” — это взвешенная сумма других токенов (контекста) для каждого токена.
|
||||||
- Выход: [batch_size, seq_len, head_size]
|
- Можно использовать кэш для генерации длинных последовательностей по одному токену (ускоряет инференс).
|
||||||
3. Конкатенация результатов:
|
|
||||||
- Объединенный выход: [batch_size, seq_len, num_heads * head_size]
|
|
||||||
4. Линейная проекция:
|
|
||||||
- Выход: [batch_size, seq_len, emb_size]
|
|
||||||
5. Применение dropout
|
|
||||||
|
|
||||||
Args:
|
Аргументы:
|
||||||
x (Tensor[float]): [batch, seq_len, emb_size] — вход
|
----------
|
||||||
mask (Optional[Tensor[bool]]): маска позиции [seq_len, seq_len]
|
x : torch.Tensor
|
||||||
use_cache (bool): использовать ли key-value кэш (для генерации)
|
Входной тензор формы [batch, seq_len, emb_size].
|
||||||
cache (list): предыдущие значения KV для ускорения
|
Это ваши входные эмбеддинги (обычно после token + positional embedding).
|
||||||
|
mask : torch.Tensor, опционально
|
||||||
|
Матрица формы [seq_len, seq_len], задающая “разрешения” — кто может смотреть на кого (например, causal mask).
|
||||||
|
Если не указана — используется внутренняя маска (например, для autoregressive генерации).
|
||||||
|
use_cache : bool, по умолчанию True
|
||||||
|
Нужно ли использовать кэш для KV attention (важно для ускорения генерации по одному токену).
|
||||||
|
cache : list, опционально
|
||||||
|
Предыдущий кэш Key/Value — для генерации текста по частям.
|
||||||
|
|
||||||
Returns:
|
Возвращает:
|
||||||
out (Tensor[float]): [batch, seq_len, emb_size] — результат MHA
|
-----------
|
||||||
kv_caches (list): списки новых KV-кэшей (если используется)
|
- output: torch.Tensor формы [batch, seq_len, emb_size] — результат применения multi-head attention.
|
||||||
|
- kv_caches: список новых KV для кэширования при генерации (или None).
|
||||||
|
|
||||||
Типичный паттерн:
|
Важно:
|
||||||
Вход: [batch, seq, emb] → N голов [batch, seq, head_size] →
|
-------
|
||||||
→ concat [batch, seq, N*head_size] → проекция → dropout
|
- Shape входа всегда [batch, seq_len, emb_size], выход тот же.
|
||||||
|
- При seq_len > max_seq_len выбросит ошибку (безопасно для контроля переполнения буферов).
|
||||||
Пример преобразований для emb_size=512, num_heads=8:
|
- При использовании use_cache=True кешируется только последние токены (актуально для LLM).
|
||||||
Вход: [4, 100, 512]
|
|
||||||
-> Каждая голова: [4, 100, 64]
|
|
||||||
-> После внимания: 8 x [4, 100, 64]
|
|
||||||
-> Конкатенация: [4, 100, 512]
|
|
||||||
-> Проекция: [4, 100, 512]
|
|
||||||
-> Dropout: [4, 100, 512]
|
|
||||||
|
|
||||||
Пример:
|
Пример:
|
||||||
>>> out, caches = mha(x)
|
-------
|
||||||
>>> out.shape # [batch, seq_len, emb_size]
|
>>> attn = MultiHeadAttention(num_heads=8, emb_size=256, head_size=32, max_seq_len=1024)
|
||||||
|
>>> x = torch.randn(2, 100, 256)
|
||||||
|
>>> y, kv_cache = attn(x)
|
||||||
|
>>> print(y.shape) # torch.Size([2, 100, 256])
|
||||||
"""
|
"""
|
||||||
# 1. Вычисляем attention для каждой головы
|
batch_size, seq_len, emb_size = x.shape
|
||||||
attention_results = []
|
|
||||||
for i, head in enumerate(self._heads):
|
|
||||||
head_cache = cache[i] if cache is not None else None
|
|
||||||
result = head(x, use_cache=use_cache, cache=head_cache)
|
|
||||||
attention_results.append(result)
|
|
||||||
|
|
||||||
outputs, caches = zip(*attention_results)
|
if seq_len > self._max_seq_len:
|
||||||
attention_outputs = list(outputs)
|
raise ValueError(
|
||||||
kv_caches = list(caches)
|
f"Длина последовательности {seq_len} превышает максимум {self._max_seq_len}"
|
||||||
|
)
|
||||||
|
|
||||||
# 2. Объединяем результаты всех голов
|
# Пропустите тензор x через матрицы Wq, Wk , Wv, чтобы получить матрицы запроса, ключа и значения.
|
||||||
concatenated_attention = torch.cat(attention_outputs, dim=-1)
|
k = self._k(x) # [B, T, hs]
|
||||||
|
q = self._q(x) # [B, T, hs]
|
||||||
|
v = self._v(x) # [B, T, hs]
|
||||||
|
|
||||||
|
# Шаг 2: Изменение формы для multi-head
|
||||||
|
# [batch_size, seq_len, num_heads * head_size]
|
||||||
|
# -> [batch_size, seq_len, num_heads, head_size]
|
||||||
|
q = q.reshape(batch_size, seq_len, self._num_heads, self._head_size)
|
||||||
|
k = k.reshape(batch_size, seq_len, self._num_heads, self._head_size)
|
||||||
|
v = v.reshape(batch_size, seq_len, self._num_heads, self._head_size)
|
||||||
|
|
||||||
|
|
||||||
|
# 3. Transpose: [B, T, H, hs] -> [B, H, T, hs]
|
||||||
|
q = q.transpose(1, 2)
|
||||||
|
k = k.transpose(1, 2)
|
||||||
|
v = v.transpose(1, 2)
|
||||||
|
|
||||||
|
start_pos = 0
|
||||||
|
if cache is not None:
|
||||||
|
k_cache, v_cache = cache
|
||||||
|
cache_len = k_cache.shape[2]
|
||||||
|
start_pos = cache_len
|
||||||
|
|
||||||
|
# Пропустите матрицы запроса и ключа через экземпляр rope, чтобы выполнить поворот.
|
||||||
|
if self._rope is not None:
|
||||||
|
# ✅ Применяем RoPE к Q и K (НЕ к V!)
|
||||||
|
q = self._rope(q, start_pos=start_pos) # [B, T, hs]
|
||||||
|
k = self._rope(k, start_pos=start_pos) # [B, T, hs]
|
||||||
|
|
||||||
|
# Если cache пришел, то объединяем кэш и одну строку из ключа и значения. Это будут новые key и value для последующих вычислений.
|
||||||
|
# 5. Кэширование (для autoregressive generation)
|
||||||
|
if cache is not None:
|
||||||
|
k_cache, v_cache = cache
|
||||||
|
k = torch.cat([k_cache, k], dim=2) # Concat по seq_len (dim=2)
|
||||||
|
v = torch.cat([v_cache, v], dim=2)
|
||||||
|
|
||||||
|
# Перемножим матрицы запроса и ключа (транспонированную), чтобы вычислить матрицу внимания.
|
||||||
|
# И разделить все значения в матрице внимания на корень из head_size.
|
||||||
|
scores = q @ k.transpose(-2, -1) / (self._head_size ** 0.5)
|
||||||
|
|
||||||
|
# Если cache пришел, то маску не накладываем. Иначе наложите на матрицу внимания треугольную маску, созданную при инициализации. Все скрытые значения должны быть приведены к минус бесконечности: float('-inf').
|
||||||
|
if cache is None:
|
||||||
|
scores = scores.masked_fill(
|
||||||
|
~self._tril_mask[:seq_len, :seq_len], float("-inf")
|
||||||
|
)
|
||||||
|
|
||||||
|
# Применить к матрице внимания (построчно) функцию Softmax.
|
||||||
|
weights = F.softmax(scores, dim=-1)
|
||||||
|
|
||||||
|
# Перемножим матрицу внимания и матрицу значения.
|
||||||
|
x_out = weights @ v # [B, T, hs]
|
||||||
|
|
||||||
|
# Измените форму тензора на batch_size × seq_len × num_heads*head_size.
|
||||||
|
# Transpose обратно и concatenate heads
|
||||||
|
x_out = x_out.transpose(1, 2) # [B, T_q, H, hs]
|
||||||
|
x_out = x_out.contiguous() # Важно для reshape!
|
||||||
|
concatenated_attention = x_out.reshape(batch_size, seq_len, self._num_heads * self._head_size)
|
||||||
|
|
||||||
|
#concatenated_attention = x_out.reshape(batch_size, seq_len, self._num_heads * self._head_size)
|
||||||
|
|
||||||
|
# Пропустите получившийся тензор через последний линейный слой.
|
||||||
# 3. Проецируем в пространство эмбеддингов
|
# 3. Проецируем в пространство эмбеддингов
|
||||||
projected_output = self._layer(concatenated_attention)
|
projected_output = self._layer(concatenated_attention)
|
||||||
|
|
||||||
@@ -143,6 +258,6 @@ class MultiHeadAttention(nn.Module):
|
|||||||
final_output = self._dropout(projected_output)
|
final_output = self._dropout(projected_output)
|
||||||
|
|
||||||
if use_cache is True:
|
if use_cache is True:
|
||||||
return (final_output, kv_caches)
|
return (final_output, (k, v))
|
||||||
else:
|
else:
|
||||||
return (final_output, None)
|
return (final_output, None)
|
||||||
|
|||||||
@@ -1,21 +1,51 @@
|
|||||||
"""
|
"""
|
||||||
Rotary Positional Embeddings (RoPE) - ротационные позиционные эмбеддинги.
|
Rotary Positional Embeddings (RoPE)
|
||||||
|
===================================
|
||||||
|
|
||||||
Реализация ротационного позиционного кодирования, которое кодирует позиционную
|
Что такое RoPE?
|
||||||
информацию через вращение векторов запросов и ключей в комплексном пространстве.
|
----------------
|
||||||
|
RoPE — это способ "вписать" информацию о позиции токенов в скрытые вектора модели трансформера.
|
||||||
|
Вместо простого сложения с абсолютным positional embedding, RoPE использует вращения векторов (как поворот стрелки на круге) внутри каждого attention head. Каждый элемент пары (вектор четного и нечетного индекса) поворачивается на угол, зависящий от позиции токена.
|
||||||
|
|
||||||
Научная статья: "RoFormer: Enhanced Transformer with Rotary Position Embedding"
|
Зачем это?
|
||||||
https://arxiv.org/abs/2104.09864
|
-----------
|
||||||
|
- RoPE реализует **относительное позиционное кодирование**: модель может сравнивать "расстояния" между токенами, а не просто помнить положение.
|
||||||
|
- Такое кодирование **улучшает генерацию длинных последовательностей** и перенос модели на тексты большей длины, чем были в обучении.
|
||||||
|
- Форма векторов и длина (норма) НЕ искажаются.
|
||||||
|
|
||||||
Математическая основа:
|
Как это работает? (главная формула)
|
||||||
Для позиции m и измерения i:
|
-------------------------------------
|
||||||
θ_i = base^(-2i/d)
|
Для каждой позиции m и пары компонент (2i, 2i+1) внутри head применяются:
|
||||||
q'_m = q_m * cos(mθ_i) + rotate(q_m) * sin(mθ_i)
|
|
||||||
|
θ_i = base^(-2i / d)
|
||||||
|
q'_{m,2i} = q_{m,2i} * cos(m * θ_i) - q_{m,2i+1} * sin(m * θ_i)
|
||||||
|
q'_{m,2i+1} = q_{m,2i+1} * cos(m * θ_i) + q_{m,2i} * sin(m * θ_i)
|
||||||
|
|
||||||
|
где d — размерность "головы" attention (head_size), base обычно 10_000.
|
||||||
|
|
||||||
|
То есть, берём каждый "вектор" (в рамках head), делим на четные/нечетные части и поворачиваем их на уникальный угол, связанный с позицией/частотой.
|
||||||
|
|
||||||
|
Архитектурные детали:
|
||||||
|
---------------------
|
||||||
|
- Ваш тензор должен быть строго 4-мерным: [batch, num_heads, seq_len, head_size].
|
||||||
|
- Размер head_size должен быть чётным!
|
||||||
|
- RoPE применяется отдельно к **Q** и **K** в механизме внимания (но не к V).
|
||||||
|
|
||||||
|
Где об этом читать:
|
||||||
|
-------------------
|
||||||
|
- RoFormer: Enhanced Transformer with Rotary Position Embedding
|
||||||
|
https://arxiv.org/abs/2104.09864
|
||||||
|
- Llama: Open and Efficient Foundation Language Models
|
||||||
|
https://arxiv.org/abs/2302.13971
|
||||||
|
- Визуализация позиционных кодировок:
|
||||||
|
https://kazemnejad.com/blog/transformer_architecture_positional_encoding/
|
||||||
|
|
||||||
|
Пример использования:
|
||||||
|
---------------------
|
||||||
|
>>> rope = RoPE(head_size=64, max_seq_len=2048)
|
||||||
|
>>> x = torch.randn(2, 8, 128, 64) # [batch, num_heads, seq_len, head_size]
|
||||||
|
>>> x_enc = rope(x) # здесь вектор x обогатится позиционной информацией
|
||||||
|
|
||||||
Преимущества:
|
|
||||||
- Относительное позиционное кодирование
|
|
||||||
- Лучшая экстраполяция на длинные последовательности
|
|
||||||
- Сохранение нормы векторов
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -25,32 +55,72 @@ from typing import Optional
|
|||||||
|
|
||||||
class RoPE(nn.Module):
|
class RoPE(nn.Module):
|
||||||
"""
|
"""
|
||||||
Rotary Positional Embeddings (RoPE) для механизма внимания.
|
Реализация RoPE (Rotary Positional Embeddings) для self-attention в трансформерах.
|
||||||
|
|
||||||
Кодирует позиционную информацию через вращение векторов запросов и ключей
|
Этот слой добавляет позиционную информацию к векторам внимания (Q, K) —
|
||||||
в многомерном пространстве с использованием синусов и косинусов.
|
не с помощью простого сложения с positional embedding, а с помощью математического
|
||||||
|
вращения (как если бы вы крутили стрелку на круге) для каждой пары компонент
|
||||||
|
(even/odd) в каждом attention head.
|
||||||
|
|
||||||
Args:
|
Формула (для каждого токена и каждой пары компонент внутри head):
|
||||||
head_size: Размерность головы внимания (должен быть четным)
|
θ_i = base^(-2i / d)
|
||||||
max_seq_len: Максимальная длина последовательности
|
out_{m,2i} = x_{m,2i} * cos(m * θ_i) - x_{m,2i+1} * sin(m * θ_i)
|
||||||
base: Базовое значение для вычисления частот (по умолчанию 10000)
|
out_{m,2i+1} = x_{m,2i+1} * cos(m * θ_i) + x_{m,2i} * sin(m * θ_i)
|
||||||
|
где d — head_size, base обычно 10_000, степень i по head axis.
|
||||||
|
|
||||||
Attributes:
|
Какие входы принимает:
|
||||||
cos_matrix: Буферизованная матрица косинусов формы [max_seq_len, head_size//2]
|
----------------------
|
||||||
sin_matrix: Буферизованная матрица синусов формы [max_seq_len, head_size//2]
|
- x: обязательно размерности [batch, num_heads, seq_len, head_size]!
|
||||||
|
- head_size (размер внимания) должен быть чётным.
|
||||||
|
- start_pos: опционально, позволяет сдвигать позиционный offset для генерации с кэшем.
|
||||||
|
|
||||||
|
Что возвращает:
|
||||||
|
---------------
|
||||||
|
- Тот же тензор (x), только со встроенной позиционной информацией (“повёрнутый” RoPE-кодировкой).
|
||||||
|
- Форма и тип выходного тензора не меняются.
|
||||||
|
|
||||||
|
Где используется:
|
||||||
|
-----------------
|
||||||
|
- В любых современных LLM (Llama, Mistral, GPT-NeoX и др.) для повышения устойчивости и generalization transformer's attention.
|
||||||
|
|
||||||
|
Пример использования:
|
||||||
|
---------------------
|
||||||
|
>>> rope = RoPE(head_size=64, max_seq_len=2048)
|
||||||
|
>>> x = torch.randn(2, 8, 128, 64) # (batch, num_heads, seq_len, head_size)
|
||||||
|
>>> x_encoded = rope(x)
|
||||||
|
|
||||||
|
Подробнее про математику и примеры с визуализацией:
|
||||||
|
---------------------------------------------------
|
||||||
|
- RoFormer: https://arxiv.org/abs/2104.09864
|
||||||
|
- Llama: https://arxiv.org/abs/2302.13971
|
||||||
|
- Демонстрация наглядно: https://kazemnejad.com/blog/transformer_architecture_positional_encoding/
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, head_size: int, max_seq_len: int, base: int = 10_000):
|
def __init__(self, head_size: int, max_seq_len: int, base: int = 10_000):
|
||||||
"""
|
"""
|
||||||
Инициализация RoPE эмбеддингов.
|
Инициализация объекта RoPE — настраивает и предвычисляет все необходимые
|
||||||
|
параметры для ротационного позиционного кодирования.
|
||||||
|
|
||||||
Args:
|
Аргументы:
|
||||||
head_size: Размерность головы внимания (должен быть четным)
|
----------
|
||||||
max_seq_len: Максимальная поддерживаемая длина последовательности
|
head_size : int
|
||||||
base: Базовое значение для вычисления частот (типично 10000)
|
Размер одного attention head (последнего измерения вектора) — сколько компонент
|
||||||
|
(float-значений) отвечает за одну "голову". Должен быть ЧЁТНЫМ числом, иначе RoPE не применим.
|
||||||
|
Обычно head_size = embed_dim // num_heads.
|
||||||
|
max_seq_len : int
|
||||||
|
Максимальная длина последовательности, которую RoPE сможет обработать.
|
||||||
|
Если ваш текст длиннее этого числа — будет ошибка! Например, для GPT2 обычно 1024, у LLaMA — до 4096.
|
||||||
|
Это число определяет размер внутренних буферов cos/sin.
|
||||||
|
base : int, по умолчанию 10_000
|
||||||
|
База для вычисления частот вращения (θ_i) для каждой компоненты.
|
||||||
|
В оригинальных статьях почти всегда используют base=10000.
|
||||||
|
Менять этот параметр не нужно, если вы не исследуете математические детали.
|
||||||
|
|
||||||
Raises:
|
Что происходит внутри:
|
||||||
AssertionError: Если head_size не четный
|
----------------------
|
||||||
|
- Проверяется чётность head_size.
|
||||||
|
- Для каждого возможного положения в пределах max_seq_len и каждой пары component высчитываются уникальные cos/sin значения (матрицы частот).
|
||||||
|
- Эти матрицы используются далее для быстрого наложения позиционного "вращения" токенов внутри attention.
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert head_size % 2 == 0, "head_size должен быть четным"
|
assert head_size % 2 == 0, "head_size должен быть четным"
|
||||||
@@ -70,28 +140,51 @@ class RoPE(nn.Module):
|
|||||||
|
|
||||||
def forward(self, x: torch.Tensor, start_pos: int = 0) -> torch.Tensor:
|
def forward(self, x: torch.Tensor, start_pos: int = 0) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Применение ротационного позиционного кодирования к входному тензору.
|
Применяет ротационное позиционное кодирование (RoPE) к входному тензору.
|
||||||
|
|
||||||
Args:
|
Что делает эта функция:
|
||||||
x: Входной тензор формы [batch_size, seq_len, head_size]
|
-----------------------
|
||||||
|
Для каждого токена в последовательности внутри каждого attention head
|
||||||
|
"поворачивает" его вектор в подпространстве (even/odd пар) на свой уникальный угол,
|
||||||
|
зависящий от позиции токена. Это позволяет attention "понимать расстояния" между токенами.
|
||||||
|
|
||||||
Returns:
|
Аргументы:
|
||||||
Тензор с примененным RoPE формы [batch_size, seq_len, head_size]
|
----------
|
||||||
|
x : torch.Tensor
|
||||||
|
Входной тензор строго формы [batch, num_heads, seq_len, head_size].
|
||||||
|
Это обычно либо Q, либо K из механизма внимания.
|
||||||
|
start_pos : int, по умолчанию 0
|
||||||
|
Сдвиг начала позиции (нужно при генерации с кэшем, почти всегда оставить 0 если не пишете автогенератор).
|
||||||
|
|
||||||
Алгоритм:
|
Возвращает:
|
||||||
1. Разделение векторов на четные и нечетные компоненты
|
-----------
|
||||||
2. Применение вращения через синусы и косинусы
|
torch.Tensor с теми же формой и типом, что и x, но уже с наложенным позиционным кодированием.
|
||||||
3. Объединение компонент обратно
|
|
||||||
|
Важно:
|
||||||
|
-------
|
||||||
|
- Если передан тензор не 4D, будет выброшено исключение!
|
||||||
|
- Не изменяет значения "на месте", всегда возвращает новый тензор.
|
||||||
|
|
||||||
|
Пример:
|
||||||
|
-------
|
||||||
|
>>> rope = RoPE(head_size=64, max_seq_len=1024)
|
||||||
|
>>> q = torch.randn(2, 8, 32, 64) # batch, num_heads, seq_len, head_size
|
||||||
|
>>> q_rope = rope(q)
|
||||||
"""
|
"""
|
||||||
batch_size, seq_len, emb_size = x.shape
|
assert x.ndim == 4, "RoPE поддерживает только 4D-вход [batch, num_heads, seq_len, head_size]"
|
||||||
|
batch_size, num_heads, seq_len, head_size = x.shape
|
||||||
|
|
||||||
# Берем нужную часть матриц и приводим к типу x
|
# Берем нужную часть матриц и приводим к типу x
|
||||||
cos = self.cos_matrix[start_pos:start_pos+seq_len].to(x.dtype) # [seq_len, head_size//2]
|
cos = self.cos_matrix[start_pos:start_pos+seq_len].to(x.dtype) # [seq_len, head_size//2]
|
||||||
sin = self.sin_matrix[start_pos:start_pos+seq_len].to(x.dtype) # [seq_len, head_size//2]
|
sin = self.sin_matrix[start_pos:start_pos+seq_len].to(x.dtype) # [seq_len, head_size//2]
|
||||||
|
|
||||||
# Разделяем на четные и нечетные компоненты
|
# Явное изменение формы для broadcasting
|
||||||
x_even = x[:, :, 0::2] # [batch_size, seq_len, head_size//2]
|
cos = cos.reshape(1, 1, seq_len, head_size // 2)
|
||||||
x_odd = x[:, :, 1::2] # [batch_size, seq_len, head_size//2]
|
sin = sin.reshape(1, 1, seq_len, head_size // 2)
|
||||||
|
|
||||||
|
# Разделяем на четные и нечетные компоненты по ПОСЛЕДНЕМУ измерению
|
||||||
|
x_even = x[..., 0::2] # [batch_size, num_heads, seq_len, head_size//2]
|
||||||
|
x_odd = x[..., 1::2] # [batch_size, num_heads, seq_len, head_size//2]
|
||||||
|
|
||||||
# Применяем поворот: q' = q * cos(mθ) + rotate(q) * sin(mθ)
|
# Применяем поворот: q' = q * cos(mθ) + rotate(q) * sin(mθ)
|
||||||
x_rotated_even = x_even * cos - x_odd * sin
|
x_rotated_even = x_even * cos - x_odd * sin
|
||||||
|
|||||||
@@ -104,12 +104,20 @@ class GPT2(BaseModel):
|
|||||||
|
|
||||||
# Вычисление start_pos из кэша (если кэш передан)
|
# Вычисление start_pos из кэша (если кэш передан)
|
||||||
if cache is not None:
|
if cache is not None:
|
||||||
# При кэше обрабатываем только один токен (последний)
|
|
||||||
seq_len = 1
|
seq_len = 1
|
||||||
# Вычисляем start_pos из самого нижнего уровня кэша
|
# Безопасно извлекаем key_cache для вычисления start_pos
|
||||||
if cache and cache[0] and cache[0][0]:
|
if (
|
||||||
key_cache, _ = cache[0][0] # Первый декодер, первая голова
|
isinstance(cache, (list, tuple))
|
||||||
start_pos = key_cache.size(1) # cache_len
|
and len(cache) > 0
|
||||||
|
and cache[0] is not None
|
||||||
|
and isinstance(cache[0], (list, tuple))
|
||||||
|
and len(cache[0]) > 0
|
||||||
|
and cache[0][0] is not None
|
||||||
|
and isinstance(cache[0][0], (tuple, list))
|
||||||
|
and len(cache[0][0]) > 0
|
||||||
|
):
|
||||||
|
key_cache, _ = cache[0][0]
|
||||||
|
start_pos = key_cache.size(1)
|
||||||
else:
|
else:
|
||||||
start_pos = 0
|
start_pos = 0
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -4,71 +4,12 @@ from torch import Tensor
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from math import sqrt
|
from math import sqrt
|
||||||
from llm.core.base_model import BaseModel
|
from llm.core.base_model import BaseModel
|
||||||
|
from llm.core.token_embeddings import TokenEmbeddings
|
||||||
|
from llm.core.silu import SiLU
|
||||||
class SiLU(nn.Module):
|
from llm.core.rms_norm import RMSNorm
|
||||||
def forward(self, x: torch.Tensor): # [batch_size × seq_len × emb_size]
|
from llm.core.swi_glu import SwiGLU
|
||||||
return torch.sigmoid(x) * x
|
from llm.core.gelu import GELU
|
||||||
|
from llm.core.rope import RoPE
|
||||||
class RMSNorm(nn.Module):
|
|
||||||
def __init__(self, dim: int, eps: float = 1e-6):
|
|
||||||
super().__init__()
|
|
||||||
self._eps = eps
|
|
||||||
self._w = nn.Parameter(torch.ones(dim))
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor): # [batch_size × seq_len × emb_size]
|
|
||||||
rms = (x.pow(2).mean(-1, keepdim=True) + self._eps) ** 0.5
|
|
||||||
norm_x = x / rms
|
|
||||||
return self._w * norm_x
|
|
||||||
|
|
||||||
class SwiGLU(nn.Module):
|
|
||||||
def __init__(self, emb_size: int, dropout: float = 0.1):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self._gate = nn.Linear(emb_size, 4 * emb_size)
|
|
||||||
self._up = nn.Linear(emb_size, 4 * emb_size)
|
|
||||||
self._down = nn.Linear(4 * emb_size, emb_size)
|
|
||||||
self._activation = SiLU()
|
|
||||||
self._dropout = nn.Dropout(dropout)
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor): # [batch_size × seq_len × emb_size].
|
|
||||||
gate_out = self._gate(x) # [batch, seq, 4*emb]
|
|
||||||
activation_out = self._activation(gate_out) # [batch, seq, 4*emb]
|
|
||||||
up_out = self._up(x) # [batch, seq, 4*emb]
|
|
||||||
out = up_out * activation_out # поэлементное!
|
|
||||||
out = self._down(out) # [batch, seq, emb]
|
|
||||||
return self._dropout(out)
|
|
||||||
|
|
||||||
|
|
||||||
class TokenEmbeddings(nn.Module):
|
|
||||||
def __init__(self, vocab_size: int, emb_size: int):
|
|
||||||
super().__init__()
|
|
||||||
self._embedding = nn.Embedding(
|
|
||||||
num_embeddings=vocab_size,
|
|
||||||
embedding_dim=emb_size
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
|
||||||
return self._embedding(x)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def num_embeddings(self) -> int:
|
|
||||||
return self._embedding.num_embeddings
|
|
||||||
|
|
||||||
@property
|
|
||||||
def embedding_dim(self) -> int:
|
|
||||||
return self._embedding.embedding_dim
|
|
||||||
|
|
||||||
|
|
||||||
class GELU(nn.Module):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
self.sqrt_2_over_pi = torch.sqrt(torch.tensor(2.0) / math.pi)
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
||||||
return 0.5 * x * (1 + torch.tanh(
|
|
||||||
self.sqrt_2_over_pi * (x + 0.044715 * torch.pow(x, 3))
|
|
||||||
))
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@@ -77,49 +18,49 @@ from torch import nn
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
class RoPE(nn.Module):
|
#class RoPE(nn.Module):
|
||||||
|
#
|
||||||
def __init__(self, head_size: int, max_seq_len: int, base: int = 10_000):
|
# def __init__(self, head_size: int, max_seq_len: int, base: int = 10_000):
|
||||||
super().__init__()
|
# super().__init__()
|
||||||
assert head_size % 2 == 0, "head_size должен быть четным"
|
# assert head_size % 2 == 0, "head_size должен быть четным"
|
||||||
|
#
|
||||||
# Вычисление частот: θ_i = base^(-2i/d) для i ∈ [0, d/2-1]
|
# # Вычисление частот: θ_i = base^(-2i/d) для i ∈ [0, d/2-1]
|
||||||
freqs = 1.0 / (base ** (2 * torch.arange(head_size // 2).float() / head_size))
|
# freqs = 1.0 / (base ** (2 * torch.arange(head_size // 2).float() / head_size))
|
||||||
|
#
|
||||||
# Позиции от 0 до max_seq_len-1
|
# # Позиции от 0 до max_seq_len-1
|
||||||
positions = torch.arange(max_seq_len).float()
|
# positions = torch.arange(max_seq_len).float()
|
||||||
|
#
|
||||||
# Внешнее произведение: m * θ_i для всех позиций и частот
|
# # Внешнее произведение: m * θ_i для всех позиций и частот
|
||||||
freq_matrix = positions.unsqueeze(1) * freqs.unsqueeze(0)
|
# freq_matrix = positions.unsqueeze(1) * freqs.unsqueeze(0)
|
||||||
|
#
|
||||||
# Предвычисление матриц косинусов и синусов
|
# # Предвычисление матриц косинусов и синусов
|
||||||
self.register_buffer("cos_matrix", torch.cos(freq_matrix))
|
# self.register_buffer("cos_matrix", torch.cos(freq_matrix))
|
||||||
self.register_buffer("sin_matrix", torch.sin(freq_matrix))
|
# self.register_buffer("sin_matrix", torch.sin(freq_matrix))
|
||||||
|
#
|
||||||
def forward(self, x: torch.Tensor, start_pos: int = 0) -> torch.Tensor: # [batch_size × seq_len × head_size] [batch_size × num_heads × seq_len × head_size]
|
# def forward(self, x: torch.Tensor, start_pos: int = 0) -> torch.Tensor: # [batch_size × seq_len × head_size] [batch_size × num_heads × seq_len × head_size]
|
||||||
batch_size, num_heads, seq_len, head_size = x.shape
|
# batch_size, num_heads, seq_len, head_size = x.shape
|
||||||
|
#
|
||||||
# Берем нужную часть матриц и приводим к типу x
|
# # Берем нужную часть матриц и приводим к типу x
|
||||||
cos = self.cos_matrix[start_pos:start_pos+seq_len].to(x.dtype) # [seq_len, head_size//2]
|
# cos = self.cos_matrix[start_pos:start_pos+seq_len].to(x.dtype) # [seq_len, head_size//2]
|
||||||
sin = self.sin_matrix[start_pos:start_pos+seq_len].to(x.dtype) # [seq_len, head_size//2]
|
# sin = self.sin_matrix[start_pos:start_pos+seq_len].to(x.dtype) # [seq_len, head_size//2]
|
||||||
|
#
|
||||||
# Явное изменение формы для broadcasting
|
# # Явное изменение формы для broadcasting
|
||||||
cos = cos.reshape(1, 1, seq_len, head_size // 2)
|
# cos = cos.reshape(1, 1, seq_len, head_size // 2)
|
||||||
sin = sin.reshape(1, 1, seq_len, head_size // 2)
|
# sin = sin.reshape(1, 1, seq_len, head_size // 2)
|
||||||
|
#
|
||||||
# Разделяем на четные и нечетные компоненты по ПОСЛЕДНЕМУ измерению
|
# # Разделяем на четные и нечетные компоненты по ПОСЛЕДНЕМУ измерению
|
||||||
x_even = x[..., 0::2] # [batch_size, num_heads, seq_len, head_size//2]
|
# x_even = x[..., 0::2] # [batch_size, num_heads, seq_len, head_size//2]
|
||||||
x_odd = x[..., 1::2] # [batch_size, num_heads, seq_len, head_size//2]
|
# x_odd = x[..., 1::2] # [batch_size, num_heads, seq_len, head_size//2]
|
||||||
|
#
|
||||||
# Применяем поворот: q' = q * cos(mθ) + rotate(q) * sin(mθ)
|
# # Применяем поворот: q' = q * cos(mθ) + rotate(q) * sin(mθ)
|
||||||
x_rotated_even = x_even * cos - x_odd * sin
|
# x_rotated_even = x_even * cos - x_odd * sin
|
||||||
x_rotated_odd = x_even * sin + x_odd * cos
|
# x_rotated_odd = x_even * sin + x_odd * cos
|
||||||
|
#
|
||||||
# Объединяем обратно в исходную размерность
|
# # Объединяем обратно в исходную размерность
|
||||||
x_rotated = torch.stack([x_rotated_even, x_rotated_odd], dim=-1)
|
# x_rotated = torch.stack([x_rotated_even, x_rotated_odd], dim=-1)
|
||||||
x_rotated = x_rotated.flatten(-2) # [batch_size, seq_len, head_size]
|
# x_rotated = x_rotated.flatten(-2) # [batch_size, seq_len, head_size]
|
||||||
|
#
|
||||||
return x_rotated
|
# return x_rotated
|
||||||
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ class TestMultiHeadAttention:
|
|||||||
assert attention is not None
|
assert attention is not None
|
||||||
|
|
||||||
# Check internal attributes
|
# Check internal attributes
|
||||||
assert len(attention._heads) == num_heads
|
assert attention._num_heads == num_heads
|
||||||
assert attention._layer.in_features == embed_dim
|
assert attention._layer.in_features == embed_dim
|
||||||
assert attention._layer.out_features == embed_dim
|
assert attention._layer.out_features == embed_dim
|
||||||
|
|
||||||
@@ -102,8 +102,10 @@ class TestMultiHeadAttention:
|
|||||||
|
|
||||||
# Check that gradients are computed for learnable parameters
|
# Check that gradients are computed for learnable parameters
|
||||||
assert attention._layer.weight.grad is not None
|
assert attention._layer.weight.grad is not None
|
||||||
if len(attention._heads) > 0:
|
# Проверяем, что также у градиентов весов q/k/v есть значения
|
||||||
assert attention._heads[0]._q.weight.grad is not None
|
assert attention._q.weight.grad is not None
|
||||||
|
assert attention._k.weight.grad is not None
|
||||||
|
assert attention._v.weight.grad is not None
|
||||||
|
|
||||||
def test_device_consistency(self, embed_dim, num_heads, random_embeddings, device):
|
def test_device_consistency(self, embed_dim, num_heads, random_embeddings, device):
|
||||||
"""Test that MultiHeadAttention works on correct device."""
|
"""Test that MultiHeadAttention works on correct device."""
|
||||||
|
|||||||
55
llm/tests/core/test_rope.py
Normal file
55
llm/tests/core/test_rope.py
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
import torch
|
||||||
|
import pytest
|
||||||
|
from llm.core.rope import RoPE
|
||||||
|
|
||||||
|
def test_rope_shapes_and_dtype():
|
||||||
|
rope = RoPE(head_size=8, max_seq_len=32)
|
||||||
|
x = torch.randn(2, 4, 16, 8) # [batch, num_heads, seq_len, head_size]
|
||||||
|
y = rope(x)
|
||||||
|
assert y.shape == x.shape
|
||||||
|
assert y.dtype == x.dtype
|
||||||
|
|
||||||
|
def test_rope_raises_on_bad_ndim():
|
||||||
|
rope = RoPE(head_size=8, max_seq_len=16)
|
||||||
|
x = torch.randn(2, 16, 8) # [batch, seq_len, head_size] (3D)
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
_ = rope(x)
|
||||||
|
|
||||||
|
def test_rope_preserves_norm():
|
||||||
|
rope = RoPE(head_size=8, max_seq_len=16)
|
||||||
|
x = torch.randn(2, 3, 7, 8)
|
||||||
|
x_norm = x.norm(dim=-1)
|
||||||
|
y = rope(x)
|
||||||
|
y_norm = y.norm(dim=-1)
|
||||||
|
# Нормы могут немного отличаться из-за float, сравниваем с допуском
|
||||||
|
assert torch.allclose(x_norm, y_norm, rtol=1e-5, atol=1e-7)
|
||||||
|
|
||||||
|
def test_rope_backward_pass():
|
||||||
|
rope = RoPE(head_size=8, max_seq_len=16)
|
||||||
|
x = torch.randn(2, 2, 8, 8, requires_grad=True)
|
||||||
|
out = rope(x)
|
||||||
|
loss = out.sum()
|
||||||
|
loss.backward()
|
||||||
|
assert x.grad is not None
|
||||||
|
assert x.grad.shape == x.shape
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("batch,num_heads,seq_len,head_size", [
|
||||||
|
(1, 1, 4, 8),
|
||||||
|
(2, 4, 16, 8),
|
||||||
|
(3, 2, 7, 8),
|
||||||
|
])
|
||||||
|
def test_rope_various_shapes(batch, num_heads, seq_len, head_size):
|
||||||
|
rope = RoPE(head_size=head_size, max_seq_len=32)
|
||||||
|
x = torch.randn(batch, num_heads, seq_len, head_size)
|
||||||
|
y = rope(x)
|
||||||
|
assert y.shape == x.shape
|
||||||
|
|
||||||
|
def test_rope_start_pos():
|
||||||
|
rope = RoPE(head_size=8, max_seq_len=32)
|
||||||
|
x_full = torch.randn(1, 2, 8, 8)
|
||||||
|
# Сравниваем участок результата для разных start_pos
|
||||||
|
out1 = rope(x_full)
|
||||||
|
out2 = rope(x_full, start_pos=2)
|
||||||
|
assert not torch.allclose(out1, out2)
|
||||||
|
# Для одинакового start_pos и x должны совпадать
|
||||||
|
assert torch.allclose(rope(x_full, start_pos=1), rope(x_full, start_pos=1))
|
||||||
@@ -145,7 +145,11 @@ class TestGPT:
|
|||||||
assert model._token_embeddings._embedding.weight.grad is not None
|
assert model._token_embeddings._embedding.weight.grad is not None
|
||||||
assert model._linear.weight.grad is not None
|
assert model._linear.weight.grad is not None
|
||||||
if len(model._decoders) > 0:
|
if len(model._decoders) > 0:
|
||||||
assert model._decoders[0]._heads._heads[0]._q.weight.grad is not None
|
# Проверяем через новый интерфейс attention оптимизации:
|
||||||
|
attn = model._decoders[0]._heads
|
||||||
|
assert attn._q.weight.grad is not None
|
||||||
|
assert attn._k.weight.grad is not None
|
||||||
|
assert attn._v.weight.grad is not None
|
||||||
|
|
||||||
def test_device_consistency(self, gpt_config, random_inputs, device):
|
def test_device_consistency(self, gpt_config, random_inputs, device):
|
||||||
"""Test that GPT works on correct device."""
|
"""Test that GPT works on correct device."""
|
||||||
|
|||||||
Reference in New Issue
Block a user