mirror of
https://github.com/pese-git/llm-arch-research.git
synced 2026-01-23 21:10:54 +00:00
refactor: extract LLaMA components to separate modules in core directory
- Moved GELU, RMSNorm, RoPE, SiLU, and SwiGLU implementations from llama.py to dedicated files in core/ - Updated feed_forward.py to use new modular components - Modified llama.py to import components from core modules instead of local definitions - Improved code organization and reusability of activation functions and normalization layers This refactor enables better code reuse across different model architectures and follows the single responsibility principle.
This commit is contained in:
@@ -1,16 +1,8 @@
|
||||
from torch import nn
|
||||
import torch
|
||||
import math
|
||||
from .gelu import GELU
|
||||
|
||||
class GELU(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.sqrt_2_over_pi = torch.sqrt(torch.tensor(2.0) / math.pi)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return 0.5 * x * (1 + torch.tanh(
|
||||
self.sqrt_2_over_pi * (x + 0.044715 * torch.pow(x, 3))
|
||||
))
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
"""
|
||||
|
||||
12
llm/src/llm/core/gelu.py
Normal file
12
llm/src/llm/core/gelu.py
Normal file
@@ -0,0 +1,12 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
class GELU(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.sqrt_2_over_pi = torch.sqrt(torch.tensor(2.0) / math.pi)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return 0.5 * x * (1 + torch.tanh(
|
||||
self.sqrt_2_over_pi * (x + 0.044715 * torch.pow(x, 3))
|
||||
))
|
||||
13
llm/src/llm/core/rms_norm.py
Normal file
13
llm/src/llm/core/rms_norm.py
Normal file
@@ -0,0 +1,13 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dim: int, eps: float = 1e-6):
|
||||
super().__init__()
|
||||
self._eps = eps
|
||||
self._w = nn.Parameter(torch.ones(dim))
|
||||
|
||||
def forward(self, x: torch.Tensor): # [batch_size × seq_len × emb_size]
|
||||
rms = (x.pow(2).mean(-1, keepdim=True) + self._eps) ** 0.5
|
||||
norm_x = x / rms
|
||||
return self._w * norm_x
|
||||
44
llm/src/llm/core/rope.py
Normal file
44
llm/src/llm/core/rope.py
Normal file
@@ -0,0 +1,44 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
class RoPE(nn.Module):
|
||||
def __init__(self, head_size: int, max_seq_len: int, base: int = 10_000):
|
||||
super().__init__()
|
||||
assert head_size % 2 == 0, "head_size должен быть четным"
|
||||
|
||||
# Обратные частоты
|
||||
freqs = 1.0 / (base ** (2 * torch.arange(head_size // 2).float() / head_size))
|
||||
|
||||
# Позиции
|
||||
positions = torch.arange(max_seq_len).float()
|
||||
|
||||
# Матрица частот (внешнее произведение)
|
||||
#freq_matrix = torch.outer(positions, freqs)
|
||||
freq_matrix = positions.unsqueeze(1) * freqs.unsqueeze(0)
|
||||
|
||||
# Матрицы косинусов и синусов
|
||||
self.register_buffer('cos_matrix', torch.cos(freq_matrix))
|
||||
self.register_buffer('sin_matrix', torch.sin(freq_matrix))
|
||||
|
||||
|
||||
def forward(self, x: torch.Tensor): # Получает на вход тензор x (тип float) размером [batch_size × seq_len × head_size]
|
||||
seq_len = x.size(1)
|
||||
# Берем нужную часть матриц и приводим к типу x
|
||||
cos = self.cos_matrix[:seq_len].to(x.dtype) # [seq_len, head_size//2]
|
||||
sin = self.sin_matrix[:seq_len].to(x.dtype) # [seq_len, head_size//2]
|
||||
|
||||
|
||||
# Разделяем на четные и нечетные
|
||||
x_even = x[:, :, 0::2] # [batch_size, seq_len, head_size//2]
|
||||
x_odd = x[:, :, 1::2] # [batch_size, seq_len, head_size//2]
|
||||
|
||||
# Применяем поворот
|
||||
x_rotated_even = x_even * cos - x_odd * sin
|
||||
x_rotated_odd = x_even * sin + x_odd * cos
|
||||
|
||||
|
||||
# Объединяем обратно
|
||||
x_rotated = torch.stack([x_rotated_even, x_rotated_odd], dim=-1)
|
||||
x_rotated = x_rotated.flatten(-2) # [batch_size, seq_len, head_size]
|
||||
|
||||
return x_rotated
|
||||
6
llm/src/llm/core/silu.py
Normal file
6
llm/src/llm/core/silu.py
Normal file
@@ -0,0 +1,6 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
class SiLU(nn.Module):
|
||||
def forward(self, x: torch.Tensor): # [batch_size × seq_len × emb_size]
|
||||
return torch.sigmoid(x) * x
|
||||
21
llm/src/llm/core/swi_glu.py
Normal file
21
llm/src/llm/core/swi_glu.py
Normal file
@@ -0,0 +1,21 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from .silu import SiLU
|
||||
|
||||
class SwiGLU(nn.Module):
|
||||
def __init__(self, emb_size: int, dropout: float = 0.1):
|
||||
super().__init__()
|
||||
|
||||
self._gate = nn.Linear(emb_size, 4 * emb_size)
|
||||
self._up = nn.Linear(emb_size, 4 * emb_size)
|
||||
self._down = nn.Linear(4 * emb_size, emb_size)
|
||||
self._activation = SiLU()
|
||||
self._dropout = nn.Dropout(dropout)
|
||||
|
||||
def forward(self, x: torch.Tensor): # [batch_size × seq_len × emb_size].
|
||||
gate_out = self._gate(x) # [batch, seq, 4*emb]
|
||||
activation_out = self._activation(gate_out) # [batch, seq, 4*emb]
|
||||
up_out = self._up(x) # [batch, seq, 4*emb]
|
||||
out = up_out * activation_out # поэлементное!
|
||||
out = self._down(out) # [batch, seq, emb]
|
||||
return self._dropout(out)
|
||||
@@ -3,88 +3,16 @@ from torch import nn
|
||||
from torch import Tensor
|
||||
import torch.nn.functional as F
|
||||
from math import sqrt
|
||||
from torch import nn
|
||||
import torch
|
||||
import math
|
||||
|
||||
from llm.core.base_model import BaseModel
|
||||
from llm.core.token_embeddings import TokenEmbeddings
|
||||
|
||||
|
||||
class SiLU(nn.Module):
|
||||
def forward(self, x: torch.Tensor): # [batch_size × seq_len × emb_size]
|
||||
return torch.sigmoid(x) * x
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dim: int, eps: float = 1e-6):
|
||||
super().__init__()
|
||||
self._eps = eps
|
||||
self._w = nn.Parameter(torch.ones(dim))
|
||||
|
||||
def forward(self, x: torch.Tensor): # [batch_size × seq_len × emb_size]
|
||||
rms = (x.pow(2).mean(-1, keepdim=True) + self._eps) ** 0.5
|
||||
norm_x = x / rms
|
||||
return self._w * norm_x
|
||||
|
||||
class SwiGLU(nn.Module):
|
||||
def __init__(self, emb_size: int, dropout: float = 0.1):
|
||||
super().__init__()
|
||||
|
||||
self._gate = nn.Linear(emb_size, 4 * emb_size)
|
||||
self._up = nn.Linear(emb_size, 4 * emb_size)
|
||||
self._down = nn.Linear(4 * emb_size, emb_size)
|
||||
self._activation = SiLU()
|
||||
self._dropout = nn.Dropout(dropout)
|
||||
|
||||
def forward(self, x: torch.Tensor): # [batch_size × seq_len × emb_size].
|
||||
gate_out = self._gate(x) # [batch, seq, 4*emb]
|
||||
activation_out = self._activation(gate_out) # [batch, seq, 4*emb]
|
||||
up_out = self._up(x) # [batch, seq, 4*emb]
|
||||
out = up_out * activation_out # поэлементное!
|
||||
out = self._down(out) # [batch, seq, emb]
|
||||
return self._dropout(out)
|
||||
|
||||
|
||||
|
||||
class RoPE(nn.Module):
|
||||
def __init__(self, head_size: int, max_seq_len: int, base: int = 10_000):
|
||||
super().__init__()
|
||||
assert head_size % 2 == 0, "head_size должен быть четным"
|
||||
|
||||
# Обратные частоты
|
||||
freqs = 1.0 / (base ** (2 * torch.arange(head_size // 2).float() / head_size))
|
||||
|
||||
# Позиции
|
||||
positions = torch.arange(max_seq_len).float()
|
||||
|
||||
# Матрица частот (внешнее произведение)
|
||||
#freq_matrix = torch.outer(positions, freqs)
|
||||
freq_matrix = positions.unsqueeze(1) * freqs.unsqueeze(0)
|
||||
|
||||
# Матрицы косинусов и синусов
|
||||
self.register_buffer('cos_matrix', torch.cos(freq_matrix))
|
||||
self.register_buffer('sin_matrix', torch.sin(freq_matrix))
|
||||
|
||||
|
||||
def forward(self, x: torch.Tensor): # Получает на вход тензор x (тип float) размером [batch_size × seq_len × head_size]
|
||||
seq_len = x.size(1)
|
||||
# Берем нужную часть матриц и приводим к типу x
|
||||
cos = self.cos_matrix[:seq_len].to(x.dtype) # [seq_len, head_size//2]
|
||||
sin = self.sin_matrix[:seq_len].to(x.dtype) # [seq_len, head_size//2]
|
||||
|
||||
|
||||
# Разделяем на четные и нечетные
|
||||
x_even = x[:, :, 0::2] # [batch_size, seq_len, head_size//2]
|
||||
x_odd = x[:, :, 1::2] # [batch_size, seq_len, head_size//2]
|
||||
|
||||
# Применяем поворот
|
||||
x_rotated_even = x_even * cos - x_odd * sin
|
||||
x_rotated_odd = x_even * sin + x_odd * cos
|
||||
|
||||
|
||||
# Объединяем обратно
|
||||
x_rotated = torch.stack([x_rotated_even, x_rotated_odd], dim=-1)
|
||||
x_rotated = x_rotated.flatten(-2) # [batch_size, seq_len, head_size]
|
||||
|
||||
return x_rotated
|
||||
|
||||
from llm.core.rope import RoPE
|
||||
from llm.core.swi_glu import SwiGLU
|
||||
from llm.core.rms_norm import RMSNorm
|
||||
from llm.core.gelu import GELU
|
||||
|
||||
|
||||
class HeadAttention(nn.Module):
|
||||
@@ -133,10 +61,7 @@ class HeadAttention(nn.Module):
|
||||
return (x_out, (k, v))
|
||||
else:
|
||||
return (x_out, None)
|
||||
|
||||
from torch import nn
|
||||
import torch
|
||||
import math
|
||||
|
||||
|
||||
class MultiHeadAttention(nn.Module):
|
||||
def __init__(self, num_heads: int, emb_size: int, head_size: int, max_seq_len: int, rope: RoPE, dropout: float = 0.1):
|
||||
@@ -176,17 +101,6 @@ class MultiHeadAttention(nn.Module):
|
||||
else:
|
||||
return (final_output, None)
|
||||
|
||||
|
||||
class GELU(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.sqrt_2_over_pi = torch.sqrt(torch.tensor(2.0) / math.pi)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return 0.5 * x * (1 + torch.tanh(
|
||||
self.sqrt_2_over_pi * (x + 0.044715 * torch.pow(x, 3))
|
||||
))
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(self,
|
||||
|
||||
Reference in New Issue
Block a user