mirror of
https://github.com/pese-git/llm-arch-research.git
synced 2026-01-23 21:10:54 +00:00
refactor: unify CachedDecoder implementation across models
- Completely removed duplicate CachedDecoder from llama.py - Modified core CachedDecoder to support dependency injection: - Added feed_forward_layer parameter (required) - Added norm_layer parameter with LayerNorm default - Added rope parameter for RoPE support - Removed unused activation parameter - Updated GPT2 to use new CachedDecoder with FeedForward - Updated LLaMA to use new CachedDecoder with SwiGLU and RMSNorm - Fixed parameter order in constructor to follow Python syntax rules This eliminates all code duplication while maintaining architectural specificities through dependency injection.
This commit is contained in:
@@ -4,6 +4,7 @@ import torch
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from .feed_forward import FeedForward
|
from .feed_forward import FeedForward
|
||||||
from .multi_head_attention import MultiHeadAttention
|
from .multi_head_attention import MultiHeadAttention
|
||||||
|
from .rope import RoPE
|
||||||
|
|
||||||
class CachedDecoder(nn.Module):
|
class CachedDecoder(nn.Module):
|
||||||
"""
|
"""
|
||||||
@@ -12,11 +13,14 @@ class CachedDecoder(nn.Module):
|
|||||||
"""
|
"""
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
feed_forward_layer: nn.Module, # Обязательный параметр
|
||||||
num_heads: int,
|
num_heads: int,
|
||||||
emb_size: int,
|
emb_size: int,
|
||||||
head_size: int,
|
head_size: int,
|
||||||
max_seq_len: int,
|
max_seq_len: int,
|
||||||
dropout: float = 0.1,
|
dropout: float = 0.1,
|
||||||
|
norm_layer: type = nn.LayerNorm, # Класс
|
||||||
|
rope: RoPE = None,
|
||||||
activation: str = "gelu",
|
activation: str = "gelu",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -25,11 +29,12 @@ class CachedDecoder(nn.Module):
|
|||||||
emb_size=emb_size,
|
emb_size=emb_size,
|
||||||
head_size=head_size,
|
head_size=head_size,
|
||||||
max_seq_len=max_seq_len,
|
max_seq_len=max_seq_len,
|
||||||
|
rope=rope,
|
||||||
dropout=dropout,
|
dropout=dropout,
|
||||||
)
|
)
|
||||||
self._ff = FeedForward(emb_size=emb_size, dropout=dropout, activation=activation)
|
self._ff = feed_forward_layer
|
||||||
self._norm1 = nn.LayerNorm(emb_size)
|
self._norm1 = norm_layer(emb_size)
|
||||||
self._norm2 = nn.LayerNorm(emb_size)
|
self._norm2 = norm_layer(emb_size)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from llm.core.base_model import BaseModel
|
|||||||
from llm.core.token_embeddings import TokenEmbeddings
|
from llm.core.token_embeddings import TokenEmbeddings
|
||||||
from llm.core.positional_embeddings import PositionalEmbeddings
|
from llm.core.positional_embeddings import PositionalEmbeddings
|
||||||
from llm.core.cached_decoder import CachedDecoder
|
from llm.core.cached_decoder import CachedDecoder
|
||||||
|
from llm.core.feed_forward import FeedForward
|
||||||
|
|
||||||
class GPT2(BaseModel):
|
class GPT2(BaseModel):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
@@ -27,6 +27,11 @@ class GPT2(BaseModel):
|
|||||||
num_heads=config["num_heads"],
|
num_heads=config["num_heads"],
|
||||||
emb_size=config["embed_dim"],
|
emb_size=config["embed_dim"],
|
||||||
head_size=config["embed_dim"] // config["num_heads"],
|
head_size=config["embed_dim"] // config["num_heads"],
|
||||||
|
feed_forward_layer=FeedForward(
|
||||||
|
emb_size=config["embed_dim"],
|
||||||
|
dropout=config["dropout"],
|
||||||
|
activation="gelu"
|
||||||
|
),
|
||||||
max_seq_len=config["max_position_embeddings"],
|
max_seq_len=config["max_position_embeddings"],
|
||||||
dropout=config["dropout"]
|
dropout=config["dropout"]
|
||||||
) for _ in range(config["num_layers"])])
|
) for _ in range(config["num_layers"])])
|
||||||
|
|||||||
@@ -1,55 +1,13 @@
|
|||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn, Tensor
|
||||||
from torch import Tensor
|
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from math import sqrt
|
|
||||||
from torch import nn
|
|
||||||
import torch
|
|
||||||
import math
|
|
||||||
|
|
||||||
from llm.core.base_model import BaseModel
|
from llm.core.base_model import BaseModel
|
||||||
from llm.core.token_embeddings import TokenEmbeddings
|
from llm.core.token_embeddings import TokenEmbeddings
|
||||||
from llm.core.rope import RoPE
|
|
||||||
from llm.core.swi_glu import SwiGLU
|
from llm.core.swi_glu import SwiGLU
|
||||||
from llm.core.rms_norm import RMSNorm
|
from llm.core.rms_norm import RMSNorm
|
||||||
from llm.core.gelu import GELU
|
from llm.core.rope import RoPE
|
||||||
from llm.core.multi_head_attention import MultiHeadAttention
|
from llm.core.cached_decoder import CachedDecoder
|
||||||
|
|
||||||
|
|
||||||
class CachedDecoder(nn.Module):
|
|
||||||
def __init__(self,
|
|
||||||
num_heads: int,
|
|
||||||
emb_size: int,
|
|
||||||
head_size: int,
|
|
||||||
max_seq_len: int,
|
|
||||||
rope: RoPE,
|
|
||||||
dropout: float = 0.1
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self._heads = MultiHeadAttention(
|
|
||||||
num_heads=num_heads,
|
|
||||||
emb_size=emb_size,
|
|
||||||
head_size=head_size,
|
|
||||||
max_seq_len=max_seq_len,
|
|
||||||
rope=rope,
|
|
||||||
dropout=dropout
|
|
||||||
)
|
|
||||||
self._ff = SwiGLU(emb_size=emb_size, dropout=dropout)
|
|
||||||
self._norm1 = RMSNorm(emb_size)
|
|
||||||
self._norm2 = RMSNorm(emb_size)
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, mask: torch.Tensor = None, use_cache: bool = True, cache: list = None) -> torch.Tensor:
|
|
||||||
norm1_out = self._norm1(x)
|
|
||||||
attention, kv_caches = self._heads(norm1_out, mask, use_cache=use_cache, cache=cache)
|
|
||||||
out = attention + x
|
|
||||||
|
|
||||||
norm2_out = self._norm2(out)
|
|
||||||
ffn_out = self._ff(norm2_out)
|
|
||||||
|
|
||||||
if use_cache is True:
|
|
||||||
return (ffn_out + out, kv_caches)
|
|
||||||
else:
|
|
||||||
return (ffn_out + out, None)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@@ -70,9 +28,14 @@ class Llama(BaseModel):
|
|||||||
|
|
||||||
self._dropout = nn.Dropout(config["dropout"])
|
self._dropout = nn.Dropout(config["dropout"])
|
||||||
self._decoders = nn.ModuleList([CachedDecoder(
|
self._decoders = nn.ModuleList([CachedDecoder(
|
||||||
|
norm_layer=RMSNorm,
|
||||||
num_heads=config["num_heads"],
|
num_heads=config["num_heads"],
|
||||||
emb_size=config["embed_dim"],
|
emb_size=config["embed_dim"],
|
||||||
head_size=config["embed_dim"] // config["num_heads"],
|
head_size=config["embed_dim"] // config["num_heads"],
|
||||||
|
feed_forward_layer=SwiGLU(
|
||||||
|
emb_size=config["embed_dim"],
|
||||||
|
dropout=config["dropout"],
|
||||||
|
),
|
||||||
max_seq_len=config["max_position_embeddings"],
|
max_seq_len=config["max_position_embeddings"],
|
||||||
rope=self._position_embeddings,
|
rope=self._position_embeddings,
|
||||||
dropout=config["dropout"],
|
dropout=config["dropout"],
|
||||||
|
|||||||
Reference in New Issue
Block a user