mirror of
https://github.com/pese-git/llm-arch-research.git
synced 2026-01-23 21:10:54 +00:00
refactor: extract LLaMA components to separate modules in core directory
- Moved GELU, RMSNorm, RoPE, SiLU, and SwiGLU implementations from llama.py to dedicated files in core/ - Updated feed_forward.py to use new modular components - Modified llama.py to import components from core modules instead of local definitions - Improved code organization and reusability of activation functions and normalization layers This refactor enables better code reuse across different model architectures and follows the single responsibility principle.
This commit is contained in:
@@ -1,16 +1,8 @@
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
import torch
|
import torch
|
||||||
import math
|
import math
|
||||||
|
from .gelu import GELU
|
||||||
|
|
||||||
class GELU(nn.Module):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
self.sqrt_2_over_pi = torch.sqrt(torch.tensor(2.0) / math.pi)
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
||||||
return 0.5 * x * (1 + torch.tanh(
|
|
||||||
self.sqrt_2_over_pi * (x + 0.044715 * torch.pow(x, 3))
|
|
||||||
))
|
|
||||||
|
|
||||||
class FeedForward(nn.Module):
|
class FeedForward(nn.Module):
|
||||||
"""
|
"""
|
||||||
|
|||||||
12
llm/src/llm/core/gelu.py
Normal file
12
llm/src/llm/core/gelu.py
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
class GELU(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.sqrt_2_over_pi = torch.sqrt(torch.tensor(2.0) / math.pi)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
return 0.5 * x * (1 + torch.tanh(
|
||||||
|
self.sqrt_2_over_pi * (x + 0.044715 * torch.pow(x, 3))
|
||||||
|
))
|
||||||
13
llm/src/llm/core/rms_norm.py
Normal file
13
llm/src/llm/core/rms_norm.py
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
class RMSNorm(nn.Module):
|
||||||
|
def __init__(self, dim: int, eps: float = 1e-6):
|
||||||
|
super().__init__()
|
||||||
|
self._eps = eps
|
||||||
|
self._w = nn.Parameter(torch.ones(dim))
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor): # [batch_size × seq_len × emb_size]
|
||||||
|
rms = (x.pow(2).mean(-1, keepdim=True) + self._eps) ** 0.5
|
||||||
|
norm_x = x / rms
|
||||||
|
return self._w * norm_x
|
||||||
44
llm/src/llm/core/rope.py
Normal file
44
llm/src/llm/core/rope.py
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
class RoPE(nn.Module):
|
||||||
|
def __init__(self, head_size: int, max_seq_len: int, base: int = 10_000):
|
||||||
|
super().__init__()
|
||||||
|
assert head_size % 2 == 0, "head_size должен быть четным"
|
||||||
|
|
||||||
|
# Обратные частоты
|
||||||
|
freqs = 1.0 / (base ** (2 * torch.arange(head_size // 2).float() / head_size))
|
||||||
|
|
||||||
|
# Позиции
|
||||||
|
positions = torch.arange(max_seq_len).float()
|
||||||
|
|
||||||
|
# Матрица частот (внешнее произведение)
|
||||||
|
#freq_matrix = torch.outer(positions, freqs)
|
||||||
|
freq_matrix = positions.unsqueeze(1) * freqs.unsqueeze(0)
|
||||||
|
|
||||||
|
# Матрицы косинусов и синусов
|
||||||
|
self.register_buffer('cos_matrix', torch.cos(freq_matrix))
|
||||||
|
self.register_buffer('sin_matrix', torch.sin(freq_matrix))
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor): # Получает на вход тензор x (тип float) размером [batch_size × seq_len × head_size]
|
||||||
|
seq_len = x.size(1)
|
||||||
|
# Берем нужную часть матриц и приводим к типу x
|
||||||
|
cos = self.cos_matrix[:seq_len].to(x.dtype) # [seq_len, head_size//2]
|
||||||
|
sin = self.sin_matrix[:seq_len].to(x.dtype) # [seq_len, head_size//2]
|
||||||
|
|
||||||
|
|
||||||
|
# Разделяем на четные и нечетные
|
||||||
|
x_even = x[:, :, 0::2] # [batch_size, seq_len, head_size//2]
|
||||||
|
x_odd = x[:, :, 1::2] # [batch_size, seq_len, head_size//2]
|
||||||
|
|
||||||
|
# Применяем поворот
|
||||||
|
x_rotated_even = x_even * cos - x_odd * sin
|
||||||
|
x_rotated_odd = x_even * sin + x_odd * cos
|
||||||
|
|
||||||
|
|
||||||
|
# Объединяем обратно
|
||||||
|
x_rotated = torch.stack([x_rotated_even, x_rotated_odd], dim=-1)
|
||||||
|
x_rotated = x_rotated.flatten(-2) # [batch_size, seq_len, head_size]
|
||||||
|
|
||||||
|
return x_rotated
|
||||||
6
llm/src/llm/core/silu.py
Normal file
6
llm/src/llm/core/silu.py
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
class SiLU(nn.Module):
|
||||||
|
def forward(self, x: torch.Tensor): # [batch_size × seq_len × emb_size]
|
||||||
|
return torch.sigmoid(x) * x
|
||||||
21
llm/src/llm/core/swi_glu.py
Normal file
21
llm/src/llm/core/swi_glu.py
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from .silu import SiLU
|
||||||
|
|
||||||
|
class SwiGLU(nn.Module):
|
||||||
|
def __init__(self, emb_size: int, dropout: float = 0.1):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self._gate = nn.Linear(emb_size, 4 * emb_size)
|
||||||
|
self._up = nn.Linear(emb_size, 4 * emb_size)
|
||||||
|
self._down = nn.Linear(4 * emb_size, emb_size)
|
||||||
|
self._activation = SiLU()
|
||||||
|
self._dropout = nn.Dropout(dropout)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor): # [batch_size × seq_len × emb_size].
|
||||||
|
gate_out = self._gate(x) # [batch, seq, 4*emb]
|
||||||
|
activation_out = self._activation(gate_out) # [batch, seq, 4*emb]
|
||||||
|
up_out = self._up(x) # [batch, seq, 4*emb]
|
||||||
|
out = up_out * activation_out # поэлементное!
|
||||||
|
out = self._down(out) # [batch, seq, emb]
|
||||||
|
return self._dropout(out)
|
||||||
@@ -3,88 +3,16 @@ from torch import nn
|
|||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from math import sqrt
|
from math import sqrt
|
||||||
|
from torch import nn
|
||||||
|
import torch
|
||||||
|
import math
|
||||||
|
|
||||||
from llm.core.base_model import BaseModel
|
from llm.core.base_model import BaseModel
|
||||||
from llm.core.token_embeddings import TokenEmbeddings
|
from llm.core.token_embeddings import TokenEmbeddings
|
||||||
|
from llm.core.rope import RoPE
|
||||||
|
from llm.core.swi_glu import SwiGLU
|
||||||
class SiLU(nn.Module):
|
from llm.core.rms_norm import RMSNorm
|
||||||
def forward(self, x: torch.Tensor): # [batch_size × seq_len × emb_size]
|
from llm.core.gelu import GELU
|
||||||
return torch.sigmoid(x) * x
|
|
||||||
|
|
||||||
class RMSNorm(nn.Module):
|
|
||||||
def __init__(self, dim: int, eps: float = 1e-6):
|
|
||||||
super().__init__()
|
|
||||||
self._eps = eps
|
|
||||||
self._w = nn.Parameter(torch.ones(dim))
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor): # [batch_size × seq_len × emb_size]
|
|
||||||
rms = (x.pow(2).mean(-1, keepdim=True) + self._eps) ** 0.5
|
|
||||||
norm_x = x / rms
|
|
||||||
return self._w * norm_x
|
|
||||||
|
|
||||||
class SwiGLU(nn.Module):
|
|
||||||
def __init__(self, emb_size: int, dropout: float = 0.1):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self._gate = nn.Linear(emb_size, 4 * emb_size)
|
|
||||||
self._up = nn.Linear(emb_size, 4 * emb_size)
|
|
||||||
self._down = nn.Linear(4 * emb_size, emb_size)
|
|
||||||
self._activation = SiLU()
|
|
||||||
self._dropout = nn.Dropout(dropout)
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor): # [batch_size × seq_len × emb_size].
|
|
||||||
gate_out = self._gate(x) # [batch, seq, 4*emb]
|
|
||||||
activation_out = self._activation(gate_out) # [batch, seq, 4*emb]
|
|
||||||
up_out = self._up(x) # [batch, seq, 4*emb]
|
|
||||||
out = up_out * activation_out # поэлементное!
|
|
||||||
out = self._down(out) # [batch, seq, emb]
|
|
||||||
return self._dropout(out)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class RoPE(nn.Module):
|
|
||||||
def __init__(self, head_size: int, max_seq_len: int, base: int = 10_000):
|
|
||||||
super().__init__()
|
|
||||||
assert head_size % 2 == 0, "head_size должен быть четным"
|
|
||||||
|
|
||||||
# Обратные частоты
|
|
||||||
freqs = 1.0 / (base ** (2 * torch.arange(head_size // 2).float() / head_size))
|
|
||||||
|
|
||||||
# Позиции
|
|
||||||
positions = torch.arange(max_seq_len).float()
|
|
||||||
|
|
||||||
# Матрица частот (внешнее произведение)
|
|
||||||
#freq_matrix = torch.outer(positions, freqs)
|
|
||||||
freq_matrix = positions.unsqueeze(1) * freqs.unsqueeze(0)
|
|
||||||
|
|
||||||
# Матрицы косинусов и синусов
|
|
||||||
self.register_buffer('cos_matrix', torch.cos(freq_matrix))
|
|
||||||
self.register_buffer('sin_matrix', torch.sin(freq_matrix))
|
|
||||||
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor): # Получает на вход тензор x (тип float) размером [batch_size × seq_len × head_size]
|
|
||||||
seq_len = x.size(1)
|
|
||||||
# Берем нужную часть матриц и приводим к типу x
|
|
||||||
cos = self.cos_matrix[:seq_len].to(x.dtype) # [seq_len, head_size//2]
|
|
||||||
sin = self.sin_matrix[:seq_len].to(x.dtype) # [seq_len, head_size//2]
|
|
||||||
|
|
||||||
|
|
||||||
# Разделяем на четные и нечетные
|
|
||||||
x_even = x[:, :, 0::2] # [batch_size, seq_len, head_size//2]
|
|
||||||
x_odd = x[:, :, 1::2] # [batch_size, seq_len, head_size//2]
|
|
||||||
|
|
||||||
# Применяем поворот
|
|
||||||
x_rotated_even = x_even * cos - x_odd * sin
|
|
||||||
x_rotated_odd = x_even * sin + x_odd * cos
|
|
||||||
|
|
||||||
|
|
||||||
# Объединяем обратно
|
|
||||||
x_rotated = torch.stack([x_rotated_even, x_rotated_odd], dim=-1)
|
|
||||||
x_rotated = x_rotated.flatten(-2) # [batch_size, seq_len, head_size]
|
|
||||||
|
|
||||||
return x_rotated
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class HeadAttention(nn.Module):
|
class HeadAttention(nn.Module):
|
||||||
@@ -134,9 +62,6 @@ class HeadAttention(nn.Module):
|
|||||||
else:
|
else:
|
||||||
return (x_out, None)
|
return (x_out, None)
|
||||||
|
|
||||||
from torch import nn
|
|
||||||
import torch
|
|
||||||
import math
|
|
||||||
|
|
||||||
class MultiHeadAttention(nn.Module):
|
class MultiHeadAttention(nn.Module):
|
||||||
def __init__(self, num_heads: int, emb_size: int, head_size: int, max_seq_len: int, rope: RoPE, dropout: float = 0.1):
|
def __init__(self, num_heads: int, emb_size: int, head_size: int, max_seq_len: int, rope: RoPE, dropout: float = 0.1):
|
||||||
@@ -177,17 +102,6 @@ class MultiHeadAttention(nn.Module):
|
|||||||
return (final_output, None)
|
return (final_output, None)
|
||||||
|
|
||||||
|
|
||||||
class GELU(nn.Module):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
self.sqrt_2_over_pi = torch.sqrt(torch.tensor(2.0) / math.pi)
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
||||||
return 0.5 * x * (1 + torch.tanh(
|
|
||||||
self.sqrt_2_over_pi * (x + 0.044715 * torch.pow(x, 3))
|
|
||||||
))
|
|
||||||
|
|
||||||
|
|
||||||
class Decoder(nn.Module):
|
class Decoder(nn.Module):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
num_heads: int,
|
num_heads: int,
|
||||||
|
|||||||
Reference in New Issue
Block a user