From d99d605b35b4570556adc5f9316c9aca83ec830a Mon Sep 17 00:00:00 2001 From: Sergey Penkovsky Date: Mon, 6 Oct 2025 14:22:44 +0300 Subject: [PATCH] refactor: partial removal of duplicate code by using core modules - Removed duplicate HeadAttention and MultiHeadAttention implementations from llama.py - Now importing MultiHeadAttention from core module - Added RoPE support parameter to core HeadAttention constructor - Kept LLaMA-specific CachedDecoder implementation (uses SwiGLU and RMSNorm) - Core CachedDecoder uses different components (FeedForward and LayerNorm) - Improved code reuse for attention components while maintaining LLaMA-specific decoder This is a partial refactor - attention components are now shared, but decoder remains LLaMA-specific due to different normalization and activation requirements. --- llm/src/llm/core/head_attention.py | 9 ++- llm/src/llm/core/multi_head_attention.py | 6 +- llm/src/llm/models/llama/llama.py | 92 +----------------------- 3 files changed, 15 insertions(+), 92 deletions(-) diff --git a/llm/src/llm/core/head_attention.py b/llm/src/llm/core/head_attention.py index 0cc8ccf..ff9909c 100644 --- a/llm/src/llm/core/head_attention.py +++ b/llm/src/llm/core/head_attention.py @@ -2,6 +2,7 @@ import torch from torch import nn import torch.nn.functional as F from math import sqrt +from .rope import RoPE class HeadAttention(nn.Module): """ @@ -30,11 +31,12 @@ class HeadAttention(nn.Module): - Автоматически адаптируется к разным версиям PyTorch - Поддерживает batch-обработку входных данных """ - def __init__(self, emb_size: int, head_size: int, max_seq_len: int): + def __init__(self, emb_size: int, head_size: int, max_seq_len: int, rope: RoPE = None): super().__init__() self._emb_size = emb_size self._head_size = head_size self._max_seq_len = max_seq_len + self._rope = rope # Линейные преобразования для Q, K, V self._k = nn.Linear(emb_size, head_size) @@ -73,6 +75,11 @@ class HeadAttention(nn.Module): q = self._q(x) # [B, T, hs] v = self._v(x) # [B, T, hs] + if self._rope is not None: + # ✅ Применяем RoPE к Q и K (НЕ к V!) + q = self._rope(q) # [B, T, hs] + k = self._rope(k) # [B, T, hs] + if cache is not None: k_cache, v_cache = cache k = torch.cat([k_cache, k], dim=1) # [B, cache_len + T, hs] diff --git a/llm/src/llm/core/multi_head_attention.py b/llm/src/llm/core/multi_head_attention.py index 3bf5fc8..b865714 100644 --- a/llm/src/llm/core/multi_head_attention.py +++ b/llm/src/llm/core/multi_head_attention.py @@ -1,6 +1,7 @@ from torch import nn import torch from .head_attention import HeadAttention +from .rope import RoPE class MultiHeadAttention(nn.Module): """ @@ -31,7 +32,7 @@ class MultiHeadAttention(nn.Module): >>> self.attention = MultiHeadAttention(...) >>> x = self.attention(x, mask) """ - def __init__(self, num_heads: int, emb_size: int, head_size: int, max_seq_len: int, dropout: float = 0.1): + def __init__(self, num_heads: int, emb_size: int, head_size: int, max_seq_len: int, rope: RoPE = None, dropout: float = 0.1): """ Инициализация многоголового внимания. @@ -52,7 +53,8 @@ class MultiHeadAttention(nn.Module): HeadAttention( emb_size=emb_size, head_size=head_size, - max_seq_len=max_seq_len + max_seq_len=max_seq_len, + rope=rope, ) for _ in range(num_heads) ]) self._layer = nn.Linear(head_size * num_heads, emb_size) diff --git a/llm/src/llm/models/llama/llama.py b/llm/src/llm/models/llama/llama.py index d220c1a..2b45f68 100644 --- a/llm/src/llm/models/llama/llama.py +++ b/llm/src/llm/models/llama/llama.py @@ -13,96 +13,10 @@ from llm.core.rope import RoPE from llm.core.swi_glu import SwiGLU from llm.core.rms_norm import RMSNorm from llm.core.gelu import GELU +from llm.core.multi_head_attention import MultiHeadAttention -class HeadAttention(nn.Module): - - def __init__(self, emb_size: int, head_size: int, max_seq_len: int, rope: RoPE): - super().__init__() - self._emb_size = emb_size - self._head_size = head_size - self._max_seq_len = max_seq_len - self._rope = rope - - self._k = nn.Linear(emb_size, head_size) - self._q = nn.Linear(emb_size, head_size) - self._v = nn.Linear(emb_size, head_size) - - mask = torch.tril(torch.ones(max_seq_len, max_seq_len)) - self.register_buffer('_tril_mask', mask.bool() if hasattr(torch, 'bool') else mask.byte()) - - def forward(self, x: torch.Tensor, use_cache: bool = True, cache: tuple = None) -> tuple: - seq_len = x.shape[1] - if seq_len > self._max_seq_len: - raise ValueError(f"Длина последовательности {seq_len} превышает максимум {self._max_seq_len}") - - k = self._k(x) # [B, T, hs] - q = self._q(x) # [B, T, hs] - v = self._v(x) # [B, T, hs] - - # ✅ Применяем RoPE к Q и K (НЕ к V!) - q = self._rope(q) # [B, T, hs] - k = self._rope(k) # [B, T, hs] - - if cache is not None: - k_cache, v_cache = cache - k = torch.cat([k_cache, k], dim=1) # [B, cache_len + T, hs] - v = torch.cat([v_cache, v], dim=1) # [B, cache_len + T, hs] - - scores = q @ k.transpose(-2, -1) / sqrt(self._head_size) - - if cache is None: - scores = scores.masked_fill(~self._tril_mask[:seq_len, :seq_len], float('-inf')) - - weights = F.softmax(scores, dim=-1) - x_out = weights @ v # [B, T, hs] - - if use_cache is True: - return (x_out, (k, v)) - else: - return (x_out, None) - - -class MultiHeadAttention(nn.Module): - def __init__(self, num_heads: int, emb_size: int, head_size: int, max_seq_len: int, rope: RoPE, dropout: float = 0.1): - - super().__init__() - self._heads = nn.ModuleList([ - HeadAttention( - emb_size=emb_size, - head_size=head_size, - max_seq_len=max_seq_len, - rope=rope, - ) for _ in range(num_heads) - ]) - self._layer = nn.Linear(head_size * num_heads, emb_size) - self._dropout = nn.Dropout(dropout) - - def forward(self, x: torch.Tensor, mask: torch.Tensor = None, use_cache: bool = True, cache: list = None): - - attention_results = [] - for i, head in enumerate(self._heads): - head_cache = cache[i] if cache is not None else None - result = head(x, use_cache=use_cache, cache=head_cache) - attention_results.append(result) - - outputs, caches = zip(*attention_results) - attention_outputs = list(outputs) - kv_caches = list(caches) - - concatenated_attention = torch.cat(attention_outputs, dim=-1) - - projected_output = self._layer(concatenated_attention) - - final_output = self._dropout(projected_output) - - if use_cache is True: - return (final_output, kv_caches) - else: - return (final_output, None) - - -class Decoder(nn.Module): +class CachedDecoder(nn.Module): def __init__(self, num_heads: int, emb_size: int, @@ -155,7 +69,7 @@ class Llama(BaseModel): ) self._dropout = nn.Dropout(config["dropout"]) - self._decoders = nn.ModuleList([Decoder( + self._decoders = nn.ModuleList([CachedDecoder( num_heads=config["num_heads"], emb_size=config["embed_dim"], head_size=config["embed_dim"] // config["num_heads"],