refactor: extract LLaMA components to separate modules in core directory

- Moved GELU, RMSNorm, RoPE, SiLU, and SwiGLU implementations from llama.py to dedicated files in core/
- Updated feed_forward.py to use new modular components
- Modified llama.py to import components from core modules instead of local definitions
- Improved code organization and reusability of activation functions and normalization layers

This refactor enables better code reuse across different model architectures and follows the single responsibility principle.
This commit is contained in:
Sergey Penkovsky
2025-10-06 14:09:19 +03:00
parent f30cd530a9
commit 211adf574c
7 changed files with 105 additions and 103 deletions

View File

@@ -1,16 +1,8 @@
from torch import nn
import torch
import math
from .gelu import GELU
class GELU(nn.Module):
def __init__(self):
super().__init__()
self.sqrt_2_over_pi = torch.sqrt(torch.tensor(2.0) / math.pi)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return 0.5 * x * (1 + torch.tanh(
self.sqrt_2_over_pi * (x + 0.044715 * torch.pow(x, 3))
))
class FeedForward(nn.Module):
"""

12
llm/src/llm/core/gelu.py Normal file
View File

@@ -0,0 +1,12 @@
import torch
from torch import nn
class GELU(nn.Module):
def __init__(self):
super().__init__()
self.sqrt_2_over_pi = torch.sqrt(torch.tensor(2.0) / math.pi)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return 0.5 * x * (1 + torch.tanh(
self.sqrt_2_over_pi * (x + 0.044715 * torch.pow(x, 3))
))

View File

@@ -0,0 +1,13 @@
import torch
from torch import nn
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self._eps = eps
self._w = nn.Parameter(torch.ones(dim))
def forward(self, x: torch.Tensor): # [batch_size × seq_len × emb_size]
rms = (x.pow(2).mean(-1, keepdim=True) + self._eps) ** 0.5
norm_x = x / rms
return self._w * norm_x

44
llm/src/llm/core/rope.py Normal file
View File

@@ -0,0 +1,44 @@
import torch
from torch import nn
class RoPE(nn.Module):
def __init__(self, head_size: int, max_seq_len: int, base: int = 10_000):
super().__init__()
assert head_size % 2 == 0, "head_size должен быть четным"
# Обратные частоты
freqs = 1.0 / (base ** (2 * torch.arange(head_size // 2).float() / head_size))
# Позиции
positions = torch.arange(max_seq_len).float()
# Матрица частот (внешнее произведение)
#freq_matrix = torch.outer(positions, freqs)
freq_matrix = positions.unsqueeze(1) * freqs.unsqueeze(0)
# Матрицы косинусов и синусов
self.register_buffer('cos_matrix', torch.cos(freq_matrix))
self.register_buffer('sin_matrix', torch.sin(freq_matrix))
def forward(self, x: torch.Tensor): # Получает на вход тензор x (тип float) размером [batch_size × seq_len × head_size]
seq_len = x.size(1)
# Берем нужную часть матриц и приводим к типу x
cos = self.cos_matrix[:seq_len].to(x.dtype) # [seq_len, head_size//2]
sin = self.sin_matrix[:seq_len].to(x.dtype) # [seq_len, head_size//2]
# Разделяем на четные и нечетные
x_even = x[:, :, 0::2] # [batch_size, seq_len, head_size//2]
x_odd = x[:, :, 1::2] # [batch_size, seq_len, head_size//2]
# Применяем поворот
x_rotated_even = x_even * cos - x_odd * sin
x_rotated_odd = x_even * sin + x_odd * cos
# Объединяем обратно
x_rotated = torch.stack([x_rotated_even, x_rotated_odd], dim=-1)
x_rotated = x_rotated.flatten(-2) # [batch_size, seq_len, head_size]
return x_rotated

6
llm/src/llm/core/silu.py Normal file
View File

@@ -0,0 +1,6 @@
import torch
from torch import nn
class SiLU(nn.Module):
def forward(self, x: torch.Tensor): # [batch_size × seq_len × emb_size]
return torch.sigmoid(x) * x

View File

@@ -0,0 +1,21 @@
import torch
from torch import nn
from .silu import SiLU
class SwiGLU(nn.Module):
def __init__(self, emb_size: int, dropout: float = 0.1):
super().__init__()
self._gate = nn.Linear(emb_size, 4 * emb_size)
self._up = nn.Linear(emb_size, 4 * emb_size)
self._down = nn.Linear(4 * emb_size, emb_size)
self._activation = SiLU()
self._dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor): # [batch_size × seq_len × emb_size].
gate_out = self._gate(x) # [batch, seq, 4*emb]
activation_out = self._activation(gate_out) # [batch, seq, 4*emb]
up_out = self._up(x) # [batch, seq, 4*emb]
out = up_out * activation_out # поэлементное!
out = self._down(out) # [batch, seq, emb]
return self._dropout(out)

View File

@@ -3,88 +3,16 @@ from torch import nn
from torch import Tensor
import torch.nn.functional as F
from math import sqrt
from torch import nn
import torch
import math
from llm.core.base_model import BaseModel
from llm.core.token_embeddings import TokenEmbeddings
class SiLU(nn.Module):
def forward(self, x: torch.Tensor): # [batch_size × seq_len × emb_size]
return torch.sigmoid(x) * x
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self._eps = eps
self._w = nn.Parameter(torch.ones(dim))
def forward(self, x: torch.Tensor): # [batch_size × seq_len × emb_size]
rms = (x.pow(2).mean(-1, keepdim=True) + self._eps) ** 0.5
norm_x = x / rms
return self._w * norm_x
class SwiGLU(nn.Module):
def __init__(self, emb_size: int, dropout: float = 0.1):
super().__init__()
self._gate = nn.Linear(emb_size, 4 * emb_size)
self._up = nn.Linear(emb_size, 4 * emb_size)
self._down = nn.Linear(4 * emb_size, emb_size)
self._activation = SiLU()
self._dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor): # [batch_size × seq_len × emb_size].
gate_out = self._gate(x) # [batch, seq, 4*emb]
activation_out = self._activation(gate_out) # [batch, seq, 4*emb]
up_out = self._up(x) # [batch, seq, 4*emb]
out = up_out * activation_out # поэлементное!
out = self._down(out) # [batch, seq, emb]
return self._dropout(out)
class RoPE(nn.Module):
def __init__(self, head_size: int, max_seq_len: int, base: int = 10_000):
super().__init__()
assert head_size % 2 == 0, "head_size должен быть четным"
# Обратные частоты
freqs = 1.0 / (base ** (2 * torch.arange(head_size // 2).float() / head_size))
# Позиции
positions = torch.arange(max_seq_len).float()
# Матрица частот (внешнее произведение)
#freq_matrix = torch.outer(positions, freqs)
freq_matrix = positions.unsqueeze(1) * freqs.unsqueeze(0)
# Матрицы косинусов и синусов
self.register_buffer('cos_matrix', torch.cos(freq_matrix))
self.register_buffer('sin_matrix', torch.sin(freq_matrix))
def forward(self, x: torch.Tensor): # Получает на вход тензор x (тип float) размером [batch_size × seq_len × head_size]
seq_len = x.size(1)
# Берем нужную часть матриц и приводим к типу x
cos = self.cos_matrix[:seq_len].to(x.dtype) # [seq_len, head_size//2]
sin = self.sin_matrix[:seq_len].to(x.dtype) # [seq_len, head_size//2]
# Разделяем на четные и нечетные
x_even = x[:, :, 0::2] # [batch_size, seq_len, head_size//2]
x_odd = x[:, :, 1::2] # [batch_size, seq_len, head_size//2]
# Применяем поворот
x_rotated_even = x_even * cos - x_odd * sin
x_rotated_odd = x_even * sin + x_odd * cos
# Объединяем обратно
x_rotated = torch.stack([x_rotated_even, x_rotated_odd], dim=-1)
x_rotated = x_rotated.flatten(-2) # [batch_size, seq_len, head_size]
return x_rotated
from llm.core.rope import RoPE
from llm.core.swi_glu import SwiGLU
from llm.core.rms_norm import RMSNorm
from llm.core.gelu import GELU
class HeadAttention(nn.Module):
@@ -134,9 +62,6 @@ class HeadAttention(nn.Module):
else:
return (x_out, None)
from torch import nn
import torch
import math
class MultiHeadAttention(nn.Module):
def __init__(self, num_heads: int, emb_size: int, head_size: int, max_seq_len: int, rope: RoPE, dropout: float = 0.1):
@@ -177,17 +102,6 @@ class MultiHeadAttention(nn.Module):
return (final_output, None)
class GELU(nn.Module):
def __init__(self):
super().__init__()
self.sqrt_2_over_pi = torch.sqrt(torch.tensor(2.0) / math.pi)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return 0.5 * x * (1 + torch.tanh(
self.sqrt_2_over_pi * (x + 0.044715 * torch.pow(x, 3))
))
class Decoder(nn.Module):
def __init__(self,
num_heads: int,