refactor: partial removal of duplicate code by using core modules

- Removed duplicate HeadAttention and MultiHeadAttention implementations from llama.py
- Now importing MultiHeadAttention from core module
- Added RoPE support parameter to core HeadAttention constructor
- Kept LLaMA-specific CachedDecoder implementation (uses SwiGLU and RMSNorm)
- Core CachedDecoder uses different components (FeedForward and LayerNorm)
- Improved code reuse for attention components while maintaining LLaMA-specific decoder

This is a partial refactor - attention components are now shared, but decoder remains LLaMA-specific due to different normalization and activation requirements.
This commit is contained in:
Sergey Penkovsky
2025-10-06 14:22:44 +03:00
parent 211adf574c
commit d99d605b35
3 changed files with 15 additions and 92 deletions

View File

@@ -2,6 +2,7 @@ import torch
from torch import nn
import torch.nn.functional as F
from math import sqrt
from .rope import RoPE
class HeadAttention(nn.Module):
"""
@@ -30,11 +31,12 @@ class HeadAttention(nn.Module):
- Автоматически адаптируется к разным версиям PyTorch
- Поддерживает batch-обработку входных данных
"""
def __init__(self, emb_size: int, head_size: int, max_seq_len: int):
def __init__(self, emb_size: int, head_size: int, max_seq_len: int, rope: RoPE = None):
super().__init__()
self._emb_size = emb_size
self._head_size = head_size
self._max_seq_len = max_seq_len
self._rope = rope
# Линейные преобразования для Q, K, V
self._k = nn.Linear(emb_size, head_size)
@@ -73,6 +75,11 @@ class HeadAttention(nn.Module):
q = self._q(x) # [B, T, hs]
v = self._v(x) # [B, T, hs]
if self._rope is not None:
# ✅ Применяем RoPE к Q и K (НЕ к V!)
q = self._rope(q) # [B, T, hs]
k = self._rope(k) # [B, T, hs]
if cache is not None:
k_cache, v_cache = cache
k = torch.cat([k_cache, k], dim=1) # [B, cache_len + T, hs]

View File

@@ -1,6 +1,7 @@
from torch import nn
import torch
from .head_attention import HeadAttention
from .rope import RoPE
class MultiHeadAttention(nn.Module):
"""
@@ -31,7 +32,7 @@ class MultiHeadAttention(nn.Module):
>>> self.attention = MultiHeadAttention(...)
>>> x = self.attention(x, mask)
"""
def __init__(self, num_heads: int, emb_size: int, head_size: int, max_seq_len: int, dropout: float = 0.1):
def __init__(self, num_heads: int, emb_size: int, head_size: int, max_seq_len: int, rope: RoPE = None, dropout: float = 0.1):
"""
Инициализация многоголового внимания.
@@ -52,7 +53,8 @@ class MultiHeadAttention(nn.Module):
HeadAttention(
emb_size=emb_size,
head_size=head_size,
max_seq_len=max_seq_len
max_seq_len=max_seq_len,
rope=rope,
) for _ in range(num_heads)
])
self._layer = nn.Linear(head_size * num_heads, emb_size)

View File

@@ -13,96 +13,10 @@ from llm.core.rope import RoPE
from llm.core.swi_glu import SwiGLU
from llm.core.rms_norm import RMSNorm
from llm.core.gelu import GELU
from llm.core.multi_head_attention import MultiHeadAttention
class HeadAttention(nn.Module):
def __init__(self, emb_size: int, head_size: int, max_seq_len: int, rope: RoPE):
super().__init__()
self._emb_size = emb_size
self._head_size = head_size
self._max_seq_len = max_seq_len
self._rope = rope
self._k = nn.Linear(emb_size, head_size)
self._q = nn.Linear(emb_size, head_size)
self._v = nn.Linear(emb_size, head_size)
mask = torch.tril(torch.ones(max_seq_len, max_seq_len))
self.register_buffer('_tril_mask', mask.bool() if hasattr(torch, 'bool') else mask.byte())
def forward(self, x: torch.Tensor, use_cache: bool = True, cache: tuple = None) -> tuple:
seq_len = x.shape[1]
if seq_len > self._max_seq_len:
raise ValueError(f"Длина последовательности {seq_len} превышает максимум {self._max_seq_len}")
k = self._k(x) # [B, T, hs]
q = self._q(x) # [B, T, hs]
v = self._v(x) # [B, T, hs]
# ✅ Применяем RoPE к Q и K (НЕ к V!)
q = self._rope(q) # [B, T, hs]
k = self._rope(k) # [B, T, hs]
if cache is not None:
k_cache, v_cache = cache
k = torch.cat([k_cache, k], dim=1) # [B, cache_len + T, hs]
v = torch.cat([v_cache, v], dim=1) # [B, cache_len + T, hs]
scores = q @ k.transpose(-2, -1) / sqrt(self._head_size)
if cache is None:
scores = scores.masked_fill(~self._tril_mask[:seq_len, :seq_len], float('-inf'))
weights = F.softmax(scores, dim=-1)
x_out = weights @ v # [B, T, hs]
if use_cache is True:
return (x_out, (k, v))
else:
return (x_out, None)
class MultiHeadAttention(nn.Module):
def __init__(self, num_heads: int, emb_size: int, head_size: int, max_seq_len: int, rope: RoPE, dropout: float = 0.1):
super().__init__()
self._heads = nn.ModuleList([
HeadAttention(
emb_size=emb_size,
head_size=head_size,
max_seq_len=max_seq_len,
rope=rope,
) for _ in range(num_heads)
])
self._layer = nn.Linear(head_size * num_heads, emb_size)
self._dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor, mask: torch.Tensor = None, use_cache: bool = True, cache: list = None):
attention_results = []
for i, head in enumerate(self._heads):
head_cache = cache[i] if cache is not None else None
result = head(x, use_cache=use_cache, cache=head_cache)
attention_results.append(result)
outputs, caches = zip(*attention_results)
attention_outputs = list(outputs)
kv_caches = list(caches)
concatenated_attention = torch.cat(attention_outputs, dim=-1)
projected_output = self._layer(concatenated_attention)
final_output = self._dropout(projected_output)
if use_cache is True:
return (final_output, kv_caches)
else:
return (final_output, None)
class Decoder(nn.Module):
class CachedDecoder(nn.Module):
def __init__(self,
num_heads: int,
emb_size: int,
@@ -155,7 +69,7 @@ class Llama(BaseModel):
)
self._dropout = nn.Dropout(config["dropout"])
self._decoders = nn.ModuleList([Decoder(
self._decoders = nn.ModuleList([CachedDecoder(
num_heads=config["num_heads"],
emb_size=config["embed_dim"],
head_size=config["embed_dim"] // config["num_heads"],