refactor: unify CachedDecoder implementation across models

- Completely removed duplicate CachedDecoder from llama.py
- Modified core CachedDecoder to support dependency injection:
  - Added feed_forward_layer parameter (required)
  - Added norm_layer parameter with LayerNorm default
  - Added rope parameter for RoPE support
  - Removed unused activation parameter
- Updated GPT2 to use new CachedDecoder with FeedForward
- Updated LLaMA to use new CachedDecoder with SwiGLU and RMSNorm
- Fixed parameter order in constructor to follow Python syntax rules

This eliminates all code duplication while maintaining architectural specificities through dependency injection.
This commit is contained in:
Sergey Penkovsky
2025-10-06 14:57:29 +03:00
parent d99d605b35
commit 3bc2848cf0
3 changed files with 22 additions and 49 deletions

View File

@@ -4,6 +4,7 @@ import torch
from torch import nn from torch import nn
from .feed_forward import FeedForward from .feed_forward import FeedForward
from .multi_head_attention import MultiHeadAttention from .multi_head_attention import MultiHeadAttention
from .rope import RoPE
class CachedDecoder(nn.Module): class CachedDecoder(nn.Module):
""" """
@@ -12,11 +13,14 @@ class CachedDecoder(nn.Module):
""" """
def __init__( def __init__(
self, self,
feed_forward_layer: nn.Module, # Обязательный параметр
num_heads: int, num_heads: int,
emb_size: int, emb_size: int,
head_size: int, head_size: int,
max_seq_len: int, max_seq_len: int,
dropout: float = 0.1, dropout: float = 0.1,
norm_layer: type = nn.LayerNorm, # Класс
rope: RoPE = None,
activation: str = "gelu", activation: str = "gelu",
): ):
super().__init__() super().__init__()
@@ -25,11 +29,12 @@ class CachedDecoder(nn.Module):
emb_size=emb_size, emb_size=emb_size,
head_size=head_size, head_size=head_size,
max_seq_len=max_seq_len, max_seq_len=max_seq_len,
rope=rope,
dropout=dropout, dropout=dropout,
) )
self._ff = FeedForward(emb_size=emb_size, dropout=dropout, activation=activation) self._ff = feed_forward_layer
self._norm1 = nn.LayerNorm(emb_size) self._norm1 = norm_layer(emb_size)
self._norm2 = nn.LayerNorm(emb_size) self._norm2 = norm_layer(emb_size)
def forward( def forward(
self, self,

View File

@@ -5,7 +5,7 @@ from llm.core.base_model import BaseModel
from llm.core.token_embeddings import TokenEmbeddings from llm.core.token_embeddings import TokenEmbeddings
from llm.core.positional_embeddings import PositionalEmbeddings from llm.core.positional_embeddings import PositionalEmbeddings
from llm.core.cached_decoder import CachedDecoder from llm.core.cached_decoder import CachedDecoder
from llm.core.feed_forward import FeedForward
class GPT2(BaseModel): class GPT2(BaseModel):
def __init__(self, config): def __init__(self, config):
@@ -27,6 +27,11 @@ class GPT2(BaseModel):
num_heads=config["num_heads"], num_heads=config["num_heads"],
emb_size=config["embed_dim"], emb_size=config["embed_dim"],
head_size=config["embed_dim"] // config["num_heads"], head_size=config["embed_dim"] // config["num_heads"],
feed_forward_layer=FeedForward(
emb_size=config["embed_dim"],
dropout=config["dropout"],
activation="gelu"
),
max_seq_len=config["max_position_embeddings"], max_seq_len=config["max_position_embeddings"],
dropout=config["dropout"] dropout=config["dropout"]
) for _ in range(config["num_layers"])]) ) for _ in range(config["num_layers"])])

View File

@@ -1,55 +1,13 @@
import torch import torch
from torch import nn from torch import nn, Tensor
from torch import Tensor
import torch.nn.functional as F import torch.nn.functional as F
from math import sqrt
from torch import nn
import torch
import math
from llm.core.base_model import BaseModel from llm.core.base_model import BaseModel
from llm.core.token_embeddings import TokenEmbeddings from llm.core.token_embeddings import TokenEmbeddings
from llm.core.rope import RoPE
from llm.core.swi_glu import SwiGLU from llm.core.swi_glu import SwiGLU
from llm.core.rms_norm import RMSNorm from llm.core.rms_norm import RMSNorm
from llm.core.gelu import GELU from llm.core.rope import RoPE
from llm.core.multi_head_attention import MultiHeadAttention from llm.core.cached_decoder import CachedDecoder
class CachedDecoder(nn.Module):
def __init__(self,
num_heads: int,
emb_size: int,
head_size: int,
max_seq_len: int,
rope: RoPE,
dropout: float = 0.1
):
super().__init__()
self._heads = MultiHeadAttention(
num_heads=num_heads,
emb_size=emb_size,
head_size=head_size,
max_seq_len=max_seq_len,
rope=rope,
dropout=dropout
)
self._ff = SwiGLU(emb_size=emb_size, dropout=dropout)
self._norm1 = RMSNorm(emb_size)
self._norm2 = RMSNorm(emb_size)
def forward(self, x: torch.Tensor, mask: torch.Tensor = None, use_cache: bool = True, cache: list = None) -> torch.Tensor:
norm1_out = self._norm1(x)
attention, kv_caches = self._heads(norm1_out, mask, use_cache=use_cache, cache=cache)
out = attention + x
norm2_out = self._norm2(out)
ffn_out = self._ff(norm2_out)
if use_cache is True:
return (ffn_out + out, kv_caches)
else:
return (ffn_out + out, None)
@@ -70,9 +28,14 @@ class Llama(BaseModel):
self._dropout = nn.Dropout(config["dropout"]) self._dropout = nn.Dropout(config["dropout"])
self._decoders = nn.ModuleList([CachedDecoder( self._decoders = nn.ModuleList([CachedDecoder(
norm_layer=RMSNorm,
num_heads=config["num_heads"], num_heads=config["num_heads"],
emb_size=config["embed_dim"], emb_size=config["embed_dim"],
head_size=config["embed_dim"] // config["num_heads"], head_size=config["embed_dim"] // config["num_heads"],
feed_forward_layer=SwiGLU(
emb_size=config["embed_dim"],
dropout=config["dropout"],
),
max_seq_len=config["max_position_embeddings"], max_seq_len=config["max_position_embeddings"],
rope=self._position_embeddings, rope=self._position_embeddings,
dropout=config["dropout"], dropout=config["dropout"],