mirror of
https://github.com/pese-git/llm-arch-research.git
synced 2026-01-24 05:21:16 +00:00
refactor: partial removal of duplicate code by using core modules
- Removed duplicate HeadAttention and MultiHeadAttention implementations from llama.py - Now importing MultiHeadAttention from core module - Added RoPE support parameter to core HeadAttention constructor - Kept LLaMA-specific CachedDecoder implementation (uses SwiGLU and RMSNorm) - Core CachedDecoder uses different components (FeedForward and LayerNorm) - Improved code reuse for attention components while maintaining LLaMA-specific decoder This is a partial refactor - attention components are now shared, but decoder remains LLaMA-specific due to different normalization and activation requirements.
This commit is contained in:
@@ -2,6 +2,7 @@ import torch
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from math import sqrt
|
from math import sqrt
|
||||||
|
from .rope import RoPE
|
||||||
|
|
||||||
class HeadAttention(nn.Module):
|
class HeadAttention(nn.Module):
|
||||||
"""
|
"""
|
||||||
@@ -30,11 +31,12 @@ class HeadAttention(nn.Module):
|
|||||||
- Автоматически адаптируется к разным версиям PyTorch
|
- Автоматически адаптируется к разным версиям PyTorch
|
||||||
- Поддерживает batch-обработку входных данных
|
- Поддерживает batch-обработку входных данных
|
||||||
"""
|
"""
|
||||||
def __init__(self, emb_size: int, head_size: int, max_seq_len: int):
|
def __init__(self, emb_size: int, head_size: int, max_seq_len: int, rope: RoPE = None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._emb_size = emb_size
|
self._emb_size = emb_size
|
||||||
self._head_size = head_size
|
self._head_size = head_size
|
||||||
self._max_seq_len = max_seq_len
|
self._max_seq_len = max_seq_len
|
||||||
|
self._rope = rope
|
||||||
|
|
||||||
# Линейные преобразования для Q, K, V
|
# Линейные преобразования для Q, K, V
|
||||||
self._k = nn.Linear(emb_size, head_size)
|
self._k = nn.Linear(emb_size, head_size)
|
||||||
@@ -73,6 +75,11 @@ class HeadAttention(nn.Module):
|
|||||||
q = self._q(x) # [B, T, hs]
|
q = self._q(x) # [B, T, hs]
|
||||||
v = self._v(x) # [B, T, hs]
|
v = self._v(x) # [B, T, hs]
|
||||||
|
|
||||||
|
if self._rope is not None:
|
||||||
|
# ✅ Применяем RoPE к Q и K (НЕ к V!)
|
||||||
|
q = self._rope(q) # [B, T, hs]
|
||||||
|
k = self._rope(k) # [B, T, hs]
|
||||||
|
|
||||||
if cache is not None:
|
if cache is not None:
|
||||||
k_cache, v_cache = cache
|
k_cache, v_cache = cache
|
||||||
k = torch.cat([k_cache, k], dim=1) # [B, cache_len + T, hs]
|
k = torch.cat([k_cache, k], dim=1) # [B, cache_len + T, hs]
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
import torch
|
import torch
|
||||||
from .head_attention import HeadAttention
|
from .head_attention import HeadAttention
|
||||||
|
from .rope import RoPE
|
||||||
|
|
||||||
class MultiHeadAttention(nn.Module):
|
class MultiHeadAttention(nn.Module):
|
||||||
"""
|
"""
|
||||||
@@ -31,7 +32,7 @@ class MultiHeadAttention(nn.Module):
|
|||||||
>>> self.attention = MultiHeadAttention(...)
|
>>> self.attention = MultiHeadAttention(...)
|
||||||
>>> x = self.attention(x, mask)
|
>>> x = self.attention(x, mask)
|
||||||
"""
|
"""
|
||||||
def __init__(self, num_heads: int, emb_size: int, head_size: int, max_seq_len: int, dropout: float = 0.1):
|
def __init__(self, num_heads: int, emb_size: int, head_size: int, max_seq_len: int, rope: RoPE = None, dropout: float = 0.1):
|
||||||
"""
|
"""
|
||||||
Инициализация многоголового внимания.
|
Инициализация многоголового внимания.
|
||||||
|
|
||||||
@@ -52,7 +53,8 @@ class MultiHeadAttention(nn.Module):
|
|||||||
HeadAttention(
|
HeadAttention(
|
||||||
emb_size=emb_size,
|
emb_size=emb_size,
|
||||||
head_size=head_size,
|
head_size=head_size,
|
||||||
max_seq_len=max_seq_len
|
max_seq_len=max_seq_len,
|
||||||
|
rope=rope,
|
||||||
) for _ in range(num_heads)
|
) for _ in range(num_heads)
|
||||||
])
|
])
|
||||||
self._layer = nn.Linear(head_size * num_heads, emb_size)
|
self._layer = nn.Linear(head_size * num_heads, emb_size)
|
||||||
|
|||||||
@@ -13,96 +13,10 @@ from llm.core.rope import RoPE
|
|||||||
from llm.core.swi_glu import SwiGLU
|
from llm.core.swi_glu import SwiGLU
|
||||||
from llm.core.rms_norm import RMSNorm
|
from llm.core.rms_norm import RMSNorm
|
||||||
from llm.core.gelu import GELU
|
from llm.core.gelu import GELU
|
||||||
|
from llm.core.multi_head_attention import MultiHeadAttention
|
||||||
|
|
||||||
|
|
||||||
class HeadAttention(nn.Module):
|
class CachedDecoder(nn.Module):
|
||||||
|
|
||||||
def __init__(self, emb_size: int, head_size: int, max_seq_len: int, rope: RoPE):
|
|
||||||
super().__init__()
|
|
||||||
self._emb_size = emb_size
|
|
||||||
self._head_size = head_size
|
|
||||||
self._max_seq_len = max_seq_len
|
|
||||||
self._rope = rope
|
|
||||||
|
|
||||||
self._k = nn.Linear(emb_size, head_size)
|
|
||||||
self._q = nn.Linear(emb_size, head_size)
|
|
||||||
self._v = nn.Linear(emb_size, head_size)
|
|
||||||
|
|
||||||
mask = torch.tril(torch.ones(max_seq_len, max_seq_len))
|
|
||||||
self.register_buffer('_tril_mask', mask.bool() if hasattr(torch, 'bool') else mask.byte())
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, use_cache: bool = True, cache: tuple = None) -> tuple:
|
|
||||||
seq_len = x.shape[1]
|
|
||||||
if seq_len > self._max_seq_len:
|
|
||||||
raise ValueError(f"Длина последовательности {seq_len} превышает максимум {self._max_seq_len}")
|
|
||||||
|
|
||||||
k = self._k(x) # [B, T, hs]
|
|
||||||
q = self._q(x) # [B, T, hs]
|
|
||||||
v = self._v(x) # [B, T, hs]
|
|
||||||
|
|
||||||
# ✅ Применяем RoPE к Q и K (НЕ к V!)
|
|
||||||
q = self._rope(q) # [B, T, hs]
|
|
||||||
k = self._rope(k) # [B, T, hs]
|
|
||||||
|
|
||||||
if cache is not None:
|
|
||||||
k_cache, v_cache = cache
|
|
||||||
k = torch.cat([k_cache, k], dim=1) # [B, cache_len + T, hs]
|
|
||||||
v = torch.cat([v_cache, v], dim=1) # [B, cache_len + T, hs]
|
|
||||||
|
|
||||||
scores = q @ k.transpose(-2, -1) / sqrt(self._head_size)
|
|
||||||
|
|
||||||
if cache is None:
|
|
||||||
scores = scores.masked_fill(~self._tril_mask[:seq_len, :seq_len], float('-inf'))
|
|
||||||
|
|
||||||
weights = F.softmax(scores, dim=-1)
|
|
||||||
x_out = weights @ v # [B, T, hs]
|
|
||||||
|
|
||||||
if use_cache is True:
|
|
||||||
return (x_out, (k, v))
|
|
||||||
else:
|
|
||||||
return (x_out, None)
|
|
||||||
|
|
||||||
|
|
||||||
class MultiHeadAttention(nn.Module):
|
|
||||||
def __init__(self, num_heads: int, emb_size: int, head_size: int, max_seq_len: int, rope: RoPE, dropout: float = 0.1):
|
|
||||||
|
|
||||||
super().__init__()
|
|
||||||
self._heads = nn.ModuleList([
|
|
||||||
HeadAttention(
|
|
||||||
emb_size=emb_size,
|
|
||||||
head_size=head_size,
|
|
||||||
max_seq_len=max_seq_len,
|
|
||||||
rope=rope,
|
|
||||||
) for _ in range(num_heads)
|
|
||||||
])
|
|
||||||
self._layer = nn.Linear(head_size * num_heads, emb_size)
|
|
||||||
self._dropout = nn.Dropout(dropout)
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, mask: torch.Tensor = None, use_cache: bool = True, cache: list = None):
|
|
||||||
|
|
||||||
attention_results = []
|
|
||||||
for i, head in enumerate(self._heads):
|
|
||||||
head_cache = cache[i] if cache is not None else None
|
|
||||||
result = head(x, use_cache=use_cache, cache=head_cache)
|
|
||||||
attention_results.append(result)
|
|
||||||
|
|
||||||
outputs, caches = zip(*attention_results)
|
|
||||||
attention_outputs = list(outputs)
|
|
||||||
kv_caches = list(caches)
|
|
||||||
|
|
||||||
concatenated_attention = torch.cat(attention_outputs, dim=-1)
|
|
||||||
|
|
||||||
projected_output = self._layer(concatenated_attention)
|
|
||||||
|
|
||||||
final_output = self._dropout(projected_output)
|
|
||||||
|
|
||||||
if use_cache is True:
|
|
||||||
return (final_output, kv_caches)
|
|
||||||
else:
|
|
||||||
return (final_output, None)
|
|
||||||
|
|
||||||
|
|
||||||
class Decoder(nn.Module):
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
num_heads: int,
|
num_heads: int,
|
||||||
emb_size: int,
|
emb_size: int,
|
||||||
@@ -155,7 +69,7 @@ class Llama(BaseModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self._dropout = nn.Dropout(config["dropout"])
|
self._dropout = nn.Dropout(config["dropout"])
|
||||||
self._decoders = nn.ModuleList([Decoder(
|
self._decoders = nn.ModuleList([CachedDecoder(
|
||||||
num_heads=config["num_heads"],
|
num_heads=config["num_heads"],
|
||||||
emb_size=config["embed_dim"],
|
emb_size=config["embed_dim"],
|
||||||
head_size=config["embed_dim"] // config["num_heads"],
|
head_size=config["embed_dim"] // config["num_heads"],
|
||||||
|
|||||||
Reference in New Issue
Block a user