diff --git a/llm/src/llm/core/cached_decoder.py b/llm/src/llm/core/cached_decoder.py index ceeff3b..8658683 100644 --- a/llm/src/llm/core/cached_decoder.py +++ b/llm/src/llm/core/cached_decoder.py @@ -4,6 +4,7 @@ import torch from torch import nn from .feed_forward import FeedForward from .multi_head_attention import MultiHeadAttention +from .rope import RoPE class CachedDecoder(nn.Module): """ @@ -12,11 +13,14 @@ class CachedDecoder(nn.Module): """ def __init__( self, + feed_forward_layer: nn.Module, # Обязательный параметр num_heads: int, emb_size: int, head_size: int, max_seq_len: int, dropout: float = 0.1, + norm_layer: type = nn.LayerNorm, # Класс + rope: RoPE = None, activation: str = "gelu", ): super().__init__() @@ -25,11 +29,12 @@ class CachedDecoder(nn.Module): emb_size=emb_size, head_size=head_size, max_seq_len=max_seq_len, + rope=rope, dropout=dropout, ) - self._ff = FeedForward(emb_size=emb_size, dropout=dropout, activation=activation) - self._norm1 = nn.LayerNorm(emb_size) - self._norm2 = nn.LayerNorm(emb_size) + self._ff = feed_forward_layer + self._norm1 = norm_layer(emb_size) + self._norm2 = norm_layer(emb_size) def forward( self, diff --git a/llm/src/llm/models/gpt/gpt2.py b/llm/src/llm/models/gpt/gpt2.py index e5036a7..09726b4 100644 --- a/llm/src/llm/models/gpt/gpt2.py +++ b/llm/src/llm/models/gpt/gpt2.py @@ -5,7 +5,7 @@ from llm.core.base_model import BaseModel from llm.core.token_embeddings import TokenEmbeddings from llm.core.positional_embeddings import PositionalEmbeddings from llm.core.cached_decoder import CachedDecoder - +from llm.core.feed_forward import FeedForward class GPT2(BaseModel): def __init__(self, config): @@ -27,6 +27,11 @@ class GPT2(BaseModel): num_heads=config["num_heads"], emb_size=config["embed_dim"], head_size=config["embed_dim"] // config["num_heads"], + feed_forward_layer=FeedForward( + emb_size=config["embed_dim"], + dropout=config["dropout"], + activation="gelu" + ), max_seq_len=config["max_position_embeddings"], dropout=config["dropout"] ) for _ in range(config["num_layers"])]) diff --git a/llm/src/llm/models/llama/llama.py b/llm/src/llm/models/llama/llama.py index 2b45f68..240d9a3 100644 --- a/llm/src/llm/models/llama/llama.py +++ b/llm/src/llm/models/llama/llama.py @@ -1,55 +1,13 @@ import torch -from torch import nn -from torch import Tensor +from torch import nn, Tensor import torch.nn.functional as F -from math import sqrt -from torch import nn -import torch -import math from llm.core.base_model import BaseModel from llm.core.token_embeddings import TokenEmbeddings -from llm.core.rope import RoPE from llm.core.swi_glu import SwiGLU from llm.core.rms_norm import RMSNorm -from llm.core.gelu import GELU -from llm.core.multi_head_attention import MultiHeadAttention - - -class CachedDecoder(nn.Module): - def __init__(self, - num_heads: int, - emb_size: int, - head_size: int, - max_seq_len: int, - rope: RoPE, - dropout: float = 0.1 - ): - super().__init__() - self._heads = MultiHeadAttention( - num_heads=num_heads, - emb_size=emb_size, - head_size=head_size, - max_seq_len=max_seq_len, - rope=rope, - dropout=dropout - ) - self._ff = SwiGLU(emb_size=emb_size, dropout=dropout) - self._norm1 = RMSNorm(emb_size) - self._norm2 = RMSNorm(emb_size) - - def forward(self, x: torch.Tensor, mask: torch.Tensor = None, use_cache: bool = True, cache: list = None) -> torch.Tensor: - norm1_out = self._norm1(x) - attention, kv_caches = self._heads(norm1_out, mask, use_cache=use_cache, cache=cache) - out = attention + x - - norm2_out = self._norm2(out) - ffn_out = self._ff(norm2_out) - - if use_cache is True: - return (ffn_out + out, kv_caches) - else: - return (ffn_out + out, None) +from llm.core.rope import RoPE +from llm.core.cached_decoder import CachedDecoder @@ -70,9 +28,14 @@ class Llama(BaseModel): self._dropout = nn.Dropout(config["dropout"]) self._decoders = nn.ModuleList([CachedDecoder( + norm_layer=RMSNorm, num_heads=config["num_heads"], emb_size=config["embed_dim"], head_size=config["embed_dim"] // config["num_heads"], + feed_forward_layer=SwiGLU( + emb_size=config["embed_dim"], + dropout=config["dropout"], + ), max_seq_len=config["max_position_embeddings"], rope=self._position_embeddings, dropout=config["dropout"],