mirror of
https://github.com/pese-git/llm-arch-research.git
synced 2026-01-23 21:10:54 +00:00
refactor(core): refactor RoPE and MultiHeadAttention, add math-rich docs, expand tests, remove unused head_attention
- refactor: улучшена и унифицирована реализация RoPE, теперь поддерживаются строгие проверки размерности входа; внесены улучшения и структурные изменения в MultiHeadAttention (более понятная логика, строгая спецификация входов/выходов) - docs: полностью переписаны docstrings для RoPE и MultiHeadAttention — включены математические формулы, ссылки на научные статьи, подробные пояснения по алгоритму, формату входных данных, ограничениям, примеры использования - test: добавлены отдельные unit-тесты для RoPE (корректность формы, ошибки на неверную размерность, сохранение нормы, backward/градиенты, работу с параметрами start_pos и батчами) - chore: удалён неиспользуемый модуль core/head_attention.py - fix: теперь выбрасывается AssertionError при неправильной размерности входа RoPE; это позволило полностью покрыть тест-кейсы на ошибки Этот коммит синхронизирует логику реализации базового внимания с современной практикой LLM, укрепляет документацию для инженеров и исследователей, а также расширяет надежность автотестирования библиотеки.
This commit is contained in:
@@ -1,121 +0,0 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from math import sqrt
|
||||
from .rope import RoPE
|
||||
|
||||
|
||||
class HeadAttention(nn.Module):
|
||||
"""
|
||||
Одноголовый механизм внимания (scaled dot-product attention) — фундаментальный строительный блок всех современных Transformer.
|
||||
|
||||
Научная суть:
|
||||
- Attention учит модель самостоятельно "выбирать" важные связи между словами, независимо от их положения.
|
||||
- Механизм causal mask гарантирует невозможность "заглядывания в будущее" при генерации (авторегрессия).
|
||||
|
||||
Формула:
|
||||
Attention(Q, K, V) = softmax(QK^T / sqrt(d_k)) · V
|
||||
(Q — запросы, K — ключи, V — значения; d_k — размерность ключа)
|
||||
|
||||
Поддерживает Rotary Position Encoding (RoPE) для относительного позиционного кодирования.
|
||||
|
||||
Args:
|
||||
emb_size (int): размер входного эмбеддинга
|
||||
head_size (int): размерность attention-головы
|
||||
max_seq_len (int): максимальная длина последовательности
|
||||
rope (RoPE, optional): экземпляр RoPE для позиций
|
||||
|
||||
Примечания:
|
||||
- Использует нижнетреугольную маску для предотвращения "заглядывания в будущее"
|
||||
- Автоматически адаптируется к разным версиям PyTorch
|
||||
- Поддерживает batch-обработку входных данных
|
||||
|
||||
Пример использования:
|
||||
>>> attention = HeadAttention(emb_size=64, head_size=32, max_seq_len=128)
|
||||
>>> x = torch.randn(1, 10, 64)
|
||||
>>> output, _ = attention(x)
|
||||
>>> print(output.shape) # torch.Size([1, 10, 32])
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, emb_size: int, head_size: int, max_seq_len: int, rope: RoPE = None
|
||||
):
|
||||
super().__init__()
|
||||
self._emb_size = emb_size
|
||||
self._head_size = head_size
|
||||
self._max_seq_len = max_seq_len
|
||||
self._rope = rope
|
||||
|
||||
# Линейные преобразования для Q, K, V
|
||||
self._k = nn.Linear(emb_size, head_size)
|
||||
self._q = nn.Linear(emb_size, head_size)
|
||||
self._v = nn.Linear(emb_size, head_size)
|
||||
|
||||
# Создание causal маски
|
||||
mask = torch.tril(torch.ones(max_seq_len, max_seq_len))
|
||||
self.register_buffer(
|
||||
"_tril_mask", mask.bool() if hasattr(torch, "bool") else mask.byte()
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor, use_cache: bool = True, cache: tuple = None
|
||||
) -> tuple:
|
||||
"""
|
||||
Прямой проход через слой внимания.
|
||||
|
||||
Аргументы:
|
||||
x (torch.Tensor): Входной тензор формы [batch_size, seq_len, emb_size]
|
||||
|
||||
Возвращает:
|
||||
torch.Tensor: Выходной тензор формы [batch_size, seq_len, head_size]
|
||||
|
||||
Исключения:
|
||||
ValueError: Если длина последовательности превышает max_seq_len
|
||||
|
||||
Пример внутренних преобразований:
|
||||
Для входа x.shape = [2, 5, 64]:
|
||||
1. Q/K/V преобразования -> [2, 5, 32]
|
||||
2. Scores = Q·K^T -> [2, 5, 5]
|
||||
3. После маски и softmax -> [2, 5, 5]
|
||||
4. Умножение на V -> [2, 5, 32]
|
||||
"""
|
||||
seq_len = x.shape[1]
|
||||
if seq_len > self._max_seq_len:
|
||||
raise ValueError(
|
||||
f"Длина последовательности {seq_len} превышает максимум {self._max_seq_len}"
|
||||
)
|
||||
|
||||
k = self._k(x) # [B, T, hs]
|
||||
q = self._q(x) # [B, T, hs]
|
||||
v = self._v(x) # [B, T, hs]
|
||||
|
||||
start_pos = 0
|
||||
if cache is not None:
|
||||
k_cache, v_cache = cache
|
||||
cache_len = k_cache.shape[1]
|
||||
start_pos = cache_len
|
||||
|
||||
if self._rope is not None:
|
||||
# ✅ Применяем RoPE к Q и K (НЕ к V!)
|
||||
q = self._rope(q, start_pos=start_pos) # [B, T, hs]
|
||||
k = self._rope(k, start_pos=start_pos) # [B, T, hs]
|
||||
|
||||
if cache is not None:
|
||||
k_cache, v_cache = cache
|
||||
k = torch.cat([k_cache, k], dim=1) # [B, cache_len + T, hs]
|
||||
v = torch.cat([v_cache, v], dim=1) # [B, cache_len + T, hs]
|
||||
|
||||
scores = q @ k.transpose(-2, -1) / sqrt(self._head_size)
|
||||
|
||||
if cache is None:
|
||||
scores = scores.masked_fill(
|
||||
~self._tril_mask[:seq_len, :seq_len], float("-inf")
|
||||
)
|
||||
|
||||
weights = F.softmax(scores, dim=-1)
|
||||
x_out = weights @ v # [B, T, hs]
|
||||
|
||||
if use_cache is True:
|
||||
return (x_out, (k, v))
|
||||
else:
|
||||
return (x_out, None)
|
||||
@@ -1,37 +1,70 @@
|
||||
from torch import nn
|
||||
import torch
|
||||
from .head_attention import HeadAttention
|
||||
import torch.nn.functional as F
|
||||
from .rope import RoPE
|
||||
|
||||
|
||||
class MultiHeadAttention(nn.Module):
|
||||
"""
|
||||
Мультиголовый (многоголовый) механизм внимания — ключевой компонент любого Transformer.
|
||||
Multi-Head Attention (Многоголовое внимание)
|
||||
============================================
|
||||
|
||||
Научная суть:
|
||||
- Модель параллельно агрегирует информацию через несколько подпространств (головы),
|
||||
чтобы видеть разные связи в последовательности (разный контекст, локально/глобально).
|
||||
- Каждый attention блок работает независимо, выход конкатенируется.
|
||||
- Механизм предложен в статье "Attention is All You Need" (Vaswani et al., 2017).
|
||||
Что такое Multi-Head Attention?
|
||||
-------------------------------
|
||||
Это ключевой компонент трансформеров, который позволяет "смотреть" на разные части предложения
|
||||
одновременно с нескольких независимых ракурсов ("голов"). Всё, что делает Single-Head Attention — только гораздо мощнее и глубже!
|
||||
|
||||
Формула внимания для одной головы:
|
||||
Attention(Q, K, V) = softmax(QK^T/sqrt(d_k))·V
|
||||
Мультиголовый:
|
||||
MultiHead(Q, K, V) = Concat([head_i])*W^O
|
||||
Зачем это нужно?
|
||||
----------------
|
||||
- Модель может учиться одновременно учитывать и локальные, и глобальные взаимосвязи между токенами.
|
||||
- Каждая attention head "ловит" свой собственный смысл/зависимости, и на выходе они объединяются.
|
||||
- Это значительно улучшает понимание сложных зависимостей в тексте, особенно на длинных последовательностях.
|
||||
|
||||
Args:
|
||||
num_heads (int): количество attention "голов"
|
||||
emb_size (int): размерности входа и выхода
|
||||
head_size (int): размер одной attention-головы (emb_size/num_heads)
|
||||
max_seq_len (int): максимальная длина последовательности
|
||||
rope (RoPE, optional): если задан, используется Rotary Positional Encoding
|
||||
dropout (float): вероятность регуляризации
|
||||
Как работает алгоритм? (основная схема)
|
||||
---------------------------------------
|
||||
1. Генерируются Q, K, V (query, key, value) — по отдельной проекции для каждой головы.
|
||||
2. Для каждой головы: attention(Q, K, V) = softmax(Q·K^T / sqrt(d)) · V
|
||||
3. Все головы "склеиваются" (concatenate) и прогоняются через общий финальный линейный слой.
|
||||
|
||||
Почему это работает?
|
||||
--------------------
|
||||
- Даёт трансформеру многомерное восприятие текста.
|
||||
- Позволяет эффективно обучаться на задачах, где порядок и "дальние" связи важнее, чем простое соседство.
|
||||
|
||||
Что принимается на вход:
|
||||
------------------------
|
||||
- x: shape [batch, seq_len, embed_dim] — обычный batched-embed тензор.
|
||||
- mask (опционально): shape [seq_len, seq_len] — маска для автогерерации или causal attention.
|
||||
|
||||
Какие параметры важны:
|
||||
----------------------
|
||||
- num_heads: сколько attention heads внутри (обычно 4, 8, 16...).
|
||||
- embed_dim: исходная размерность входного тензора.
|
||||
- head_size: размер одной attention-head (обычно embed_dim // num_heads).
|
||||
- max_seq_len: максимальная длина последовательности для маски.
|
||||
|
||||
Что возвращает:
|
||||
---------------
|
||||
- output: shape [batch, seq_len, embed_dim] — результат применения всех attention heads.
|
||||
- (опционально) cache: кэш для Q/K/V (нужно для генерации по одному токену).
|
||||
|
||||
Особенности реализации:
|
||||
-----------------------
|
||||
- Оптимизированно работает через матричные умножения (без python for циклов!).
|
||||
- Включена поддержка causal attention (маска, предотвращающая «заглядывание в будущее»).
|
||||
- Является ядром любого трансформера (и LLM!).
|
||||
|
||||
Пример использования:
|
||||
>>> mha = MultiHeadAttention(num_heads=8, emb_size=512, head_size=64, max_seq_len=1024)
|
||||
>>> x = torch.randn(2, 50, 512)
|
||||
>>> out, cache = mha(x)
|
||||
>>> print(out.shape)
|
||||
---------------------
|
||||
>>> attn = MultiHeadAttention(num_heads=8, embed_dim=256, head_size=32, max_seq_len=1024)
|
||||
>>> x = torch.randn(2, 128, 256) # [batch, seq_len, embed_dim]
|
||||
>>> context, _ = attn(x)
|
||||
>>> print(context.shape) # torch.Size([2, 128, 256])
|
||||
|
||||
Где прочитать подробнее:
|
||||
-------------------------
|
||||
- Attention is All You Need (Vaswani et al, 2017): https://arxiv.org/abs/1706.03762
|
||||
- Illustrated Transformer (blog): https://jalammar.github.io/illustrated-transformer/
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -44,32 +77,59 @@ class MultiHeadAttention(nn.Module):
|
||||
dropout: float = 0.1,
|
||||
):
|
||||
"""
|
||||
Инициализация многоголового внимания.
|
||||
Конструктор многоголового внимания (MultiHeadAttention).
|
||||
|
||||
Параметры:
|
||||
num_heads (int): Количество голов внимания. Типичные значения: 4-16
|
||||
emb_size (int): Размерность входных и выходных эмбеддингов
|
||||
head_size (int): Размерность каждой головы внимания (обычно emb_size // num_heads)
|
||||
max_seq_len (int): Максимальная длина последовательности
|
||||
dropout (float): Вероятность dropout (по умолчанию 0.1)
|
||||
Здесь создаются все параметры и внутренние слои для эффективного параллельного внимания (attention) сразу из нескольких "голов".
|
||||
|
||||
Контрольные значения:
|
||||
- num_heads * head_size должно равняться emb_size
|
||||
- head_size обычно выбирают 32-128
|
||||
- max_seq_len зависит от задачи (512 для BERT, 2048 для GPT-3)
|
||||
Аргументы:
|
||||
----------
|
||||
num_heads : int
|
||||
Сколько attention-heads будет внутри слоя.
|
||||
Каждая “голова” учится видеть уникальные зависимости в тексте. Обычно это 4, 8, 16 и т.п.
|
||||
Чем больше голов — тем богаче контекст, но и больше памяти.
|
||||
emb_size : int
|
||||
Сколько float-значений в каждом входном векторе (размерность embedding).
|
||||
Обычно это 256, 512, 768, 1024 и т.д.
|
||||
head_size : int
|
||||
Сколько компонент будет у каждой головы внимания.
|
||||
Важно: num_heads * head_size должно ровно совпадать с emb_size!
|
||||
Обычно head_size = emb_size // num_heads.
|
||||
max_seq_len : int
|
||||
Максимально допустимая длина последовательности для attention/маски/генерации.
|
||||
Определяет размер буферов для causal mask.
|
||||
rope : RoPE, по умолчанию None
|
||||
Объект Rotary Positional Encoding (если хотите привнести продвинутое позиционное кодирование в attention).
|
||||
Не обязателен, но нужен для современных LLM (Llama, Mistral и пр.).
|
||||
dropout : float, по умолчанию 0.1
|
||||
Величина dropout (регуляризации) — помогает борьбе с переобучением. Чем больше, тем сильнее регуляризация.
|
||||
|
||||
Внутри конструктора происходит:
|
||||
-------------------------------
|
||||
- Создаются три линейных слоя для Q, K, V (“где смотреть” и “что вытаскивать” в attention).
|
||||
- Генерируется нижнетреугольная causal-маска (запрещает видеть будущее для автогерерации).
|
||||
- Создаётся финальный линейный слой для склейки всех голов в одно пространство emb_size.
|
||||
- Вводится dropout (случайное зануление, чтобы не было сильной зависимости внимания к отдельным "плейсам").
|
||||
|
||||
Пример:
|
||||
-------
|
||||
>>> attn = MultiHeadAttention(num_heads=8, emb_size=256, head_size=32, max_seq_len=1024)
|
||||
"""
|
||||
super().__init__()
|
||||
self._heads = nn.ModuleList(
|
||||
[
|
||||
HeadAttention(
|
||||
emb_size=emb_size,
|
||||
head_size=head_size,
|
||||
max_seq_len=max_seq_len,
|
||||
rope=rope,
|
||||
)
|
||||
for _ in range(num_heads)
|
||||
]
|
||||
self._num_heads = num_heads
|
||||
self._head_size = head_size
|
||||
self._max_seq_len = max_seq_len
|
||||
self._rope = rope
|
||||
|
||||
self._q = nn.Linear(emb_size, num_heads * head_size)
|
||||
self._k = nn.Linear(emb_size, num_heads * head_size)
|
||||
self._v = nn.Linear(emb_size, num_heads * head_size)
|
||||
|
||||
# Создание causal маски
|
||||
mask = torch.tril(torch.ones(max_seq_len, max_seq_len))
|
||||
self.register_buffer(
|
||||
"_tril_mask", mask.bool() if hasattr(torch, "bool") else mask.byte()
|
||||
)
|
||||
|
||||
self._layer = nn.Linear(head_size * num_heads, emb_size)
|
||||
self._dropout = nn.Dropout(dropout)
|
||||
|
||||
@@ -81,61 +141,116 @@ class MultiHeadAttention(nn.Module):
|
||||
cache: list = None,
|
||||
):
|
||||
"""
|
||||
Прямой проход (forward):
|
||||
Для каждого токена оценивает "важность" остальных токенов сразу через несколько attention-блоков.
|
||||
Основной шаг \"многоголового внимания\": находит взаимосвязи между токенами
|
||||
в последовательности сразу из нескольких “ракурсов” (attention heads).
|
||||
|
||||
Подробное описание преобразований тензоров:
|
||||
1. Входной тензор [batch_size, seq_len, emb_size] разделяется на N голов:
|
||||
- Каждая голова получает тензор [batch_size, seq_len, head_size]
|
||||
2. Каждая голова вычисляет attention:
|
||||
- Вход: [batch_size, seq_len, head_size]
|
||||
- Выход: [batch_size, seq_len, head_size]
|
||||
3. Конкатенация результатов:
|
||||
- Объединенный выход: [batch_size, seq_len, num_heads * head_size]
|
||||
4. Линейная проекция:
|
||||
- Выход: [batch_size, seq_len, emb_size]
|
||||
5. Применение dropout
|
||||
Что делает этот метод:
|
||||
----------------------
|
||||
- Для каждого токена сравнивает его с остальными во входной последовательности.
|
||||
- Делает это одновременно через несколько attention heads (каждая head видит текст по-своему).
|
||||
- Итоговое “внимание” — это взвешенная сумма других токенов (контекста) для каждого токена.
|
||||
- Можно использовать кэш для генерации длинных последовательностей по одному токену (ускоряет инференс).
|
||||
|
||||
Args:
|
||||
x (Tensor[float]): [batch, seq_len, emb_size] — вход
|
||||
mask (Optional[Tensor[bool]]): маска позиции [seq_len, seq_len]
|
||||
use_cache (bool): использовать ли key-value кэш (для генерации)
|
||||
cache (list): предыдущие значения KV для ускорения
|
||||
Аргументы:
|
||||
----------
|
||||
x : torch.Tensor
|
||||
Входной тензор формы [batch, seq_len, emb_size].
|
||||
Это ваши входные эмбеддинги (обычно после token + positional embedding).
|
||||
mask : torch.Tensor, опционально
|
||||
Матрица формы [seq_len, seq_len], задающая “разрешения” — кто может смотреть на кого (например, causal mask).
|
||||
Если не указана — используется внутренняя маска (например, для autoregressive генерации).
|
||||
use_cache : bool, по умолчанию True
|
||||
Нужно ли использовать кэш для KV attention (важно для ускорения генерации по одному токену).
|
||||
cache : list, опционально
|
||||
Предыдущий кэш Key/Value — для генерации текста по частям.
|
||||
|
||||
Returns:
|
||||
out (Tensor[float]): [batch, seq_len, emb_size] — результат MHA
|
||||
kv_caches (list): списки новых KV-кэшей (если используется)
|
||||
Возвращает:
|
||||
-----------
|
||||
- output: torch.Tensor формы [batch, seq_len, emb_size] — результат применения multi-head attention.
|
||||
- kv_caches: список новых KV для кэширования при генерации (или None).
|
||||
|
||||
Типичный паттерн:
|
||||
Вход: [batch, seq, emb] → N голов [batch, seq, head_size] →
|
||||
→ concat [batch, seq, N*head_size] → проекция → dropout
|
||||
|
||||
Пример преобразований для emb_size=512, num_heads=8:
|
||||
Вход: [4, 100, 512]
|
||||
-> Каждая голова: [4, 100, 64]
|
||||
-> После внимания: 8 x [4, 100, 64]
|
||||
-> Конкатенация: [4, 100, 512]
|
||||
-> Проекция: [4, 100, 512]
|
||||
-> Dropout: [4, 100, 512]
|
||||
Важно:
|
||||
-------
|
||||
- Shape входа всегда [batch, seq_len, emb_size], выход тот же.
|
||||
- При seq_len > max_seq_len выбросит ошибку (безопасно для контроля переполнения буферов).
|
||||
- При использовании use_cache=True кешируется только последние токены (актуально для LLM).
|
||||
|
||||
Пример:
|
||||
>>> out, caches = mha(x)
|
||||
>>> out.shape # [batch, seq_len, emb_size]
|
||||
-------
|
||||
>>> attn = MultiHeadAttention(num_heads=8, emb_size=256, head_size=32, max_seq_len=1024)
|
||||
>>> x = torch.randn(2, 100, 256)
|
||||
>>> y, kv_cache = attn(x)
|
||||
>>> print(y.shape) # torch.Size([2, 100, 256])
|
||||
"""
|
||||
# 1. Вычисляем attention для каждой головы
|
||||
attention_results = []
|
||||
for i, head in enumerate(self._heads):
|
||||
head_cache = cache[i] if cache is not None else None
|
||||
result = head(x, use_cache=use_cache, cache=head_cache)
|
||||
attention_results.append(result)
|
||||
batch_size, seq_len, emb_size = x.shape
|
||||
|
||||
outputs, caches = zip(*attention_results)
|
||||
attention_outputs = list(outputs)
|
||||
kv_caches = list(caches)
|
||||
if seq_len > self._max_seq_len:
|
||||
raise ValueError(
|
||||
f"Длина последовательности {seq_len} превышает максимум {self._max_seq_len}"
|
||||
)
|
||||
|
||||
# 2. Объединяем результаты всех голов
|
||||
concatenated_attention = torch.cat(attention_outputs, dim=-1)
|
||||
# Пропустите тензор x через матрицы Wq, Wk , Wv, чтобы получить матрицы запроса, ключа и значения.
|
||||
k = self._k(x) # [B, T, hs]
|
||||
q = self._q(x) # [B, T, hs]
|
||||
v = self._v(x) # [B, T, hs]
|
||||
|
||||
# Шаг 2: Изменение формы для multi-head
|
||||
# [batch_size, seq_len, num_heads * head_size]
|
||||
# -> [batch_size, seq_len, num_heads, head_size]
|
||||
q = q.reshape(batch_size, seq_len, self._num_heads, self._head_size)
|
||||
k = k.reshape(batch_size, seq_len, self._num_heads, self._head_size)
|
||||
v = v.reshape(batch_size, seq_len, self._num_heads, self._head_size)
|
||||
|
||||
|
||||
# 3. Transpose: [B, T, H, hs] -> [B, H, T, hs]
|
||||
q = q.transpose(1, 2)
|
||||
k = k.transpose(1, 2)
|
||||
v = v.transpose(1, 2)
|
||||
|
||||
start_pos = 0
|
||||
if cache is not None:
|
||||
k_cache, v_cache = cache
|
||||
cache_len = k_cache.shape[2]
|
||||
start_pos = cache_len
|
||||
|
||||
# Пропустите матрицы запроса и ключа через экземпляр rope, чтобы выполнить поворот.
|
||||
if self._rope is not None:
|
||||
# ✅ Применяем RoPE к Q и K (НЕ к V!)
|
||||
q = self._rope(q, start_pos=start_pos) # [B, T, hs]
|
||||
k = self._rope(k, start_pos=start_pos) # [B, T, hs]
|
||||
|
||||
# Если cache пришел, то объединяем кэш и одну строку из ключа и значения. Это будут новые key и value для последующих вычислений.
|
||||
# 5. Кэширование (для autoregressive generation)
|
||||
if cache is not None:
|
||||
k_cache, v_cache = cache
|
||||
k = torch.cat([k_cache, k], dim=2) # Concat по seq_len (dim=2)
|
||||
v = torch.cat([v_cache, v], dim=2)
|
||||
|
||||
# Перемножим матрицы запроса и ключа (транспонированную), чтобы вычислить матрицу внимания.
|
||||
# И разделить все значения в матрице внимания на корень из head_size.
|
||||
scores = q @ k.transpose(-2, -1) / (self._head_size ** 0.5)
|
||||
|
||||
# Если cache пришел, то маску не накладываем. Иначе наложите на матрицу внимания треугольную маску, созданную при инициализации. Все скрытые значения должны быть приведены к минус бесконечности: float('-inf').
|
||||
if cache is None:
|
||||
scores = scores.masked_fill(
|
||||
~self._tril_mask[:seq_len, :seq_len], float("-inf")
|
||||
)
|
||||
|
||||
# Применить к матрице внимания (построчно) функцию Softmax.
|
||||
weights = F.softmax(scores, dim=-1)
|
||||
|
||||
# Перемножим матрицу внимания и матрицу значения.
|
||||
x_out = weights @ v # [B, T, hs]
|
||||
|
||||
# Измените форму тензора на batch_size × seq_len × num_heads*head_size.
|
||||
# Transpose обратно и concatenate heads
|
||||
x_out = x_out.transpose(1, 2) # [B, T_q, H, hs]
|
||||
x_out = x_out.contiguous() # Важно для reshape!
|
||||
concatenated_attention = x_out.reshape(batch_size, seq_len, self._num_heads * self._head_size)
|
||||
|
||||
#concatenated_attention = x_out.reshape(batch_size, seq_len, self._num_heads * self._head_size)
|
||||
|
||||
# Пропустите получившийся тензор через последний линейный слой.
|
||||
# 3. Проецируем в пространство эмбеддингов
|
||||
projected_output = self._layer(concatenated_attention)
|
||||
|
||||
@@ -143,6 +258,6 @@ class MultiHeadAttention(nn.Module):
|
||||
final_output = self._dropout(projected_output)
|
||||
|
||||
if use_cache is True:
|
||||
return (final_output, kv_caches)
|
||||
return (final_output, (k, v))
|
||||
else:
|
||||
return (final_output, None)
|
||||
|
||||
@@ -1,21 +1,51 @@
|
||||
"""
|
||||
Rotary Positional Embeddings (RoPE) - ротационные позиционные эмбеддинги.
|
||||
Rotary Positional Embeddings (RoPE)
|
||||
===================================
|
||||
|
||||
Реализация ротационного позиционного кодирования, которое кодирует позиционную
|
||||
информацию через вращение векторов запросов и ключей в комплексном пространстве.
|
||||
Что такое RoPE?
|
||||
----------------
|
||||
RoPE — это способ "вписать" информацию о позиции токенов в скрытые вектора модели трансформера.
|
||||
Вместо простого сложения с абсолютным positional embedding, RoPE использует вращения векторов (как поворот стрелки на круге) внутри каждого attention head. Каждый элемент пары (вектор четного и нечетного индекса) поворачивается на угол, зависящий от позиции токена.
|
||||
|
||||
Научная статья: "RoFormer: Enhanced Transformer with Rotary Position Embedding"
|
||||
https://arxiv.org/abs/2104.09864
|
||||
Зачем это?
|
||||
-----------
|
||||
- RoPE реализует **относительное позиционное кодирование**: модель может сравнивать "расстояния" между токенами, а не просто помнить положение.
|
||||
- Такое кодирование **улучшает генерацию длинных последовательностей** и перенос модели на тексты большей длины, чем были в обучении.
|
||||
- Форма векторов и длина (норма) НЕ искажаются.
|
||||
|
||||
Как это работает? (главная формула)
|
||||
-------------------------------------
|
||||
Для каждой позиции m и пары компонент (2i, 2i+1) внутри head применяются:
|
||||
|
||||
Математическая основа:
|
||||
Для позиции m и измерения i:
|
||||
θ_i = base^(-2i / d)
|
||||
q'_m = q_m * cos(mθ_i) + rotate(q_m) * sin(mθ_i)
|
||||
q'_{m,2i} = q_{m,2i} * cos(m * θ_i) - q_{m,2i+1} * sin(m * θ_i)
|
||||
q'_{m,2i+1} = q_{m,2i+1} * cos(m * θ_i) + q_{m,2i} * sin(m * θ_i)
|
||||
|
||||
где d — размерность "головы" attention (head_size), base обычно 10_000.
|
||||
|
||||
То есть, берём каждый "вектор" (в рамках head), делим на четные/нечетные части и поворачиваем их на уникальный угол, связанный с позицией/частотой.
|
||||
|
||||
Архитектурные детали:
|
||||
---------------------
|
||||
- Ваш тензор должен быть строго 4-мерным: [batch, num_heads, seq_len, head_size].
|
||||
- Размер head_size должен быть чётным!
|
||||
- RoPE применяется отдельно к **Q** и **K** в механизме внимания (но не к V).
|
||||
|
||||
Где об этом читать:
|
||||
-------------------
|
||||
- RoFormer: Enhanced Transformer with Rotary Position Embedding
|
||||
https://arxiv.org/abs/2104.09864
|
||||
- Llama: Open and Efficient Foundation Language Models
|
||||
https://arxiv.org/abs/2302.13971
|
||||
- Визуализация позиционных кодировок:
|
||||
https://kazemnejad.com/blog/transformer_architecture_positional_encoding/
|
||||
|
||||
Пример использования:
|
||||
---------------------
|
||||
>>> rope = RoPE(head_size=64, max_seq_len=2048)
|
||||
>>> x = torch.randn(2, 8, 128, 64) # [batch, num_heads, seq_len, head_size]
|
||||
>>> x_enc = rope(x) # здесь вектор x обогатится позиционной информацией
|
||||
|
||||
Преимущества:
|
||||
- Относительное позиционное кодирование
|
||||
- Лучшая экстраполяция на длинные последовательности
|
||||
- Сохранение нормы векторов
|
||||
"""
|
||||
|
||||
import torch
|
||||
@@ -25,32 +55,72 @@ from typing import Optional
|
||||
|
||||
class RoPE(nn.Module):
|
||||
"""
|
||||
Rotary Positional Embeddings (RoPE) для механизма внимания.
|
||||
Реализация RoPE (Rotary Positional Embeddings) для self-attention в трансформерах.
|
||||
|
||||
Кодирует позиционную информацию через вращение векторов запросов и ключей
|
||||
в многомерном пространстве с использованием синусов и косинусов.
|
||||
Этот слой добавляет позиционную информацию к векторам внимания (Q, K) —
|
||||
не с помощью простого сложения с positional embedding, а с помощью математического
|
||||
вращения (как если бы вы крутили стрелку на круге) для каждой пары компонент
|
||||
(even/odd) в каждом attention head.
|
||||
|
||||
Args:
|
||||
head_size: Размерность головы внимания (должен быть четным)
|
||||
max_seq_len: Максимальная длина последовательности
|
||||
base: Базовое значение для вычисления частот (по умолчанию 10000)
|
||||
Формула (для каждого токена и каждой пары компонент внутри head):
|
||||
θ_i = base^(-2i / d)
|
||||
out_{m,2i} = x_{m,2i} * cos(m * θ_i) - x_{m,2i+1} * sin(m * θ_i)
|
||||
out_{m,2i+1} = x_{m,2i+1} * cos(m * θ_i) + x_{m,2i} * sin(m * θ_i)
|
||||
где d — head_size, base обычно 10_000, степень i по head axis.
|
||||
|
||||
Attributes:
|
||||
cos_matrix: Буферизованная матрица косинусов формы [max_seq_len, head_size//2]
|
||||
sin_matrix: Буферизованная матрица синусов формы [max_seq_len, head_size//2]
|
||||
Какие входы принимает:
|
||||
----------------------
|
||||
- x: обязательно размерности [batch, num_heads, seq_len, head_size]!
|
||||
- head_size (размер внимания) должен быть чётным.
|
||||
- start_pos: опционально, позволяет сдвигать позиционный offset для генерации с кэшем.
|
||||
|
||||
Что возвращает:
|
||||
---------------
|
||||
- Тот же тензор (x), только со встроенной позиционной информацией (“повёрнутый” RoPE-кодировкой).
|
||||
- Форма и тип выходного тензора не меняются.
|
||||
|
||||
Где используется:
|
||||
-----------------
|
||||
- В любых современных LLM (Llama, Mistral, GPT-NeoX и др.) для повышения устойчивости и generalization transformer's attention.
|
||||
|
||||
Пример использования:
|
||||
---------------------
|
||||
>>> rope = RoPE(head_size=64, max_seq_len=2048)
|
||||
>>> x = torch.randn(2, 8, 128, 64) # (batch, num_heads, seq_len, head_size)
|
||||
>>> x_encoded = rope(x)
|
||||
|
||||
Подробнее про математику и примеры с визуализацией:
|
||||
---------------------------------------------------
|
||||
- RoFormer: https://arxiv.org/abs/2104.09864
|
||||
- Llama: https://arxiv.org/abs/2302.13971
|
||||
- Демонстрация наглядно: https://kazemnejad.com/blog/transformer_architecture_positional_encoding/
|
||||
"""
|
||||
|
||||
def __init__(self, head_size: int, max_seq_len: int, base: int = 10_000):
|
||||
"""
|
||||
Инициализация RoPE эмбеддингов.
|
||||
Инициализация объекта RoPE — настраивает и предвычисляет все необходимые
|
||||
параметры для ротационного позиционного кодирования.
|
||||
|
||||
Args:
|
||||
head_size: Размерность головы внимания (должен быть четным)
|
||||
max_seq_len: Максимальная поддерживаемая длина последовательности
|
||||
base: Базовое значение для вычисления частот (типично 10000)
|
||||
Аргументы:
|
||||
----------
|
||||
head_size : int
|
||||
Размер одного attention head (последнего измерения вектора) — сколько компонент
|
||||
(float-значений) отвечает за одну "голову". Должен быть ЧЁТНЫМ числом, иначе RoPE не применим.
|
||||
Обычно head_size = embed_dim // num_heads.
|
||||
max_seq_len : int
|
||||
Максимальная длина последовательности, которую RoPE сможет обработать.
|
||||
Если ваш текст длиннее этого числа — будет ошибка! Например, для GPT2 обычно 1024, у LLaMA — до 4096.
|
||||
Это число определяет размер внутренних буферов cos/sin.
|
||||
base : int, по умолчанию 10_000
|
||||
База для вычисления частот вращения (θ_i) для каждой компоненты.
|
||||
В оригинальных статьях почти всегда используют base=10000.
|
||||
Менять этот параметр не нужно, если вы не исследуете математические детали.
|
||||
|
||||
Raises:
|
||||
AssertionError: Если head_size не четный
|
||||
Что происходит внутри:
|
||||
----------------------
|
||||
- Проверяется чётность head_size.
|
||||
- Для каждого возможного положения в пределах max_seq_len и каждой пары component высчитываются уникальные cos/sin значения (матрицы частот).
|
||||
- Эти матрицы используются далее для быстрого наложения позиционного "вращения" токенов внутри attention.
|
||||
"""
|
||||
super().__init__()
|
||||
assert head_size % 2 == 0, "head_size должен быть четным"
|
||||
@@ -70,28 +140,51 @@ class RoPE(nn.Module):
|
||||
|
||||
def forward(self, x: torch.Tensor, start_pos: int = 0) -> torch.Tensor:
|
||||
"""
|
||||
Применение ротационного позиционного кодирования к входному тензору.
|
||||
Применяет ротационное позиционное кодирование (RoPE) к входному тензору.
|
||||
|
||||
Args:
|
||||
x: Входной тензор формы [batch_size, seq_len, head_size]
|
||||
Что делает эта функция:
|
||||
-----------------------
|
||||
Для каждого токена в последовательности внутри каждого attention head
|
||||
"поворачивает" его вектор в подпространстве (even/odd пар) на свой уникальный угол,
|
||||
зависящий от позиции токена. Это позволяет attention "понимать расстояния" между токенами.
|
||||
|
||||
Returns:
|
||||
Тензор с примененным RoPE формы [batch_size, seq_len, head_size]
|
||||
Аргументы:
|
||||
----------
|
||||
x : torch.Tensor
|
||||
Входной тензор строго формы [batch, num_heads, seq_len, head_size].
|
||||
Это обычно либо Q, либо K из механизма внимания.
|
||||
start_pos : int, по умолчанию 0
|
||||
Сдвиг начала позиции (нужно при генерации с кэшем, почти всегда оставить 0 если не пишете автогенератор).
|
||||
|
||||
Алгоритм:
|
||||
1. Разделение векторов на четные и нечетные компоненты
|
||||
2. Применение вращения через синусы и косинусы
|
||||
3. Объединение компонент обратно
|
||||
Возвращает:
|
||||
-----------
|
||||
torch.Tensor с теми же формой и типом, что и x, но уже с наложенным позиционным кодированием.
|
||||
|
||||
Важно:
|
||||
-------
|
||||
- Если передан тензор не 4D, будет выброшено исключение!
|
||||
- Не изменяет значения "на месте", всегда возвращает новый тензор.
|
||||
|
||||
Пример:
|
||||
-------
|
||||
>>> rope = RoPE(head_size=64, max_seq_len=1024)
|
||||
>>> q = torch.randn(2, 8, 32, 64) # batch, num_heads, seq_len, head_size
|
||||
>>> q_rope = rope(q)
|
||||
"""
|
||||
batch_size, seq_len, emb_size = x.shape
|
||||
assert x.ndim == 4, "RoPE поддерживает только 4D-вход [batch, num_heads, seq_len, head_size]"
|
||||
batch_size, num_heads, seq_len, head_size = x.shape
|
||||
|
||||
# Берем нужную часть матриц и приводим к типу x
|
||||
cos = self.cos_matrix[start_pos:start_pos+seq_len].to(x.dtype) # [seq_len, head_size//2]
|
||||
sin = self.sin_matrix[start_pos:start_pos+seq_len].to(x.dtype) # [seq_len, head_size//2]
|
||||
|
||||
# Разделяем на четные и нечетные компоненты
|
||||
x_even = x[:, :, 0::2] # [batch_size, seq_len, head_size//2]
|
||||
x_odd = x[:, :, 1::2] # [batch_size, seq_len, head_size//2]
|
||||
# Явное изменение формы для broadcasting
|
||||
cos = cos.reshape(1, 1, seq_len, head_size // 2)
|
||||
sin = sin.reshape(1, 1, seq_len, head_size // 2)
|
||||
|
||||
# Разделяем на четные и нечетные компоненты по ПОСЛЕДНЕМУ измерению
|
||||
x_even = x[..., 0::2] # [batch_size, num_heads, seq_len, head_size//2]
|
||||
x_odd = x[..., 1::2] # [batch_size, num_heads, seq_len, head_size//2]
|
||||
|
||||
# Применяем поворот: q' = q * cos(mθ) + rotate(q) * sin(mθ)
|
||||
x_rotated_even = x_even * cos - x_odd * sin
|
||||
|
||||
@@ -104,12 +104,20 @@ class GPT2(BaseModel):
|
||||
|
||||
# Вычисление start_pos из кэша (если кэш передан)
|
||||
if cache is not None:
|
||||
# При кэше обрабатываем только один токен (последний)
|
||||
seq_len = 1
|
||||
# Вычисляем start_pos из самого нижнего уровня кэша
|
||||
if cache and cache[0] and cache[0][0]:
|
||||
key_cache, _ = cache[0][0] # Первый декодер, первая голова
|
||||
start_pos = key_cache.size(1) # cache_len
|
||||
# Безопасно извлекаем key_cache для вычисления start_pos
|
||||
if (
|
||||
isinstance(cache, (list, tuple))
|
||||
and len(cache) > 0
|
||||
and cache[0] is not None
|
||||
and isinstance(cache[0], (list, tuple))
|
||||
and len(cache[0]) > 0
|
||||
and cache[0][0] is not None
|
||||
and isinstance(cache[0][0], (tuple, list))
|
||||
and len(cache[0][0]) > 0
|
||||
):
|
||||
key_cache, _ = cache[0][0]
|
||||
start_pos = key_cache.size(1)
|
||||
else:
|
||||
start_pos = 0
|
||||
else:
|
||||
|
||||
@@ -4,71 +4,12 @@ from torch import Tensor
|
||||
import torch.nn.functional as F
|
||||
from math import sqrt
|
||||
from llm.core.base_model import BaseModel
|
||||
|
||||
|
||||
class SiLU(nn.Module):
|
||||
def forward(self, x: torch.Tensor): # [batch_size × seq_len × emb_size]
|
||||
return torch.sigmoid(x) * x
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dim: int, eps: float = 1e-6):
|
||||
super().__init__()
|
||||
self._eps = eps
|
||||
self._w = nn.Parameter(torch.ones(dim))
|
||||
|
||||
def forward(self, x: torch.Tensor): # [batch_size × seq_len × emb_size]
|
||||
rms = (x.pow(2).mean(-1, keepdim=True) + self._eps) ** 0.5
|
||||
norm_x = x / rms
|
||||
return self._w * norm_x
|
||||
|
||||
class SwiGLU(nn.Module):
|
||||
def __init__(self, emb_size: int, dropout: float = 0.1):
|
||||
super().__init__()
|
||||
|
||||
self._gate = nn.Linear(emb_size, 4 * emb_size)
|
||||
self._up = nn.Linear(emb_size, 4 * emb_size)
|
||||
self._down = nn.Linear(4 * emb_size, emb_size)
|
||||
self._activation = SiLU()
|
||||
self._dropout = nn.Dropout(dropout)
|
||||
|
||||
def forward(self, x: torch.Tensor): # [batch_size × seq_len × emb_size].
|
||||
gate_out = self._gate(x) # [batch, seq, 4*emb]
|
||||
activation_out = self._activation(gate_out) # [batch, seq, 4*emb]
|
||||
up_out = self._up(x) # [batch, seq, 4*emb]
|
||||
out = up_out * activation_out # поэлементное!
|
||||
out = self._down(out) # [batch, seq, emb]
|
||||
return self._dropout(out)
|
||||
|
||||
|
||||
class TokenEmbeddings(nn.Module):
|
||||
def __init__(self, vocab_size: int, emb_size: int):
|
||||
super().__init__()
|
||||
self._embedding = nn.Embedding(
|
||||
num_embeddings=vocab_size,
|
||||
embedding_dim=emb_size
|
||||
)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return self._embedding(x)
|
||||
|
||||
@property
|
||||
def num_embeddings(self) -> int:
|
||||
return self._embedding.num_embeddings
|
||||
|
||||
@property
|
||||
def embedding_dim(self) -> int:
|
||||
return self._embedding.embedding_dim
|
||||
|
||||
|
||||
class GELU(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.sqrt_2_over_pi = torch.sqrt(torch.tensor(2.0) / math.pi)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return 0.5 * x * (1 + torch.tanh(
|
||||
self.sqrt_2_over_pi * (x + 0.044715 * torch.pow(x, 3))
|
||||
))
|
||||
from llm.core.token_embeddings import TokenEmbeddings
|
||||
from llm.core.silu import SiLU
|
||||
from llm.core.rms_norm import RMSNorm
|
||||
from llm.core.swi_glu import SwiGLU
|
||||
from llm.core.gelu import GELU
|
||||
from llm.core.rope import RoPE
|
||||
|
||||
|
||||
|
||||
@@ -77,49 +18,49 @@ from torch import nn
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class RoPE(nn.Module):
|
||||
|
||||
def __init__(self, head_size: int, max_seq_len: int, base: int = 10_000):
|
||||
super().__init__()
|
||||
assert head_size % 2 == 0, "head_size должен быть четным"
|
||||
|
||||
# Вычисление частот: θ_i = base^(-2i/d) для i ∈ [0, d/2-1]
|
||||
freqs = 1.0 / (base ** (2 * torch.arange(head_size // 2).float() / head_size))
|
||||
|
||||
# Позиции от 0 до max_seq_len-1
|
||||
positions = torch.arange(max_seq_len).float()
|
||||
|
||||
# Внешнее произведение: m * θ_i для всех позиций и частот
|
||||
freq_matrix = positions.unsqueeze(1) * freqs.unsqueeze(0)
|
||||
|
||||
# Предвычисление матриц косинусов и синусов
|
||||
self.register_buffer("cos_matrix", torch.cos(freq_matrix))
|
||||
self.register_buffer("sin_matrix", torch.sin(freq_matrix))
|
||||
|
||||
def forward(self, x: torch.Tensor, start_pos: int = 0) -> torch.Tensor: # [batch_size × seq_len × head_size] [batch_size × num_heads × seq_len × head_size]
|
||||
batch_size, num_heads, seq_len, head_size = x.shape
|
||||
|
||||
# Берем нужную часть матриц и приводим к типу x
|
||||
cos = self.cos_matrix[start_pos:start_pos+seq_len].to(x.dtype) # [seq_len, head_size//2]
|
||||
sin = self.sin_matrix[start_pos:start_pos+seq_len].to(x.dtype) # [seq_len, head_size//2]
|
||||
|
||||
# Явное изменение формы для broadcasting
|
||||
cos = cos.reshape(1, 1, seq_len, head_size // 2)
|
||||
sin = sin.reshape(1, 1, seq_len, head_size // 2)
|
||||
|
||||
# Разделяем на четные и нечетные компоненты по ПОСЛЕДНЕМУ измерению
|
||||
x_even = x[..., 0::2] # [batch_size, num_heads, seq_len, head_size//2]
|
||||
x_odd = x[..., 1::2] # [batch_size, num_heads, seq_len, head_size//2]
|
||||
|
||||
# Применяем поворот: q' = q * cos(mθ) + rotate(q) * sin(mθ)
|
||||
x_rotated_even = x_even * cos - x_odd * sin
|
||||
x_rotated_odd = x_even * sin + x_odd * cos
|
||||
|
||||
# Объединяем обратно в исходную размерность
|
||||
x_rotated = torch.stack([x_rotated_even, x_rotated_odd], dim=-1)
|
||||
x_rotated = x_rotated.flatten(-2) # [batch_size, seq_len, head_size]
|
||||
|
||||
return x_rotated
|
||||
#class RoPE(nn.Module):
|
||||
#
|
||||
# def __init__(self, head_size: int, max_seq_len: int, base: int = 10_000):
|
||||
# super().__init__()
|
||||
# assert head_size % 2 == 0, "head_size должен быть четным"
|
||||
#
|
||||
# # Вычисление частот: θ_i = base^(-2i/d) для i ∈ [0, d/2-1]
|
||||
# freqs = 1.0 / (base ** (2 * torch.arange(head_size // 2).float() / head_size))
|
||||
#
|
||||
# # Позиции от 0 до max_seq_len-1
|
||||
# positions = torch.arange(max_seq_len).float()
|
||||
#
|
||||
# # Внешнее произведение: m * θ_i для всех позиций и частот
|
||||
# freq_matrix = positions.unsqueeze(1) * freqs.unsqueeze(0)
|
||||
#
|
||||
# # Предвычисление матриц косинусов и синусов
|
||||
# self.register_buffer("cos_matrix", torch.cos(freq_matrix))
|
||||
# self.register_buffer("sin_matrix", torch.sin(freq_matrix))
|
||||
#
|
||||
# def forward(self, x: torch.Tensor, start_pos: int = 0) -> torch.Tensor: # [batch_size × seq_len × head_size] [batch_size × num_heads × seq_len × head_size]
|
||||
# batch_size, num_heads, seq_len, head_size = x.shape
|
||||
#
|
||||
# # Берем нужную часть матриц и приводим к типу x
|
||||
# cos = self.cos_matrix[start_pos:start_pos+seq_len].to(x.dtype) # [seq_len, head_size//2]
|
||||
# sin = self.sin_matrix[start_pos:start_pos+seq_len].to(x.dtype) # [seq_len, head_size//2]
|
||||
#
|
||||
# # Явное изменение формы для broadcasting
|
||||
# cos = cos.reshape(1, 1, seq_len, head_size // 2)
|
||||
# sin = sin.reshape(1, 1, seq_len, head_size // 2)
|
||||
#
|
||||
# # Разделяем на четные и нечетные компоненты по ПОСЛЕДНЕМУ измерению
|
||||
# x_even = x[..., 0::2] # [batch_size, num_heads, seq_len, head_size//2]
|
||||
# x_odd = x[..., 1::2] # [batch_size, num_heads, seq_len, head_size//2]
|
||||
#
|
||||
# # Применяем поворот: q' = q * cos(mθ) + rotate(q) * sin(mθ)
|
||||
# x_rotated_even = x_even * cos - x_odd * sin
|
||||
# x_rotated_odd = x_even * sin + x_odd * cos
|
||||
#
|
||||
# # Объединяем обратно в исходную размерность
|
||||
# x_rotated = torch.stack([x_rotated_even, x_rotated_odd], dim=-1)
|
||||
# x_rotated = x_rotated.flatten(-2) # [batch_size, seq_len, head_size]
|
||||
#
|
||||
# return x_rotated
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
@@ -19,7 +19,7 @@ class TestMultiHeadAttention:
|
||||
assert attention is not None
|
||||
|
||||
# Check internal attributes
|
||||
assert len(attention._heads) == num_heads
|
||||
assert attention._num_heads == num_heads
|
||||
assert attention._layer.in_features == embed_dim
|
||||
assert attention._layer.out_features == embed_dim
|
||||
|
||||
@@ -102,8 +102,10 @@ class TestMultiHeadAttention:
|
||||
|
||||
# Check that gradients are computed for learnable parameters
|
||||
assert attention._layer.weight.grad is not None
|
||||
if len(attention._heads) > 0:
|
||||
assert attention._heads[0]._q.weight.grad is not None
|
||||
# Проверяем, что также у градиентов весов q/k/v есть значения
|
||||
assert attention._q.weight.grad is not None
|
||||
assert attention._k.weight.grad is not None
|
||||
assert attention._v.weight.grad is not None
|
||||
|
||||
def test_device_consistency(self, embed_dim, num_heads, random_embeddings, device):
|
||||
"""Test that MultiHeadAttention works on correct device."""
|
||||
|
||||
55
llm/tests/core/test_rope.py
Normal file
55
llm/tests/core/test_rope.py
Normal file
@@ -0,0 +1,55 @@
|
||||
import torch
|
||||
import pytest
|
||||
from llm.core.rope import RoPE
|
||||
|
||||
def test_rope_shapes_and_dtype():
|
||||
rope = RoPE(head_size=8, max_seq_len=32)
|
||||
x = torch.randn(2, 4, 16, 8) # [batch, num_heads, seq_len, head_size]
|
||||
y = rope(x)
|
||||
assert y.shape == x.shape
|
||||
assert y.dtype == x.dtype
|
||||
|
||||
def test_rope_raises_on_bad_ndim():
|
||||
rope = RoPE(head_size=8, max_seq_len=16)
|
||||
x = torch.randn(2, 16, 8) # [batch, seq_len, head_size] (3D)
|
||||
with pytest.raises(AssertionError):
|
||||
_ = rope(x)
|
||||
|
||||
def test_rope_preserves_norm():
|
||||
rope = RoPE(head_size=8, max_seq_len=16)
|
||||
x = torch.randn(2, 3, 7, 8)
|
||||
x_norm = x.norm(dim=-1)
|
||||
y = rope(x)
|
||||
y_norm = y.norm(dim=-1)
|
||||
# Нормы могут немного отличаться из-за float, сравниваем с допуском
|
||||
assert torch.allclose(x_norm, y_norm, rtol=1e-5, atol=1e-7)
|
||||
|
||||
def test_rope_backward_pass():
|
||||
rope = RoPE(head_size=8, max_seq_len=16)
|
||||
x = torch.randn(2, 2, 8, 8, requires_grad=True)
|
||||
out = rope(x)
|
||||
loss = out.sum()
|
||||
loss.backward()
|
||||
assert x.grad is not None
|
||||
assert x.grad.shape == x.shape
|
||||
|
||||
@pytest.mark.parametrize("batch,num_heads,seq_len,head_size", [
|
||||
(1, 1, 4, 8),
|
||||
(2, 4, 16, 8),
|
||||
(3, 2, 7, 8),
|
||||
])
|
||||
def test_rope_various_shapes(batch, num_heads, seq_len, head_size):
|
||||
rope = RoPE(head_size=head_size, max_seq_len=32)
|
||||
x = torch.randn(batch, num_heads, seq_len, head_size)
|
||||
y = rope(x)
|
||||
assert y.shape == x.shape
|
||||
|
||||
def test_rope_start_pos():
|
||||
rope = RoPE(head_size=8, max_seq_len=32)
|
||||
x_full = torch.randn(1, 2, 8, 8)
|
||||
# Сравниваем участок результата для разных start_pos
|
||||
out1 = rope(x_full)
|
||||
out2 = rope(x_full, start_pos=2)
|
||||
assert not torch.allclose(out1, out2)
|
||||
# Для одинакового start_pos и x должны совпадать
|
||||
assert torch.allclose(rope(x_full, start_pos=1), rope(x_full, start_pos=1))
|
||||
@@ -145,7 +145,11 @@ class TestGPT:
|
||||
assert model._token_embeddings._embedding.weight.grad is not None
|
||||
assert model._linear.weight.grad is not None
|
||||
if len(model._decoders) > 0:
|
||||
assert model._decoders[0]._heads._heads[0]._q.weight.grad is not None
|
||||
# Проверяем через новый интерфейс attention оптимизации:
|
||||
attn = model._decoders[0]._heads
|
||||
assert attn._q.weight.grad is not None
|
||||
assert attn._k.weight.grad is not None
|
||||
assert attn._v.weight.grad is not None
|
||||
|
||||
def test_device_consistency(self, gpt_config, random_inputs, device):
|
||||
"""Test that GPT works on correct device."""
|
||||
|
||||
Reference in New Issue
Block a user