mirror of
https://github.com/pese-git/llm-arch-research.git
synced 2026-01-23 21:10:54 +00:00
refactor: partial removal of duplicate code by using core modules
- Removed duplicate HeadAttention and MultiHeadAttention implementations from llama.py - Now importing MultiHeadAttention from core module - Added RoPE support parameter to core HeadAttention constructor - Kept LLaMA-specific CachedDecoder implementation (uses SwiGLU and RMSNorm) - Core CachedDecoder uses different components (FeedForward and LayerNorm) - Improved code reuse for attention components while maintaining LLaMA-specific decoder This is a partial refactor - attention components are now shared, but decoder remains LLaMA-specific due to different normalization and activation requirements.
This commit is contained in:
@@ -2,6 +2,7 @@ import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from math import sqrt
|
||||
from .rope import RoPE
|
||||
|
||||
class HeadAttention(nn.Module):
|
||||
"""
|
||||
@@ -30,11 +31,12 @@ class HeadAttention(nn.Module):
|
||||
- Автоматически адаптируется к разным версиям PyTorch
|
||||
- Поддерживает batch-обработку входных данных
|
||||
"""
|
||||
def __init__(self, emb_size: int, head_size: int, max_seq_len: int):
|
||||
def __init__(self, emb_size: int, head_size: int, max_seq_len: int, rope: RoPE = None):
|
||||
super().__init__()
|
||||
self._emb_size = emb_size
|
||||
self._head_size = head_size
|
||||
self._max_seq_len = max_seq_len
|
||||
self._rope = rope
|
||||
|
||||
# Линейные преобразования для Q, K, V
|
||||
self._k = nn.Linear(emb_size, head_size)
|
||||
@@ -73,6 +75,11 @@ class HeadAttention(nn.Module):
|
||||
q = self._q(x) # [B, T, hs]
|
||||
v = self._v(x) # [B, T, hs]
|
||||
|
||||
if self._rope is not None:
|
||||
# ✅ Применяем RoPE к Q и K (НЕ к V!)
|
||||
q = self._rope(q) # [B, T, hs]
|
||||
k = self._rope(k) # [B, T, hs]
|
||||
|
||||
if cache is not None:
|
||||
k_cache, v_cache = cache
|
||||
k = torch.cat([k_cache, k], dim=1) # [B, cache_len + T, hs]
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from torch import nn
|
||||
import torch
|
||||
from .head_attention import HeadAttention
|
||||
from .rope import RoPE
|
||||
|
||||
class MultiHeadAttention(nn.Module):
|
||||
"""
|
||||
@@ -31,7 +32,7 @@ class MultiHeadAttention(nn.Module):
|
||||
>>> self.attention = MultiHeadAttention(...)
|
||||
>>> x = self.attention(x, mask)
|
||||
"""
|
||||
def __init__(self, num_heads: int, emb_size: int, head_size: int, max_seq_len: int, dropout: float = 0.1):
|
||||
def __init__(self, num_heads: int, emb_size: int, head_size: int, max_seq_len: int, rope: RoPE = None, dropout: float = 0.1):
|
||||
"""
|
||||
Инициализация многоголового внимания.
|
||||
|
||||
@@ -52,7 +53,8 @@ class MultiHeadAttention(nn.Module):
|
||||
HeadAttention(
|
||||
emb_size=emb_size,
|
||||
head_size=head_size,
|
||||
max_seq_len=max_seq_len
|
||||
max_seq_len=max_seq_len,
|
||||
rope=rope,
|
||||
) for _ in range(num_heads)
|
||||
])
|
||||
self._layer = nn.Linear(head_size * num_heads, emb_size)
|
||||
|
||||
@@ -13,96 +13,10 @@ from llm.core.rope import RoPE
|
||||
from llm.core.swi_glu import SwiGLU
|
||||
from llm.core.rms_norm import RMSNorm
|
||||
from llm.core.gelu import GELU
|
||||
from llm.core.multi_head_attention import MultiHeadAttention
|
||||
|
||||
|
||||
class HeadAttention(nn.Module):
|
||||
|
||||
def __init__(self, emb_size: int, head_size: int, max_seq_len: int, rope: RoPE):
|
||||
super().__init__()
|
||||
self._emb_size = emb_size
|
||||
self._head_size = head_size
|
||||
self._max_seq_len = max_seq_len
|
||||
self._rope = rope
|
||||
|
||||
self._k = nn.Linear(emb_size, head_size)
|
||||
self._q = nn.Linear(emb_size, head_size)
|
||||
self._v = nn.Linear(emb_size, head_size)
|
||||
|
||||
mask = torch.tril(torch.ones(max_seq_len, max_seq_len))
|
||||
self.register_buffer('_tril_mask', mask.bool() if hasattr(torch, 'bool') else mask.byte())
|
||||
|
||||
def forward(self, x: torch.Tensor, use_cache: bool = True, cache: tuple = None) -> tuple:
|
||||
seq_len = x.shape[1]
|
||||
if seq_len > self._max_seq_len:
|
||||
raise ValueError(f"Длина последовательности {seq_len} превышает максимум {self._max_seq_len}")
|
||||
|
||||
k = self._k(x) # [B, T, hs]
|
||||
q = self._q(x) # [B, T, hs]
|
||||
v = self._v(x) # [B, T, hs]
|
||||
|
||||
# ✅ Применяем RoPE к Q и K (НЕ к V!)
|
||||
q = self._rope(q) # [B, T, hs]
|
||||
k = self._rope(k) # [B, T, hs]
|
||||
|
||||
if cache is not None:
|
||||
k_cache, v_cache = cache
|
||||
k = torch.cat([k_cache, k], dim=1) # [B, cache_len + T, hs]
|
||||
v = torch.cat([v_cache, v], dim=1) # [B, cache_len + T, hs]
|
||||
|
||||
scores = q @ k.transpose(-2, -1) / sqrt(self._head_size)
|
||||
|
||||
if cache is None:
|
||||
scores = scores.masked_fill(~self._tril_mask[:seq_len, :seq_len], float('-inf'))
|
||||
|
||||
weights = F.softmax(scores, dim=-1)
|
||||
x_out = weights @ v # [B, T, hs]
|
||||
|
||||
if use_cache is True:
|
||||
return (x_out, (k, v))
|
||||
else:
|
||||
return (x_out, None)
|
||||
|
||||
|
||||
class MultiHeadAttention(nn.Module):
|
||||
def __init__(self, num_heads: int, emb_size: int, head_size: int, max_seq_len: int, rope: RoPE, dropout: float = 0.1):
|
||||
|
||||
super().__init__()
|
||||
self._heads = nn.ModuleList([
|
||||
HeadAttention(
|
||||
emb_size=emb_size,
|
||||
head_size=head_size,
|
||||
max_seq_len=max_seq_len,
|
||||
rope=rope,
|
||||
) for _ in range(num_heads)
|
||||
])
|
||||
self._layer = nn.Linear(head_size * num_heads, emb_size)
|
||||
self._dropout = nn.Dropout(dropout)
|
||||
|
||||
def forward(self, x: torch.Tensor, mask: torch.Tensor = None, use_cache: bool = True, cache: list = None):
|
||||
|
||||
attention_results = []
|
||||
for i, head in enumerate(self._heads):
|
||||
head_cache = cache[i] if cache is not None else None
|
||||
result = head(x, use_cache=use_cache, cache=head_cache)
|
||||
attention_results.append(result)
|
||||
|
||||
outputs, caches = zip(*attention_results)
|
||||
attention_outputs = list(outputs)
|
||||
kv_caches = list(caches)
|
||||
|
||||
concatenated_attention = torch.cat(attention_outputs, dim=-1)
|
||||
|
||||
projected_output = self._layer(concatenated_attention)
|
||||
|
||||
final_output = self._dropout(projected_output)
|
||||
|
||||
if use_cache is True:
|
||||
return (final_output, kv_caches)
|
||||
else:
|
||||
return (final_output, None)
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
class CachedDecoder(nn.Module):
|
||||
def __init__(self,
|
||||
num_heads: int,
|
||||
emb_size: int,
|
||||
@@ -155,7 +69,7 @@ class Llama(BaseModel):
|
||||
)
|
||||
|
||||
self._dropout = nn.Dropout(config["dropout"])
|
||||
self._decoders = nn.ModuleList([Decoder(
|
||||
self._decoders = nn.ModuleList([CachedDecoder(
|
||||
num_heads=config["num_heads"],
|
||||
emb_size=config["embed_dim"],
|
||||
head_size=config["embed_dim"] // config["num_heads"],
|
||||
|
||||
Reference in New Issue
Block a user