From d10044e4a74f64ee77518dd1d71324a04f5ee3c9 Mon Sep 17 00:00:00 2001 From: Sergey Penkovsky Date: Wed, 15 Oct 2025 10:59:56 +0300 Subject: [PATCH] refactor(core): refactor RoPE and MultiHeadAttention, add math-rich docs, expand tests, remove unused head_attention MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - refactor: улучшена и унифицирована реализация RoPE, теперь поддерживаются строгие проверки размерности входа; внесены улучшения и структурные изменения в MultiHeadAttention (более понятная логика, строгая спецификация входов/выходов) - docs: полностью переписаны docstrings для RoPE и MultiHeadAttention — включены математические формулы, ссылки на научные статьи, подробные пояснения по алгоритму, формату входных данных, ограничениям, примеры использования - test: добавлены отдельные unit-тесты для RoPE (корректность формы, ошибки на неверную размерность, сохранение нормы, backward/градиенты, работу с параметрами start_pos и батчами) - chore: удалён неиспользуемый модуль core/head_attention.py - fix: теперь выбрасывается AssertionError при неправильной размерности входа RoPE; это позволило полностью покрыть тест-кейсы на ошибки Этот коммит синхронизирует логику реализации базового внимания с современной практикой LLM, укрепляет документацию для инженеров и исследователей, а также расширяет надежность автотестирования библиотеки. --- llm/src/llm/core/head_attention.py | 121 -------- llm/src/llm/core/multi_head_attention.py | 293 ++++++++++++++------ llm/src/llm/core/rope.py | 181 +++++++++--- llm/src/llm/models/gpt/gpt2.py | 18 +- llm/src/llm/models/mistral/mistral.py | 157 ++++------- llm/tests/core/test_multi_head_attention.py | 8 +- llm/tests/core/test_rope.py | 55 ++++ llm/tests/models/test_gpt.py | 6 +- 8 files changed, 468 insertions(+), 371 deletions(-) delete mode 100644 llm/src/llm/core/head_attention.py create mode 100644 llm/tests/core/test_rope.py diff --git a/llm/src/llm/core/head_attention.py b/llm/src/llm/core/head_attention.py deleted file mode 100644 index 194c706..0000000 --- a/llm/src/llm/core/head_attention.py +++ /dev/null @@ -1,121 +0,0 @@ -import torch -from torch import nn -import torch.nn.functional as F -from math import sqrt -from .rope import RoPE - - -class HeadAttention(nn.Module): - """ - Одноголовый механизм внимания (scaled dot-product attention) — фундаментальный строительный блок всех современных Transformer. - - Научная суть: - - Attention учит модель самостоятельно "выбирать" важные связи между словами, независимо от их положения. - - Механизм causal mask гарантирует невозможность "заглядывания в будущее" при генерации (авторегрессия). - - Формула: - Attention(Q, K, V) = softmax(QK^T / sqrt(d_k)) · V - (Q — запросы, K — ключи, V — значения; d_k — размерность ключа) - - Поддерживает Rotary Position Encoding (RoPE) для относительного позиционного кодирования. - - Args: - emb_size (int): размер входного эмбеддинга - head_size (int): размерность attention-головы - max_seq_len (int): максимальная длина последовательности - rope (RoPE, optional): экземпляр RoPE для позиций - - Примечания: - - Использует нижнетреугольную маску для предотвращения "заглядывания в будущее" - - Автоматически адаптируется к разным версиям PyTorch - - Поддерживает batch-обработку входных данных - - Пример использования: - >>> attention = HeadAttention(emb_size=64, head_size=32, max_seq_len=128) - >>> x = torch.randn(1, 10, 64) - >>> output, _ = attention(x) - >>> print(output.shape) # torch.Size([1, 10, 32]) - """ - - def __init__( - self, emb_size: int, head_size: int, max_seq_len: int, rope: RoPE = None - ): - super().__init__() - self._emb_size = emb_size - self._head_size = head_size - self._max_seq_len = max_seq_len - self._rope = rope - - # Линейные преобразования для Q, K, V - self._k = nn.Linear(emb_size, head_size) - self._q = nn.Linear(emb_size, head_size) - self._v = nn.Linear(emb_size, head_size) - - # Создание causal маски - mask = torch.tril(torch.ones(max_seq_len, max_seq_len)) - self.register_buffer( - "_tril_mask", mask.bool() if hasattr(torch, "bool") else mask.byte() - ) - - def forward( - self, x: torch.Tensor, use_cache: bool = True, cache: tuple = None - ) -> tuple: - """ - Прямой проход через слой внимания. - - Аргументы: - x (torch.Tensor): Входной тензор формы [batch_size, seq_len, emb_size] - - Возвращает: - torch.Tensor: Выходной тензор формы [batch_size, seq_len, head_size] - - Исключения: - ValueError: Если длина последовательности превышает max_seq_len - - Пример внутренних преобразований: - Для входа x.shape = [2, 5, 64]: - 1. Q/K/V преобразования -> [2, 5, 32] - 2. Scores = Q·K^T -> [2, 5, 5] - 3. После маски и softmax -> [2, 5, 5] - 4. Умножение на V -> [2, 5, 32] - """ - seq_len = x.shape[1] - if seq_len > self._max_seq_len: - raise ValueError( - f"Длина последовательности {seq_len} превышает максимум {self._max_seq_len}" - ) - - k = self._k(x) # [B, T, hs] - q = self._q(x) # [B, T, hs] - v = self._v(x) # [B, T, hs] - - start_pos = 0 - if cache is not None: - k_cache, v_cache = cache - cache_len = k_cache.shape[1] - start_pos = cache_len - - if self._rope is not None: - # ✅ Применяем RoPE к Q и K (НЕ к V!) - q = self._rope(q, start_pos=start_pos) # [B, T, hs] - k = self._rope(k, start_pos=start_pos) # [B, T, hs] - - if cache is not None: - k_cache, v_cache = cache - k = torch.cat([k_cache, k], dim=1) # [B, cache_len + T, hs] - v = torch.cat([v_cache, v], dim=1) # [B, cache_len + T, hs] - - scores = q @ k.transpose(-2, -1) / sqrt(self._head_size) - - if cache is None: - scores = scores.masked_fill( - ~self._tril_mask[:seq_len, :seq_len], float("-inf") - ) - - weights = F.softmax(scores, dim=-1) - x_out = weights @ v # [B, T, hs] - - if use_cache is True: - return (x_out, (k, v)) - else: - return (x_out, None) diff --git a/llm/src/llm/core/multi_head_attention.py b/llm/src/llm/core/multi_head_attention.py index 21fc03f..c69392e 100644 --- a/llm/src/llm/core/multi_head_attention.py +++ b/llm/src/llm/core/multi_head_attention.py @@ -1,37 +1,70 @@ from torch import nn import torch -from .head_attention import HeadAttention +import torch.nn.functional as F from .rope import RoPE class MultiHeadAttention(nn.Module): """ - Мультиголовый (многоголовый) механизм внимания — ключевой компонент любого Transformer. + Multi-Head Attention (Многоголовое внимание) + ============================================ - Научная суть: - - Модель параллельно агрегирует информацию через несколько подпространств (головы), - чтобы видеть разные связи в последовательности (разный контекст, локально/глобально). - - Каждый attention блок работает независимо, выход конкатенируется. - - Механизм предложен в статье "Attention is All You Need" (Vaswani et al., 2017). + Что такое Multi-Head Attention? + ------------------------------- + Это ключевой компонент трансформеров, который позволяет "смотреть" на разные части предложения + одновременно с нескольких независимых ракурсов ("голов"). Всё, что делает Single-Head Attention — только гораздо мощнее и глубже! - Формула внимания для одной головы: - Attention(Q, K, V) = softmax(QK^T/sqrt(d_k))·V - Мультиголовый: - MultiHead(Q, K, V) = Concat([head_i])*W^O + Зачем это нужно? + ---------------- + - Модель может учиться одновременно учитывать и локальные, и глобальные взаимосвязи между токенами. + - Каждая attention head "ловит" свой собственный смысл/зависимости, и на выходе они объединяются. + - Это значительно улучшает понимание сложных зависимостей в тексте, особенно на длинных последовательностях. - Args: - num_heads (int): количество attention "голов" - emb_size (int): размерности входа и выхода - head_size (int): размер одной attention-головы (emb_size/num_heads) - max_seq_len (int): максимальная длина последовательности - rope (RoPE, optional): если задан, используется Rotary Positional Encoding - dropout (float): вероятность регуляризации + Как работает алгоритм? (основная схема) + --------------------------------------- + 1. Генерируются Q, K, V (query, key, value) — по отдельной проекции для каждой головы. + 2. Для каждой головы: attention(Q, K, V) = softmax(Q·K^T / sqrt(d)) · V + 3. Все головы "склеиваются" (concatenate) и прогоняются через общий финальный линейный слой. + + Почему это работает? + -------------------- + - Даёт трансформеру многомерное восприятие текста. + - Позволяет эффективно обучаться на задачах, где порядок и "дальние" связи важнее, чем простое соседство. + + Что принимается на вход: + ------------------------ + - x: shape [batch, seq_len, embed_dim] — обычный batched-embed тензор. + - mask (опционально): shape [seq_len, seq_len] — маска для автогерерации или causal attention. + + Какие параметры важны: + ---------------------- + - num_heads: сколько attention heads внутри (обычно 4, 8, 16...). + - embed_dim: исходная размерность входного тензора. + - head_size: размер одной attention-head (обычно embed_dim // num_heads). + - max_seq_len: максимальная длина последовательности для маски. + + Что возвращает: + --------------- + - output: shape [batch, seq_len, embed_dim] — результат применения всех attention heads. + - (опционально) cache: кэш для Q/K/V (нужно для генерации по одному токену). + + Особенности реализации: + ----------------------- + - Оптимизированно работает через матричные умножения (без python for циклов!). + - Включена поддержка causal attention (маска, предотвращающая «заглядывание в будущее»). + - Является ядром любого трансформера (и LLM!). Пример использования: - >>> mha = MultiHeadAttention(num_heads=8, emb_size=512, head_size=64, max_seq_len=1024) - >>> x = torch.randn(2, 50, 512) - >>> out, cache = mha(x) - >>> print(out.shape) + --------------------- + >>> attn = MultiHeadAttention(num_heads=8, embed_dim=256, head_size=32, max_seq_len=1024) + >>> x = torch.randn(2, 128, 256) # [batch, seq_len, embed_dim] + >>> context, _ = attn(x) + >>> print(context.shape) # torch.Size([2, 128, 256]) + + Где прочитать подробнее: + ------------------------- + - Attention is All You Need (Vaswani et al, 2017): https://arxiv.org/abs/1706.03762 + - Illustrated Transformer (blog): https://jalammar.github.io/illustrated-transformer/ """ def __init__( @@ -44,32 +77,59 @@ class MultiHeadAttention(nn.Module): dropout: float = 0.1, ): """ - Инициализация многоголового внимания. + Конструктор многоголового внимания (MultiHeadAttention). - Параметры: - num_heads (int): Количество голов внимания. Типичные значения: 4-16 - emb_size (int): Размерность входных и выходных эмбеддингов - head_size (int): Размерность каждой головы внимания (обычно emb_size // num_heads) - max_seq_len (int): Максимальная длина последовательности - dropout (float): Вероятность dropout (по умолчанию 0.1) + Здесь создаются все параметры и внутренние слои для эффективного параллельного внимания (attention) сразу из нескольких "голов". - Контрольные значения: - - num_heads * head_size должно равняться emb_size - - head_size обычно выбирают 32-128 - - max_seq_len зависит от задачи (512 для BERT, 2048 для GPT-3) + Аргументы: + ---------- + num_heads : int + Сколько attention-heads будет внутри слоя. + Каждая “голова” учится видеть уникальные зависимости в тексте. Обычно это 4, 8, 16 и т.п. + Чем больше голов — тем богаче контекст, но и больше памяти. + emb_size : int + Сколько float-значений в каждом входном векторе (размерность embedding). + Обычно это 256, 512, 768, 1024 и т.д. + head_size : int + Сколько компонент будет у каждой головы внимания. + Важно: num_heads * head_size должно ровно совпадать с emb_size! + Обычно head_size = emb_size // num_heads. + max_seq_len : int + Максимально допустимая длина последовательности для attention/маски/генерации. + Определяет размер буферов для causal mask. + rope : RoPE, по умолчанию None + Объект Rotary Positional Encoding (если хотите привнести продвинутое позиционное кодирование в attention). + Не обязателен, но нужен для современных LLM (Llama, Mistral и пр.). + dropout : float, по умолчанию 0.1 + Величина dropout (регуляризации) — помогает борьбе с переобучением. Чем больше, тем сильнее регуляризация. + + Внутри конструктора происходит: + ------------------------------- + - Создаются три линейных слоя для Q, K, V (“где смотреть” и “что вытаскивать” в attention). + - Генерируется нижнетреугольная causal-маска (запрещает видеть будущее для автогерерации). + - Создаётся финальный линейный слой для склейки всех голов в одно пространство emb_size. + - Вводится dropout (случайное зануление, чтобы не было сильной зависимости внимания к отдельным "плейсам"). + + Пример: + ------- + >>> attn = MultiHeadAttention(num_heads=8, emb_size=256, head_size=32, max_seq_len=1024) """ super().__init__() - self._heads = nn.ModuleList( - [ - HeadAttention( - emb_size=emb_size, - head_size=head_size, - max_seq_len=max_seq_len, - rope=rope, - ) - for _ in range(num_heads) - ] + self._num_heads = num_heads + self._head_size = head_size + self._max_seq_len = max_seq_len + self._rope = rope + + self._q = nn.Linear(emb_size, num_heads * head_size) + self._k = nn.Linear(emb_size, num_heads * head_size) + self._v = nn.Linear(emb_size, num_heads * head_size) + + # Создание causal маски + mask = torch.tril(torch.ones(max_seq_len, max_seq_len)) + self.register_buffer( + "_tril_mask", mask.bool() if hasattr(torch, "bool") else mask.byte() ) + self._layer = nn.Linear(head_size * num_heads, emb_size) self._dropout = nn.Dropout(dropout) @@ -81,61 +141,116 @@ class MultiHeadAttention(nn.Module): cache: list = None, ): """ - Прямой проход (forward): - Для каждого токена оценивает "важность" остальных токенов сразу через несколько attention-блоков. + Основной шаг \"многоголового внимания\": находит взаимосвязи между токенами + в последовательности сразу из нескольких “ракурсов” (attention heads). - Подробное описание преобразований тензоров: - 1. Входной тензор [batch_size, seq_len, emb_size] разделяется на N голов: - - Каждая голова получает тензор [batch_size, seq_len, head_size] - 2. Каждая голова вычисляет attention: - - Вход: [batch_size, seq_len, head_size] - - Выход: [batch_size, seq_len, head_size] - 3. Конкатенация результатов: - - Объединенный выход: [batch_size, seq_len, num_heads * head_size] - 4. Линейная проекция: - - Выход: [batch_size, seq_len, emb_size] - 5. Применение dropout + Что делает этот метод: + ---------------------- + - Для каждого токена сравнивает его с остальными во входной последовательности. + - Делает это одновременно через несколько attention heads (каждая head видит текст по-своему). + - Итоговое “внимание” — это взвешенная сумма других токенов (контекста) для каждого токена. + - Можно использовать кэш для генерации длинных последовательностей по одному токену (ускоряет инференс). - Args: - x (Tensor[float]): [batch, seq_len, emb_size] — вход - mask (Optional[Tensor[bool]]): маска позиции [seq_len, seq_len] - use_cache (bool): использовать ли key-value кэш (для генерации) - cache (list): предыдущие значения KV для ускорения + Аргументы: + ---------- + x : torch.Tensor + Входной тензор формы [batch, seq_len, emb_size]. + Это ваши входные эмбеддинги (обычно после token + positional embedding). + mask : torch.Tensor, опционально + Матрица формы [seq_len, seq_len], задающая “разрешения” — кто может смотреть на кого (например, causal mask). + Если не указана — используется внутренняя маска (например, для autoregressive генерации). + use_cache : bool, по умолчанию True + Нужно ли использовать кэш для KV attention (важно для ускорения генерации по одному токену). + cache : list, опционально + Предыдущий кэш Key/Value — для генерации текста по частям. - Returns: - out (Tensor[float]): [batch, seq_len, emb_size] — результат MHA - kv_caches (list): списки новых KV-кэшей (если используется) + Возвращает: + ----------- + - output: torch.Tensor формы [batch, seq_len, emb_size] — результат применения multi-head attention. + - kv_caches: список новых KV для кэширования при генерации (или None). - Типичный паттерн: - Вход: [batch, seq, emb] → N голов [batch, seq, head_size] → - → concat [batch, seq, N*head_size] → проекция → dropout - - Пример преобразований для emb_size=512, num_heads=8: - Вход: [4, 100, 512] - -> Каждая голова: [4, 100, 64] - -> После внимания: 8 x [4, 100, 64] - -> Конкатенация: [4, 100, 512] - -> Проекция: [4, 100, 512] - -> Dropout: [4, 100, 512] + Важно: + ------- + - Shape входа всегда [batch, seq_len, emb_size], выход тот же. + - При seq_len > max_seq_len выбросит ошибку (безопасно для контроля переполнения буферов). + - При использовании use_cache=True кешируется только последние токены (актуально для LLM). Пример: - >>> out, caches = mha(x) - >>> out.shape # [batch, seq_len, emb_size] + ------- + >>> attn = MultiHeadAttention(num_heads=8, emb_size=256, head_size=32, max_seq_len=1024) + >>> x = torch.randn(2, 100, 256) + >>> y, kv_cache = attn(x) + >>> print(y.shape) # torch.Size([2, 100, 256]) """ - # 1. Вычисляем attention для каждой головы - attention_results = [] - for i, head in enumerate(self._heads): - head_cache = cache[i] if cache is not None else None - result = head(x, use_cache=use_cache, cache=head_cache) - attention_results.append(result) + batch_size, seq_len, emb_size = x.shape - outputs, caches = zip(*attention_results) - attention_outputs = list(outputs) - kv_caches = list(caches) + if seq_len > self._max_seq_len: + raise ValueError( + f"Длина последовательности {seq_len} превышает максимум {self._max_seq_len}" + ) - # 2. Объединяем результаты всех голов - concatenated_attention = torch.cat(attention_outputs, dim=-1) + # Пропустите тензор x через матрицы Wq, Wk , Wv, чтобы получить матрицы запроса, ключа и значения. + k = self._k(x) # [B, T, hs] + q = self._q(x) # [B, T, hs] + v = self._v(x) # [B, T, hs] + # Шаг 2: Изменение формы для multi-head + # [batch_size, seq_len, num_heads * head_size] + # -> [batch_size, seq_len, num_heads, head_size] + q = q.reshape(batch_size, seq_len, self._num_heads, self._head_size) + k = k.reshape(batch_size, seq_len, self._num_heads, self._head_size) + v = v.reshape(batch_size, seq_len, self._num_heads, self._head_size) + + + # 3. Transpose: [B, T, H, hs] -> [B, H, T, hs] + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + start_pos = 0 + if cache is not None: + k_cache, v_cache = cache + cache_len = k_cache.shape[2] + start_pos = cache_len + + # Пропустите матрицы запроса и ключа через экземпляр rope, чтобы выполнить поворот. + if self._rope is not None: + # ✅ Применяем RoPE к Q и K (НЕ к V!) + q = self._rope(q, start_pos=start_pos) # [B, T, hs] + k = self._rope(k, start_pos=start_pos) # [B, T, hs] + + # Если cache пришел, то объединяем кэш и одну строку из ключа и значения. Это будут новые key и value для последующих вычислений. + # 5. Кэширование (для autoregressive generation) + if cache is not None: + k_cache, v_cache = cache + k = torch.cat([k_cache, k], dim=2) # Concat по seq_len (dim=2) + v = torch.cat([v_cache, v], dim=2) + + # Перемножим матрицы запроса и ключа (транспонированную), чтобы вычислить матрицу внимания. + # И разделить все значения в матрице внимания на корень из head_size. + scores = q @ k.transpose(-2, -1) / (self._head_size ** 0.5) + + # Если cache пришел, то маску не накладываем. Иначе наложите на матрицу внимания треугольную маску, созданную при инициализации. Все скрытые значения должны быть приведены к минус бесконечности: float('-inf'). + if cache is None: + scores = scores.masked_fill( + ~self._tril_mask[:seq_len, :seq_len], float("-inf") + ) + + # Применить к матрице внимания (построчно) функцию Softmax. + weights = F.softmax(scores, dim=-1) + + # Перемножим матрицу внимания и матрицу значения. + x_out = weights @ v # [B, T, hs] + + # Измените форму тензора на batch_size × seq_len × num_heads*head_size. + # Transpose обратно и concatenate heads + x_out = x_out.transpose(1, 2) # [B, T_q, H, hs] + x_out = x_out.contiguous() # Важно для reshape! + concatenated_attention = x_out.reshape(batch_size, seq_len, self._num_heads * self._head_size) + + #concatenated_attention = x_out.reshape(batch_size, seq_len, self._num_heads * self._head_size) + + # Пропустите получившийся тензор через последний линейный слой. # 3. Проецируем в пространство эмбеддингов projected_output = self._layer(concatenated_attention) @@ -143,6 +258,6 @@ class MultiHeadAttention(nn.Module): final_output = self._dropout(projected_output) if use_cache is True: - return (final_output, kv_caches) + return (final_output, (k, v)) else: return (final_output, None) diff --git a/llm/src/llm/core/rope.py b/llm/src/llm/core/rope.py index 7a8801c..70059c2 100644 --- a/llm/src/llm/core/rope.py +++ b/llm/src/llm/core/rope.py @@ -1,21 +1,51 @@ """ -Rotary Positional Embeddings (RoPE) - ротационные позиционные эмбеддинги. +Rotary Positional Embeddings (RoPE) +=================================== -Реализация ротационного позиционного кодирования, которое кодирует позиционную -информацию через вращение векторов запросов и ключей в комплексном пространстве. +Что такое RoPE? +---------------- +RoPE — это способ "вписать" информацию о позиции токенов в скрытые вектора модели трансформера. +Вместо простого сложения с абсолютным positional embedding, RoPE использует вращения векторов (как поворот стрелки на круге) внутри каждого attention head. Каждый элемент пары (вектор четного и нечетного индекса) поворачивается на угол, зависящий от позиции токена. -Научная статья: "RoFormer: Enhanced Transformer with Rotary Position Embedding" -https://arxiv.org/abs/2104.09864 +Зачем это? +----------- +- RoPE реализует **относительное позиционное кодирование**: модель может сравнивать "расстояния" между токенами, а не просто помнить положение. +- Такое кодирование **улучшает генерацию длинных последовательностей** и перенос модели на тексты большей длины, чем были в обучении. +- Форма векторов и длина (норма) НЕ искажаются. -Математическая основа: -Для позиции m и измерения i: -θ_i = base^(-2i/d) -q'_m = q_m * cos(mθ_i) + rotate(q_m) * sin(mθ_i) +Как это работает? (главная формула) +------------------------------------- +Для каждой позиции m и пары компонент (2i, 2i+1) внутри head применяются: + + θ_i = base^(-2i / d) + q'_{m,2i} = q_{m,2i} * cos(m * θ_i) - q_{m,2i+1} * sin(m * θ_i) + q'_{m,2i+1} = q_{m,2i+1} * cos(m * θ_i) + q_{m,2i} * sin(m * θ_i) + +где d — размерность "головы" attention (head_size), base обычно 10_000. + +То есть, берём каждый "вектор" (в рамках head), делим на четные/нечетные части и поворачиваем их на уникальный угол, связанный с позицией/частотой. + +Архитектурные детали: +--------------------- +- Ваш тензор должен быть строго 4-мерным: [batch, num_heads, seq_len, head_size]. +- Размер head_size должен быть чётным! +- RoPE применяется отдельно к **Q** и **K** в механизме внимания (но не к V). + +Где об этом читать: +------------------- +- RoFormer: Enhanced Transformer with Rotary Position Embedding + https://arxiv.org/abs/2104.09864 +- Llama: Open and Efficient Foundation Language Models + https://arxiv.org/abs/2302.13971 +- Визуализация позиционных кодировок: + https://kazemnejad.com/blog/transformer_architecture_positional_encoding/ + +Пример использования: +--------------------- +>>> rope = RoPE(head_size=64, max_seq_len=2048) +>>> x = torch.randn(2, 8, 128, 64) # [batch, num_heads, seq_len, head_size] +>>> x_enc = rope(x) # здесь вектор x обогатится позиционной информацией -Преимущества: -- Относительное позиционное кодирование -- Лучшая экстраполяция на длинные последовательности -- Сохранение нормы векторов """ import torch @@ -25,32 +55,72 @@ from typing import Optional class RoPE(nn.Module): """ - Rotary Positional Embeddings (RoPE) для механизма внимания. + Реализация RoPE (Rotary Positional Embeddings) для self-attention в трансформерах. - Кодирует позиционную информацию через вращение векторов запросов и ключей - в многомерном пространстве с использованием синусов и косинусов. + Этот слой добавляет позиционную информацию к векторам внимания (Q, K) — + не с помощью простого сложения с positional embedding, а с помощью математического + вращения (как если бы вы крутили стрелку на круге) для каждой пары компонент + (even/odd) в каждом attention head. - Args: - head_size: Размерность головы внимания (должен быть четным) - max_seq_len: Максимальная длина последовательности - base: Базовое значение для вычисления частот (по умолчанию 10000) + Формула (для каждого токена и каждой пары компонент внутри head): + θ_i = base^(-2i / d) + out_{m,2i} = x_{m,2i} * cos(m * θ_i) - x_{m,2i+1} * sin(m * θ_i) + out_{m,2i+1} = x_{m,2i+1} * cos(m * θ_i) + x_{m,2i} * sin(m * θ_i) + где d — head_size, base обычно 10_000, степень i по head axis. - Attributes: - cos_matrix: Буферизованная матрица косинусов формы [max_seq_len, head_size//2] - sin_matrix: Буферизованная матрица синусов формы [max_seq_len, head_size//2] + Какие входы принимает: + ---------------------- + - x: обязательно размерности [batch, num_heads, seq_len, head_size]! + - head_size (размер внимания) должен быть чётным. + - start_pos: опционально, позволяет сдвигать позиционный offset для генерации с кэшем. + + Что возвращает: + --------------- + - Тот же тензор (x), только со встроенной позиционной информацией (“повёрнутый” RoPE-кодировкой). + - Форма и тип выходного тензора не меняются. + + Где используется: + ----------------- + - В любых современных LLM (Llama, Mistral, GPT-NeoX и др.) для повышения устойчивости и generalization transformer's attention. + + Пример использования: + --------------------- + >>> rope = RoPE(head_size=64, max_seq_len=2048) + >>> x = torch.randn(2, 8, 128, 64) # (batch, num_heads, seq_len, head_size) + >>> x_encoded = rope(x) + + Подробнее про математику и примеры с визуализацией: + --------------------------------------------------- + - RoFormer: https://arxiv.org/abs/2104.09864 + - Llama: https://arxiv.org/abs/2302.13971 + - Демонстрация наглядно: https://kazemnejad.com/blog/transformer_architecture_positional_encoding/ """ def __init__(self, head_size: int, max_seq_len: int, base: int = 10_000): """ - Инициализация RoPE эмбеддингов. + Инициализация объекта RoPE — настраивает и предвычисляет все необходимые + параметры для ротационного позиционного кодирования. - Args: - head_size: Размерность головы внимания (должен быть четным) - max_seq_len: Максимальная поддерживаемая длина последовательности - base: Базовое значение для вычисления частот (типично 10000) + Аргументы: + ---------- + head_size : int + Размер одного attention head (последнего измерения вектора) — сколько компонент + (float-значений) отвечает за одну "голову". Должен быть ЧЁТНЫМ числом, иначе RoPE не применим. + Обычно head_size = embed_dim // num_heads. + max_seq_len : int + Максимальная длина последовательности, которую RoPE сможет обработать. + Если ваш текст длиннее этого числа — будет ошибка! Например, для GPT2 обычно 1024, у LLaMA — до 4096. + Это число определяет размер внутренних буферов cos/sin. + base : int, по умолчанию 10_000 + База для вычисления частот вращения (θ_i) для каждой компоненты. + В оригинальных статьях почти всегда используют base=10000. + Менять этот параметр не нужно, если вы не исследуете математические детали. - Raises: - AssertionError: Если head_size не четный + Что происходит внутри: + ---------------------- + - Проверяется чётность head_size. + - Для каждого возможного положения в пределах max_seq_len и каждой пары component высчитываются уникальные cos/sin значения (матрицы частот). + - Эти матрицы используются далее для быстрого наложения позиционного "вращения" токенов внутри attention. """ super().__init__() assert head_size % 2 == 0, "head_size должен быть четным" @@ -70,28 +140,51 @@ class RoPE(nn.Module): def forward(self, x: torch.Tensor, start_pos: int = 0) -> torch.Tensor: """ - Применение ротационного позиционного кодирования к входному тензору. + Применяет ротационное позиционное кодирование (RoPE) к входному тензору. - Args: - x: Входной тензор формы [batch_size, seq_len, head_size] + Что делает эта функция: + ----------------------- + Для каждого токена в последовательности внутри каждого attention head + "поворачивает" его вектор в подпространстве (even/odd пар) на свой уникальный угол, + зависящий от позиции токена. Это позволяет attention "понимать расстояния" между токенами. - Returns: - Тензор с примененным RoPE формы [batch_size, seq_len, head_size] + Аргументы: + ---------- + x : torch.Tensor + Входной тензор строго формы [batch, num_heads, seq_len, head_size]. + Это обычно либо Q, либо K из механизма внимания. + start_pos : int, по умолчанию 0 + Сдвиг начала позиции (нужно при генерации с кэшем, почти всегда оставить 0 если не пишете автогенератор). - Алгоритм: - 1. Разделение векторов на четные и нечетные компоненты - 2. Применение вращения через синусы и косинусы - 3. Объединение компонент обратно + Возвращает: + ----------- + torch.Tensor с теми же формой и типом, что и x, но уже с наложенным позиционным кодированием. + + Важно: + ------- + - Если передан тензор не 4D, будет выброшено исключение! + - Не изменяет значения "на месте", всегда возвращает новый тензор. + + Пример: + ------- + >>> rope = RoPE(head_size=64, max_seq_len=1024) + >>> q = torch.randn(2, 8, 32, 64) # batch, num_heads, seq_len, head_size + >>> q_rope = rope(q) """ - batch_size, seq_len, emb_size = x.shape + assert x.ndim == 4, "RoPE поддерживает только 4D-вход [batch, num_heads, seq_len, head_size]" + batch_size, num_heads, seq_len, head_size = x.shape # Берем нужную часть матриц и приводим к типу x cos = self.cos_matrix[start_pos:start_pos+seq_len].to(x.dtype) # [seq_len, head_size//2] sin = self.sin_matrix[start_pos:start_pos+seq_len].to(x.dtype) # [seq_len, head_size//2] - - # Разделяем на четные и нечетные компоненты - x_even = x[:, :, 0::2] # [batch_size, seq_len, head_size//2] - x_odd = x[:, :, 1::2] # [batch_size, seq_len, head_size//2] + + # Явное изменение формы для broadcasting + cos = cos.reshape(1, 1, seq_len, head_size // 2) + sin = sin.reshape(1, 1, seq_len, head_size // 2) + + # Разделяем на четные и нечетные компоненты по ПОСЛЕДНЕМУ измерению + x_even = x[..., 0::2] # [batch_size, num_heads, seq_len, head_size//2] + x_odd = x[..., 1::2] # [batch_size, num_heads, seq_len, head_size//2] # Применяем поворот: q' = q * cos(mθ) + rotate(q) * sin(mθ) x_rotated_even = x_even * cos - x_odd * sin diff --git a/llm/src/llm/models/gpt/gpt2.py b/llm/src/llm/models/gpt/gpt2.py index 0b8d532..7e98b69 100644 --- a/llm/src/llm/models/gpt/gpt2.py +++ b/llm/src/llm/models/gpt/gpt2.py @@ -104,12 +104,20 @@ class GPT2(BaseModel): # Вычисление start_pos из кэша (если кэш передан) if cache is not None: - # При кэше обрабатываем только один токен (последний) seq_len = 1 - # Вычисляем start_pos из самого нижнего уровня кэша - if cache and cache[0] and cache[0][0]: - key_cache, _ = cache[0][0] # Первый декодер, первая голова - start_pos = key_cache.size(1) # cache_len + # Безопасно извлекаем key_cache для вычисления start_pos + if ( + isinstance(cache, (list, tuple)) + and len(cache) > 0 + and cache[0] is not None + and isinstance(cache[0], (list, tuple)) + and len(cache[0]) > 0 + and cache[0][0] is not None + and isinstance(cache[0][0], (tuple, list)) + and len(cache[0][0]) > 0 + ): + key_cache, _ = cache[0][0] + start_pos = key_cache.size(1) else: start_pos = 0 else: diff --git a/llm/src/llm/models/mistral/mistral.py b/llm/src/llm/models/mistral/mistral.py index 07754fa..607d2d4 100644 --- a/llm/src/llm/models/mistral/mistral.py +++ b/llm/src/llm/models/mistral/mistral.py @@ -4,71 +4,12 @@ from torch import Tensor import torch.nn.functional as F from math import sqrt from llm.core.base_model import BaseModel - - -class SiLU(nn.Module): - def forward(self, x: torch.Tensor): # [batch_size × seq_len × emb_size] - return torch.sigmoid(x) * x - -class RMSNorm(nn.Module): - def __init__(self, dim: int, eps: float = 1e-6): - super().__init__() - self._eps = eps - self._w = nn.Parameter(torch.ones(dim)) - - def forward(self, x: torch.Tensor): # [batch_size × seq_len × emb_size] - rms = (x.pow(2).mean(-1, keepdim=True) + self._eps) ** 0.5 - norm_x = x / rms - return self._w * norm_x - -class SwiGLU(nn.Module): - def __init__(self, emb_size: int, dropout: float = 0.1): - super().__init__() - - self._gate = nn.Linear(emb_size, 4 * emb_size) - self._up = nn.Linear(emb_size, 4 * emb_size) - self._down = nn.Linear(4 * emb_size, emb_size) - self._activation = SiLU() - self._dropout = nn.Dropout(dropout) - - def forward(self, x: torch.Tensor): # [batch_size × seq_len × emb_size]. - gate_out = self._gate(x) # [batch, seq, 4*emb] - activation_out = self._activation(gate_out) # [batch, seq, 4*emb] - up_out = self._up(x) # [batch, seq, 4*emb] - out = up_out * activation_out # поэлементное! - out = self._down(out) # [batch, seq, emb] - return self._dropout(out) - - -class TokenEmbeddings(nn.Module): - def __init__(self, vocab_size: int, emb_size: int): - super().__init__() - self._embedding = nn.Embedding( - num_embeddings=vocab_size, - embedding_dim=emb_size - ) - - def forward(self, x: Tensor) -> Tensor: - return self._embedding(x) - - @property - def num_embeddings(self) -> int: - return self._embedding.num_embeddings - - @property - def embedding_dim(self) -> int: - return self._embedding.embedding_dim - - -class GELU(nn.Module): - def __init__(self): - super().__init__() - self.sqrt_2_over_pi = torch.sqrt(torch.tensor(2.0) / math.pi) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return 0.5 * x * (1 + torch.tanh( - self.sqrt_2_over_pi * (x + 0.044715 * torch.pow(x, 3)) - )) +from llm.core.token_embeddings import TokenEmbeddings +from llm.core.silu import SiLU +from llm.core.rms_norm import RMSNorm +from llm.core.swi_glu import SwiGLU +from llm.core.gelu import GELU +from llm.core.rope import RoPE @@ -77,49 +18,49 @@ from torch import nn from typing import Optional -class RoPE(nn.Module): - - def __init__(self, head_size: int, max_seq_len: int, base: int = 10_000): - super().__init__() - assert head_size % 2 == 0, "head_size должен быть четным" - - # Вычисление частот: θ_i = base^(-2i/d) для i ∈ [0, d/2-1] - freqs = 1.0 / (base ** (2 * torch.arange(head_size // 2).float() / head_size)) - - # Позиции от 0 до max_seq_len-1 - positions = torch.arange(max_seq_len).float() - - # Внешнее произведение: m * θ_i для всех позиций и частот - freq_matrix = positions.unsqueeze(1) * freqs.unsqueeze(0) - - # Предвычисление матриц косинусов и синусов - self.register_buffer("cos_matrix", torch.cos(freq_matrix)) - self.register_buffer("sin_matrix", torch.sin(freq_matrix)) - - def forward(self, x: torch.Tensor, start_pos: int = 0) -> torch.Tensor: # [batch_size × seq_len × head_size] [batch_size × num_heads × seq_len × head_size] - batch_size, num_heads, seq_len, head_size = x.shape - - # Берем нужную часть матриц и приводим к типу x - cos = self.cos_matrix[start_pos:start_pos+seq_len].to(x.dtype) # [seq_len, head_size//2] - sin = self.sin_matrix[start_pos:start_pos+seq_len].to(x.dtype) # [seq_len, head_size//2] - - # Явное изменение формы для broadcasting - cos = cos.reshape(1, 1, seq_len, head_size // 2) - sin = sin.reshape(1, 1, seq_len, head_size // 2) - - # Разделяем на четные и нечетные компоненты по ПОСЛЕДНЕМУ измерению - x_even = x[..., 0::2] # [batch_size, num_heads, seq_len, head_size//2] - x_odd = x[..., 1::2] # [batch_size, num_heads, seq_len, head_size//2] - - # Применяем поворот: q' = q * cos(mθ) + rotate(q) * sin(mθ) - x_rotated_even = x_even * cos - x_odd * sin - x_rotated_odd = x_even * sin + x_odd * cos - - # Объединяем обратно в исходную размерность - x_rotated = torch.stack([x_rotated_even, x_rotated_odd], dim=-1) - x_rotated = x_rotated.flatten(-2) # [batch_size, seq_len, head_size] - - return x_rotated +#class RoPE(nn.Module): +# +# def __init__(self, head_size: int, max_seq_len: int, base: int = 10_000): +# super().__init__() +# assert head_size % 2 == 0, "head_size должен быть четным" +# +# # Вычисление частот: θ_i = base^(-2i/d) для i ∈ [0, d/2-1] +# freqs = 1.0 / (base ** (2 * torch.arange(head_size // 2).float() / head_size)) +# +# # Позиции от 0 до max_seq_len-1 +# positions = torch.arange(max_seq_len).float() +# +# # Внешнее произведение: m * θ_i для всех позиций и частот +# freq_matrix = positions.unsqueeze(1) * freqs.unsqueeze(0) +# +# # Предвычисление матриц косинусов и синусов +# self.register_buffer("cos_matrix", torch.cos(freq_matrix)) +# self.register_buffer("sin_matrix", torch.sin(freq_matrix)) +# +# def forward(self, x: torch.Tensor, start_pos: int = 0) -> torch.Tensor: # [batch_size × seq_len × head_size] [batch_size × num_heads × seq_len × head_size] +# batch_size, num_heads, seq_len, head_size = x.shape +# +# # Берем нужную часть матриц и приводим к типу x +# cos = self.cos_matrix[start_pos:start_pos+seq_len].to(x.dtype) # [seq_len, head_size//2] +# sin = self.sin_matrix[start_pos:start_pos+seq_len].to(x.dtype) # [seq_len, head_size//2] +# +# # Явное изменение формы для broadcasting +# cos = cos.reshape(1, 1, seq_len, head_size // 2) +# sin = sin.reshape(1, 1, seq_len, head_size // 2) +# +# # Разделяем на четные и нечетные компоненты по ПОСЛЕДНЕМУ измерению +# x_even = x[..., 0::2] # [batch_size, num_heads, seq_len, head_size//2] +# x_odd = x[..., 1::2] # [batch_size, num_heads, seq_len, head_size//2] +# +# # Применяем поворот: q' = q * cos(mθ) + rotate(q) * sin(mθ) +# x_rotated_even = x_even * cos - x_odd * sin +# x_rotated_odd = x_even * sin + x_odd * cos +# +# # Объединяем обратно в исходную размерность +# x_rotated = torch.stack([x_rotated_even, x_rotated_odd], dim=-1) +# x_rotated = x_rotated.flatten(-2) # [batch_size, seq_len, head_size] +# +# return x_rotated import torch diff --git a/llm/tests/core/test_multi_head_attention.py b/llm/tests/core/test_multi_head_attention.py index cd4246d..dd1f49b 100644 --- a/llm/tests/core/test_multi_head_attention.py +++ b/llm/tests/core/test_multi_head_attention.py @@ -19,7 +19,7 @@ class TestMultiHeadAttention: assert attention is not None # Check internal attributes - assert len(attention._heads) == num_heads + assert attention._num_heads == num_heads assert attention._layer.in_features == embed_dim assert attention._layer.out_features == embed_dim @@ -102,8 +102,10 @@ class TestMultiHeadAttention: # Check that gradients are computed for learnable parameters assert attention._layer.weight.grad is not None - if len(attention._heads) > 0: - assert attention._heads[0]._q.weight.grad is not None + # Проверяем, что также у градиентов весов q/k/v есть значения + assert attention._q.weight.grad is not None + assert attention._k.weight.grad is not None + assert attention._v.weight.grad is not None def test_device_consistency(self, embed_dim, num_heads, random_embeddings, device): """Test that MultiHeadAttention works on correct device.""" diff --git a/llm/tests/core/test_rope.py b/llm/tests/core/test_rope.py new file mode 100644 index 0000000..33c02de --- /dev/null +++ b/llm/tests/core/test_rope.py @@ -0,0 +1,55 @@ +import torch +import pytest +from llm.core.rope import RoPE + +def test_rope_shapes_and_dtype(): + rope = RoPE(head_size=8, max_seq_len=32) + x = torch.randn(2, 4, 16, 8) # [batch, num_heads, seq_len, head_size] + y = rope(x) + assert y.shape == x.shape + assert y.dtype == x.dtype + +def test_rope_raises_on_bad_ndim(): + rope = RoPE(head_size=8, max_seq_len=16) + x = torch.randn(2, 16, 8) # [batch, seq_len, head_size] (3D) + with pytest.raises(AssertionError): + _ = rope(x) + +def test_rope_preserves_norm(): + rope = RoPE(head_size=8, max_seq_len=16) + x = torch.randn(2, 3, 7, 8) + x_norm = x.norm(dim=-1) + y = rope(x) + y_norm = y.norm(dim=-1) + # Нормы могут немного отличаться из-за float, сравниваем с допуском + assert torch.allclose(x_norm, y_norm, rtol=1e-5, atol=1e-7) + +def test_rope_backward_pass(): + rope = RoPE(head_size=8, max_seq_len=16) + x = torch.randn(2, 2, 8, 8, requires_grad=True) + out = rope(x) + loss = out.sum() + loss.backward() + assert x.grad is not None + assert x.grad.shape == x.shape + +@pytest.mark.parametrize("batch,num_heads,seq_len,head_size", [ + (1, 1, 4, 8), + (2, 4, 16, 8), + (3, 2, 7, 8), +]) +def test_rope_various_shapes(batch, num_heads, seq_len, head_size): + rope = RoPE(head_size=head_size, max_seq_len=32) + x = torch.randn(batch, num_heads, seq_len, head_size) + y = rope(x) + assert y.shape == x.shape + +def test_rope_start_pos(): + rope = RoPE(head_size=8, max_seq_len=32) + x_full = torch.randn(1, 2, 8, 8) + # Сравниваем участок результата для разных start_pos + out1 = rope(x_full) + out2 = rope(x_full, start_pos=2) + assert not torch.allclose(out1, out2) + # Для одинакового start_pos и x должны совпадать + assert torch.allclose(rope(x_full, start_pos=1), rope(x_full, start_pos=1)) diff --git a/llm/tests/models/test_gpt.py b/llm/tests/models/test_gpt.py index 90040ce..61f90d2 100644 --- a/llm/tests/models/test_gpt.py +++ b/llm/tests/models/test_gpt.py @@ -145,7 +145,11 @@ class TestGPT: assert model._token_embeddings._embedding.weight.grad is not None assert model._linear.weight.grad is not None if len(model._decoders) > 0: - assert model._decoders[0]._heads._heads[0]._q.weight.grad is not None + # Проверяем через новый интерфейс attention оптимизации: + attn = model._decoders[0]._heads + assert attn._q.weight.grad is not None + assert attn._k.weight.grad is not None + assert attn._v.weight.grad is not None def test_device_consistency(self, gpt_config, random_inputs, device): """Test that GPT works on correct device."""