From 211adf574ce9f56d4073c768ec4a98dce7d4ed27 Mon Sep 17 00:00:00 2001 From: Sergey Penkovsky Date: Mon, 6 Oct 2025 14:09:19 +0300 Subject: [PATCH] refactor: extract LLaMA components to separate modules in core directory - Moved GELU, RMSNorm, RoPE, SiLU, and SwiGLU implementations from llama.py to dedicated files in core/ - Updated feed_forward.py to use new modular components - Modified llama.py to import components from core modules instead of local definitions - Improved code organization and reusability of activation functions and normalization layers This refactor enables better code reuse across different model architectures and follows the single responsibility principle. --- llm/src/llm/core/feed_forward.py | 10 +-- llm/src/llm/core/gelu.py | 12 ++++ llm/src/llm/core/rms_norm.py | 13 ++++ llm/src/llm/core/rope.py | 44 +++++++++++++ llm/src/llm/core/silu.py | 6 ++ llm/src/llm/core/swi_glu.py | 21 ++++++ llm/src/llm/models/llama/llama.py | 102 +++--------------------------- 7 files changed, 105 insertions(+), 103 deletions(-) create mode 100644 llm/src/llm/core/gelu.py create mode 100644 llm/src/llm/core/rms_norm.py create mode 100644 llm/src/llm/core/rope.py create mode 100644 llm/src/llm/core/silu.py create mode 100644 llm/src/llm/core/swi_glu.py diff --git a/llm/src/llm/core/feed_forward.py b/llm/src/llm/core/feed_forward.py index e1d7576..93fb3d3 100644 --- a/llm/src/llm/core/feed_forward.py +++ b/llm/src/llm/core/feed_forward.py @@ -1,16 +1,8 @@ from torch import nn import torch import math +from .gelu import GELU -class GELU(nn.Module): - def __init__(self): - super().__init__() - self.sqrt_2_over_pi = torch.sqrt(torch.tensor(2.0) / math.pi) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return 0.5 * x * (1 + torch.tanh( - self.sqrt_2_over_pi * (x + 0.044715 * torch.pow(x, 3)) - )) class FeedForward(nn.Module): """ diff --git a/llm/src/llm/core/gelu.py b/llm/src/llm/core/gelu.py new file mode 100644 index 0000000..e311c6e --- /dev/null +++ b/llm/src/llm/core/gelu.py @@ -0,0 +1,12 @@ +import torch +from torch import nn + +class GELU(nn.Module): + def __init__(self): + super().__init__() + self.sqrt_2_over_pi = torch.sqrt(torch.tensor(2.0) / math.pi) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return 0.5 * x * (1 + torch.tanh( + self.sqrt_2_over_pi * (x + 0.044715 * torch.pow(x, 3)) + )) \ No newline at end of file diff --git a/llm/src/llm/core/rms_norm.py b/llm/src/llm/core/rms_norm.py new file mode 100644 index 0000000..37fafe8 --- /dev/null +++ b/llm/src/llm/core/rms_norm.py @@ -0,0 +1,13 @@ +import torch +from torch import nn + +class RMSNorm(nn.Module): + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self._eps = eps + self._w = nn.Parameter(torch.ones(dim)) + + def forward(self, x: torch.Tensor): # [batch_size × seq_len × emb_size] + rms = (x.pow(2).mean(-1, keepdim=True) + self._eps) ** 0.5 + norm_x = x / rms + return self._w * norm_x \ No newline at end of file diff --git a/llm/src/llm/core/rope.py b/llm/src/llm/core/rope.py new file mode 100644 index 0000000..966b981 --- /dev/null +++ b/llm/src/llm/core/rope.py @@ -0,0 +1,44 @@ +import torch +from torch import nn + +class RoPE(nn.Module): + def __init__(self, head_size: int, max_seq_len: int, base: int = 10_000): + super().__init__() + assert head_size % 2 == 0, "head_size должен быть четным" + + # Обратные частоты + freqs = 1.0 / (base ** (2 * torch.arange(head_size // 2).float() / head_size)) + + # Позиции + positions = torch.arange(max_seq_len).float() + + # Матрица частот (внешнее произведение) + #freq_matrix = torch.outer(positions, freqs) + freq_matrix = positions.unsqueeze(1) * freqs.unsqueeze(0) + + # Матрицы косинусов и синусов + self.register_buffer('cos_matrix', torch.cos(freq_matrix)) + self.register_buffer('sin_matrix', torch.sin(freq_matrix)) + + + def forward(self, x: torch.Tensor): # Получает на вход тензор x (тип float) размером [batch_size × seq_len × head_size] + seq_len = x.size(1) + # Берем нужную часть матриц и приводим к типу x + cos = self.cos_matrix[:seq_len].to(x.dtype) # [seq_len, head_size//2] + sin = self.sin_matrix[:seq_len].to(x.dtype) # [seq_len, head_size//2] + + + # Разделяем на четные и нечетные + x_even = x[:, :, 0::2] # [batch_size, seq_len, head_size//2] + x_odd = x[:, :, 1::2] # [batch_size, seq_len, head_size//2] + + # Применяем поворот + x_rotated_even = x_even * cos - x_odd * sin + x_rotated_odd = x_even * sin + x_odd * cos + + + # Объединяем обратно + x_rotated = torch.stack([x_rotated_even, x_rotated_odd], dim=-1) + x_rotated = x_rotated.flatten(-2) # [batch_size, seq_len, head_size] + + return x_rotated \ No newline at end of file diff --git a/llm/src/llm/core/silu.py b/llm/src/llm/core/silu.py new file mode 100644 index 0000000..b557280 --- /dev/null +++ b/llm/src/llm/core/silu.py @@ -0,0 +1,6 @@ +import torch +from torch import nn + +class SiLU(nn.Module): + def forward(self, x: torch.Tensor): # [batch_size × seq_len × emb_size] + return torch.sigmoid(x) * x \ No newline at end of file diff --git a/llm/src/llm/core/swi_glu.py b/llm/src/llm/core/swi_glu.py new file mode 100644 index 0000000..0c3ade7 --- /dev/null +++ b/llm/src/llm/core/swi_glu.py @@ -0,0 +1,21 @@ +import torch +from torch import nn +from .silu import SiLU + +class SwiGLU(nn.Module): + def __init__(self, emb_size: int, dropout: float = 0.1): + super().__init__() + + self._gate = nn.Linear(emb_size, 4 * emb_size) + self._up = nn.Linear(emb_size, 4 * emb_size) + self._down = nn.Linear(4 * emb_size, emb_size) + self._activation = SiLU() + self._dropout = nn.Dropout(dropout) + + def forward(self, x: torch.Tensor): # [batch_size × seq_len × emb_size]. + gate_out = self._gate(x) # [batch, seq, 4*emb] + activation_out = self._activation(gate_out) # [batch, seq, 4*emb] + up_out = self._up(x) # [batch, seq, 4*emb] + out = up_out * activation_out # поэлементное! + out = self._down(out) # [batch, seq, emb] + return self._dropout(out) \ No newline at end of file diff --git a/llm/src/llm/models/llama/llama.py b/llm/src/llm/models/llama/llama.py index 793862a..d220c1a 100644 --- a/llm/src/llm/models/llama/llama.py +++ b/llm/src/llm/models/llama/llama.py @@ -3,88 +3,16 @@ from torch import nn from torch import Tensor import torch.nn.functional as F from math import sqrt +from torch import nn +import torch +import math from llm.core.base_model import BaseModel from llm.core.token_embeddings import TokenEmbeddings - - -class SiLU(nn.Module): - def forward(self, x: torch.Tensor): # [batch_size × seq_len × emb_size] - return torch.sigmoid(x) * x - -class RMSNorm(nn.Module): - def __init__(self, dim: int, eps: float = 1e-6): - super().__init__() - self._eps = eps - self._w = nn.Parameter(torch.ones(dim)) - - def forward(self, x: torch.Tensor): # [batch_size × seq_len × emb_size] - rms = (x.pow(2).mean(-1, keepdim=True) + self._eps) ** 0.5 - norm_x = x / rms - return self._w * norm_x - -class SwiGLU(nn.Module): - def __init__(self, emb_size: int, dropout: float = 0.1): - super().__init__() - - self._gate = nn.Linear(emb_size, 4 * emb_size) - self._up = nn.Linear(emb_size, 4 * emb_size) - self._down = nn.Linear(4 * emb_size, emb_size) - self._activation = SiLU() - self._dropout = nn.Dropout(dropout) - - def forward(self, x: torch.Tensor): # [batch_size × seq_len × emb_size]. - gate_out = self._gate(x) # [batch, seq, 4*emb] - activation_out = self._activation(gate_out) # [batch, seq, 4*emb] - up_out = self._up(x) # [batch, seq, 4*emb] - out = up_out * activation_out # поэлементное! - out = self._down(out) # [batch, seq, emb] - return self._dropout(out) - - - -class RoPE(nn.Module): - def __init__(self, head_size: int, max_seq_len: int, base: int = 10_000): - super().__init__() - assert head_size % 2 == 0, "head_size должен быть четным" - - # Обратные частоты - freqs = 1.0 / (base ** (2 * torch.arange(head_size // 2).float() / head_size)) - - # Позиции - positions = torch.arange(max_seq_len).float() - - # Матрица частот (внешнее произведение) - #freq_matrix = torch.outer(positions, freqs) - freq_matrix = positions.unsqueeze(1) * freqs.unsqueeze(0) - - # Матрицы косинусов и синусов - self.register_buffer('cos_matrix', torch.cos(freq_matrix)) - self.register_buffer('sin_matrix', torch.sin(freq_matrix)) - - - def forward(self, x: torch.Tensor): # Получает на вход тензор x (тип float) размером [batch_size × seq_len × head_size] - seq_len = x.size(1) - # Берем нужную часть матриц и приводим к типу x - cos = self.cos_matrix[:seq_len].to(x.dtype) # [seq_len, head_size//2] - sin = self.sin_matrix[:seq_len].to(x.dtype) # [seq_len, head_size//2] - - - # Разделяем на четные и нечетные - x_even = x[:, :, 0::2] # [batch_size, seq_len, head_size//2] - x_odd = x[:, :, 1::2] # [batch_size, seq_len, head_size//2] - - # Применяем поворот - x_rotated_even = x_even * cos - x_odd * sin - x_rotated_odd = x_even * sin + x_odd * cos - - - # Объединяем обратно - x_rotated = torch.stack([x_rotated_even, x_rotated_odd], dim=-1) - x_rotated = x_rotated.flatten(-2) # [batch_size, seq_len, head_size] - - return x_rotated - +from llm.core.rope import RoPE +from llm.core.swi_glu import SwiGLU +from llm.core.rms_norm import RMSNorm +from llm.core.gelu import GELU class HeadAttention(nn.Module): @@ -133,10 +61,7 @@ class HeadAttention(nn.Module): return (x_out, (k, v)) else: return (x_out, None) - -from torch import nn -import torch -import math + class MultiHeadAttention(nn.Module): def __init__(self, num_heads: int, emb_size: int, head_size: int, max_seq_len: int, rope: RoPE, dropout: float = 0.1): @@ -176,17 +101,6 @@ class MultiHeadAttention(nn.Module): else: return (final_output, None) - -class GELU(nn.Module): - def __init__(self): - super().__init__() - self.sqrt_2_over_pi = torch.sqrt(torch.tensor(2.0) / math.pi) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return 0.5 * x * (1 + torch.tanh( - self.sqrt_2_over_pi * (x + 0.044715 * torch.pow(x, 3)) - )) - class Decoder(nn.Module): def __init__(self,