From 3bc2848cf0bbf9af32e4b19b1d99fd83cd59db2f Mon Sep 17 00:00:00 2001 From: Sergey Penkovsky Date: Mon, 6 Oct 2025 14:57:29 +0300 Subject: [PATCH] refactor: unify CachedDecoder implementation across models - Completely removed duplicate CachedDecoder from llama.py - Modified core CachedDecoder to support dependency injection: - Added feed_forward_layer parameter (required) - Added norm_layer parameter with LayerNorm default - Added rope parameter for RoPE support - Removed unused activation parameter - Updated GPT2 to use new CachedDecoder with FeedForward - Updated LLaMA to use new CachedDecoder with SwiGLU and RMSNorm - Fixed parameter order in constructor to follow Python syntax rules This eliminates all code duplication while maintaining architectural specificities through dependency injection. --- llm/src/llm/core/cached_decoder.py | 11 +++++-- llm/src/llm/models/gpt/gpt2.py | 7 +++- llm/src/llm/models/llama/llama.py | 53 +++++------------------------- 3 files changed, 22 insertions(+), 49 deletions(-) diff --git a/llm/src/llm/core/cached_decoder.py b/llm/src/llm/core/cached_decoder.py index ceeff3b..8658683 100644 --- a/llm/src/llm/core/cached_decoder.py +++ b/llm/src/llm/core/cached_decoder.py @@ -4,6 +4,7 @@ import torch from torch import nn from .feed_forward import FeedForward from .multi_head_attention import MultiHeadAttention +from .rope import RoPE class CachedDecoder(nn.Module): """ @@ -12,11 +13,14 @@ class CachedDecoder(nn.Module): """ def __init__( self, + feed_forward_layer: nn.Module, # Обязательный параметр num_heads: int, emb_size: int, head_size: int, max_seq_len: int, dropout: float = 0.1, + norm_layer: type = nn.LayerNorm, # Класс + rope: RoPE = None, activation: str = "gelu", ): super().__init__() @@ -25,11 +29,12 @@ class CachedDecoder(nn.Module): emb_size=emb_size, head_size=head_size, max_seq_len=max_seq_len, + rope=rope, dropout=dropout, ) - self._ff = FeedForward(emb_size=emb_size, dropout=dropout, activation=activation) - self._norm1 = nn.LayerNorm(emb_size) - self._norm2 = nn.LayerNorm(emb_size) + self._ff = feed_forward_layer + self._norm1 = norm_layer(emb_size) + self._norm2 = norm_layer(emb_size) def forward( self, diff --git a/llm/src/llm/models/gpt/gpt2.py b/llm/src/llm/models/gpt/gpt2.py index e5036a7..09726b4 100644 --- a/llm/src/llm/models/gpt/gpt2.py +++ b/llm/src/llm/models/gpt/gpt2.py @@ -5,7 +5,7 @@ from llm.core.base_model import BaseModel from llm.core.token_embeddings import TokenEmbeddings from llm.core.positional_embeddings import PositionalEmbeddings from llm.core.cached_decoder import CachedDecoder - +from llm.core.feed_forward import FeedForward class GPT2(BaseModel): def __init__(self, config): @@ -27,6 +27,11 @@ class GPT2(BaseModel): num_heads=config["num_heads"], emb_size=config["embed_dim"], head_size=config["embed_dim"] // config["num_heads"], + feed_forward_layer=FeedForward( + emb_size=config["embed_dim"], + dropout=config["dropout"], + activation="gelu" + ), max_seq_len=config["max_position_embeddings"], dropout=config["dropout"] ) for _ in range(config["num_layers"])]) diff --git a/llm/src/llm/models/llama/llama.py b/llm/src/llm/models/llama/llama.py index 2b45f68..240d9a3 100644 --- a/llm/src/llm/models/llama/llama.py +++ b/llm/src/llm/models/llama/llama.py @@ -1,55 +1,13 @@ import torch -from torch import nn -from torch import Tensor +from torch import nn, Tensor import torch.nn.functional as F -from math import sqrt -from torch import nn -import torch -import math from llm.core.base_model import BaseModel from llm.core.token_embeddings import TokenEmbeddings -from llm.core.rope import RoPE from llm.core.swi_glu import SwiGLU from llm.core.rms_norm import RMSNorm -from llm.core.gelu import GELU -from llm.core.multi_head_attention import MultiHeadAttention - - -class CachedDecoder(nn.Module): - def __init__(self, - num_heads: int, - emb_size: int, - head_size: int, - max_seq_len: int, - rope: RoPE, - dropout: float = 0.1 - ): - super().__init__() - self._heads = MultiHeadAttention( - num_heads=num_heads, - emb_size=emb_size, - head_size=head_size, - max_seq_len=max_seq_len, - rope=rope, - dropout=dropout - ) - self._ff = SwiGLU(emb_size=emb_size, dropout=dropout) - self._norm1 = RMSNorm(emb_size) - self._norm2 = RMSNorm(emb_size) - - def forward(self, x: torch.Tensor, mask: torch.Tensor = None, use_cache: bool = True, cache: list = None) -> torch.Tensor: - norm1_out = self._norm1(x) - attention, kv_caches = self._heads(norm1_out, mask, use_cache=use_cache, cache=cache) - out = attention + x - - norm2_out = self._norm2(out) - ffn_out = self._ff(norm2_out) - - if use_cache is True: - return (ffn_out + out, kv_caches) - else: - return (ffn_out + out, None) +from llm.core.rope import RoPE +from llm.core.cached_decoder import CachedDecoder @@ -70,9 +28,14 @@ class Llama(BaseModel): self._dropout = nn.Dropout(config["dropout"]) self._decoders = nn.ModuleList([CachedDecoder( + norm_layer=RMSNorm, num_heads=config["num_heads"], emb_size=config["embed_dim"], head_size=config["embed_dim"] // config["num_heads"], + feed_forward_layer=SwiGLU( + emb_size=config["embed_dim"], + dropout=config["dropout"], + ), max_seq_len=config["max_position_embeddings"], rope=self._position_embeddings, dropout=config["dropout"],