mirror of
https://github.com/pese-git/llm-arch-research.git
synced 2026-01-24 05:21:16 +00:00
Compare commits
7 Commits
feature/mi
...
92a34551b8
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
92a34551b8 | ||
|
|
ea932a36f3 | ||
|
|
cfb4b6dfb1 | ||
|
|
58c4a00b48 | ||
|
|
c9da4c841b | ||
|
|
b1737bbce2 | ||
|
|
1aba02cab9 |
19
experiments/llm_only/configs/gemma_generate.json
Normal file
19
experiments/llm_only/configs/gemma_generate.json
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
{
|
||||||
|
"bpe_tokenizer": "checkpoints/bpe_tokenizer.json",
|
||||||
|
"test_prompts": [
|
||||||
|
"Open weights",
|
||||||
|
"The Llama model is",
|
||||||
|
"Efficient transformers"
|
||||||
|
],
|
||||||
|
"model_config_path": "checkpoints/gemma-bpe/config.json",
|
||||||
|
"model_weights": "checkpoints/gemma-bpe/model.pt",
|
||||||
|
"generation": {
|
||||||
|
"max_new_tokens": 40,
|
||||||
|
"temperature": 0.8,
|
||||||
|
"do_sample": true,
|
||||||
|
"top_k": null,
|
||||||
|
"top_p": null
|
||||||
|
},
|
||||||
|
"log_path": "checkpoints/gemma_only_generation_logs.json"
|
||||||
|
}
|
||||||
|
|
||||||
28
experiments/llm_only/configs/gemma_train.json
Normal file
28
experiments/llm_only/configs/gemma_train.json
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
{
|
||||||
|
"bpe_tokenizer": "checkpoints/bpe_tokenizer.json",
|
||||||
|
"bpe_vocab_size": 1000,
|
||||||
|
"bpe_special_tokens": ["<pad>", "<unk>", "<bos>", "<eos>"],
|
||||||
|
"test_prompts": ["Open source AI", "What is Llama?"],
|
||||||
|
"model_config": {
|
||||||
|
"vocab_size": null,
|
||||||
|
"embed_dim": 256,
|
||||||
|
"num_q_heads": 4,
|
||||||
|
"num_kv_heads": 2,
|
||||||
|
"head_size": 64,
|
||||||
|
"num_layers": 4,
|
||||||
|
"max_position_embeddings": 512,
|
||||||
|
"num_experts": 8,
|
||||||
|
"top_k_experts": 2,
|
||||||
|
"window_size": 16,
|
||||||
|
"dropout": 0.1
|
||||||
|
},
|
||||||
|
"model_weights": "checkpoints/gemma-bpe/model.pt",
|
||||||
|
"model_config_path": "checkpoints/gemma-bpe/config.json",
|
||||||
|
"training": {
|
||||||
|
"learning_rate": 0.0003,
|
||||||
|
"batch_size": 2,
|
||||||
|
"num_epochs": 3,
|
||||||
|
"warmup_steps": 50
|
||||||
|
},
|
||||||
|
"log_path": "checkpoints/gemma_only_training_logs.json"
|
||||||
|
}
|
||||||
19
experiments/llm_only/configs/mixtral_generate.json
Normal file
19
experiments/llm_only/configs/mixtral_generate.json
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
{
|
||||||
|
"bpe_tokenizer": "checkpoints/bpe_tokenizer.json",
|
||||||
|
"test_prompts": [
|
||||||
|
"Open weights",
|
||||||
|
"The Llama model is",
|
||||||
|
"Efficient transformers"
|
||||||
|
],
|
||||||
|
"model_config_path": "checkpoints/mixtral-bpe/config.json",
|
||||||
|
"model_weights": "checkpoints/mixtral-bpe/model.pt",
|
||||||
|
"generation": {
|
||||||
|
"max_new_tokens": 40,
|
||||||
|
"temperature": 0.8,
|
||||||
|
"do_sample": true,
|
||||||
|
"top_k": null,
|
||||||
|
"top_p": null
|
||||||
|
},
|
||||||
|
"log_path": "checkpoints/mixtral_only_generation_logs.json"
|
||||||
|
}
|
||||||
|
|
||||||
28
experiments/llm_only/configs/mixtral_train.json
Normal file
28
experiments/llm_only/configs/mixtral_train.json
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
{
|
||||||
|
"bpe_tokenizer": "checkpoints/bpe_tokenizer.json",
|
||||||
|
"bpe_vocab_size": 1000,
|
||||||
|
"bpe_special_tokens": ["<pad>", "<unk>", "<bos>", "<eos>"],
|
||||||
|
"test_prompts": ["Open source AI", "What is Llama?"],
|
||||||
|
"model_config": {
|
||||||
|
"vocab_size": null,
|
||||||
|
"embed_dim": 256,
|
||||||
|
"num_q_heads": 4,
|
||||||
|
"num_kv_heads": 2,
|
||||||
|
"head_size": 64,
|
||||||
|
"num_layers": 4,
|
||||||
|
"max_position_embeddings": 512,
|
||||||
|
"num_experts": 8,
|
||||||
|
"top_k_experts": 2,
|
||||||
|
"window_size": 16,
|
||||||
|
"dropout": 0.1
|
||||||
|
},
|
||||||
|
"model_weights": "checkpoints/mixtral-bpe/model.pt",
|
||||||
|
"model_config_path": "checkpoints/mixtral-bpe/config.json",
|
||||||
|
"training": {
|
||||||
|
"learning_rate": 0.0003,
|
||||||
|
"batch_size": 2,
|
||||||
|
"num_epochs": 3,
|
||||||
|
"warmup_steps": 50
|
||||||
|
},
|
||||||
|
"log_path": "checkpoints/mixtral_only_training_logs.json"
|
||||||
|
}
|
||||||
@@ -45,6 +45,12 @@ def load_model_class(model_name):
|
|||||||
elif model_name.lower() == 'mistral':
|
elif model_name.lower() == 'mistral':
|
||||||
from llm.models.mistral import Mistral
|
from llm.models.mistral import Mistral
|
||||||
return Mistral
|
return Mistral
|
||||||
|
elif model_name.lower() == 'mixtral':
|
||||||
|
from llm.models.mixtral import Mixtral
|
||||||
|
return Mixtral
|
||||||
|
elif model_name.lower() == 'gemma':
|
||||||
|
from llm.models.gemma import Gemma
|
||||||
|
return Gemma
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Модель '{model_name}' не поддерживается.")
|
raise ValueError(f"Модель '{model_name}' не поддерживается.")
|
||||||
|
|
||||||
|
|||||||
140
llm/src/llm/core/geglu.py
Normal file
140
llm/src/llm/core/geglu.py
Normal file
@@ -0,0 +1,140 @@
|
|||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from llm.core.gelu import GELU
|
||||||
|
|
||||||
|
class GeGLU(nn.Module):
|
||||||
|
"""
|
||||||
|
GeGLU (Gated GELU Linear Unit) — эффективная нелинейность для feed-forward блоков в современных трансформерах.
|
||||||
|
|
||||||
|
Назначение:
|
||||||
|
-----------
|
||||||
|
GeGLU — это вариант GLU (Gated Linear Unit), где «шлюз» реализован через GELU-активацию,
|
||||||
|
а затем поэлементно перемножается с другим линейным преобразованием. Такой gating-механизм позволяет повысить
|
||||||
|
выразительность MLP-блока и ускорить обучение, что подтверждено экспериментами на LLM (см. PaLM, LLaMA, T5).
|
||||||
|
|
||||||
|
Формула:
|
||||||
|
--------
|
||||||
|
GeGLU(x) = GELU(W_g x + b_g) ⊙ (W_u x + b_u) W_d + b_d
|
||||||
|
(здесь W_g, W_u, W_d — матрицы весов; GELU применяется к одной ветке, ⊙ — поэлементное умножение)
|
||||||
|
|
||||||
|
Структура блока:
|
||||||
|
----------------
|
||||||
|
1. gate = GELU(Linear_gate(x)) # ветка gating-а, shape [batch, seq, 4×emb]
|
||||||
|
2. up = Linear_up(x) # ветка передачи, shape [batch, seq, 4×emb]
|
||||||
|
3. out = gate * up # поэлементно, реализует динамическую фильтрацию информации
|
||||||
|
4. out = Linear_down(out) # проекция обратно в исходное пространство
|
||||||
|
5. out = Dropout(out) # регуляризация
|
||||||
|
|
||||||
|
Основные преимущества:
|
||||||
|
----------------------
|
||||||
|
- Позволяет эффективно обучать глубокие трансформеры (см. PaLM, LLaMA).
|
||||||
|
- Обеспечивает плавные градиенты за счёт GELU и gating-эффекта.
|
||||||
|
- Используется во многих современных LLM вместо обычных FFN или простых GLU.
|
||||||
|
|
||||||
|
Аргументы конструктора:
|
||||||
|
-----------------------
|
||||||
|
emb_size : int
|
||||||
|
Размер эмбеддинга (input и output).
|
||||||
|
dropout : float, по умолчанию 0.1
|
||||||
|
Dropout к финальному выходу (примерно 0.1-0.2 для регуляризации).
|
||||||
|
|
||||||
|
Пример использования:
|
||||||
|
---------------------
|
||||||
|
>>> geglu = GeGLU(emb_size=512, dropout=0.1)
|
||||||
|
>>> x = torch.randn(8, 16, 512)
|
||||||
|
>>> y = geglu(x)
|
||||||
|
>>> print(y.shape) # torch.Size([8, 16, 512])
|
||||||
|
|
||||||
|
Литература:
|
||||||
|
-----------
|
||||||
|
- Shazeer N., "GLU Variants Improve Transformer", 2020: https://arxiv.org/abs/2002.05202
|
||||||
|
- PaLM: https://arxiv.org/abs/2204.02311
|
||||||
|
- LLaMA: https://arxiv.org/abs/2302.13971
|
||||||
|
- T5: https://arxiv.org/abs/1910.10683
|
||||||
|
"""
|
||||||
|
def __init__(self, emb_size: int, dropout: float = 0.1):
|
||||||
|
"""
|
||||||
|
Инициализация блока GeGLU.
|
||||||
|
|
||||||
|
Создаёт три последовательных линейных слоя и задаёт GELU в качестве активации для ветки gating,
|
||||||
|
а также финальный dropout. Все размеры согласованы так, чтобы реализовать формулу GeGLU (см. описание класса).
|
||||||
|
|
||||||
|
Аргументы:
|
||||||
|
----------
|
||||||
|
emb_size : int
|
||||||
|
Размерность входного и выходного скрытого пространства (hidden size).
|
||||||
|
Данная величина определяет размерность эмбеддинга для всех внутренних вычислений.
|
||||||
|
Обычно равна размеру скрытого слоя трансформера.
|
||||||
|
|
||||||
|
dropout : float, по умолчанию 0.1
|
||||||
|
Вероятность отключения нейронов после выхода из блока (регуляризация).
|
||||||
|
Рекомендуемое значение: 0.1 (или чуть больше для небольших моделей).
|
||||||
|
|
||||||
|
Внутри:
|
||||||
|
-------
|
||||||
|
- self._gate: Linear слой размерности [emb_size, 4 * emb_size], ветка gating (проходит через GELU)
|
||||||
|
- self._up: Linear слой размерности [emb_size, 4 * emb_size], ветка передачи ("пропускная")
|
||||||
|
- self._down: Linear слой сжатия обратно к emb_size
|
||||||
|
- self._activation: Активация GELU для gating-ветки
|
||||||
|
- self._dropout: Dropout для выходного тензора
|
||||||
|
|
||||||
|
Пример:
|
||||||
|
-------
|
||||||
|
>>> block = GeGLU(emb_size=256, dropout=0.1)
|
||||||
|
>>> print(block)
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self._gate = nn.Linear(emb_size, 4 * emb_size)
|
||||||
|
self._up = nn.Linear(emb_size, 4 * emb_size)
|
||||||
|
self._down = nn.Linear(4 * emb_size, emb_size)
|
||||||
|
self._activation = GELU()
|
||||||
|
self._dropout = nn.Dropout(dropout)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
"""
|
||||||
|
Прямой проход (forward) через блок GeGLU.
|
||||||
|
|
||||||
|
Для входного тензора скрытых состояний x реализует последовательность операций:
|
||||||
|
1. Gating-ветка: линейное преобразование → GELU-активация
|
||||||
|
2. Пропускная ветка: линейное преобразование
|
||||||
|
3. Поэлементное умножение результатов обеих веток (gating)
|
||||||
|
4. Проекция через Linear обратно к emb_size
|
||||||
|
5. Dropout результата для регуляризации
|
||||||
|
|
||||||
|
Математически:
|
||||||
|
--------------
|
||||||
|
gate = GELU(W_g·x + b_g)
|
||||||
|
up = W_u·x + b_u
|
||||||
|
out = gate * up
|
||||||
|
out = W_d·out + b_d
|
||||||
|
out = Dropout(out)
|
||||||
|
|
||||||
|
Аргументы:
|
||||||
|
----------
|
||||||
|
x : torch.Tensor
|
||||||
|
Входной тензор формы [batch_size, seq_len, emb_size]
|
||||||
|
(или любой совместимой формы, где последняя ось — emb_size).
|
||||||
|
|
||||||
|
Возвращает:
|
||||||
|
-----------
|
||||||
|
torch.Tensor :
|
||||||
|
Тензор той же формы [batch_size, seq_len, emb_size], прошедший через структуру GeGLU.
|
||||||
|
|
||||||
|
Пример:
|
||||||
|
-------
|
||||||
|
>>> y = geglu(x)
|
||||||
|
>>> print(y.shape) # [batch_size, seq_len, emb_size]
|
||||||
|
|
||||||
|
Примечания:
|
||||||
|
-----------
|
||||||
|
- Ветка gating строит masк для динамической фильтрации информации.
|
||||||
|
- Такой тип блока эффективно используется как замена обычного FFN в современных LLM.
|
||||||
|
"""
|
||||||
|
gate_out = self._gate(x) # [batch, seq, 4*emb]
|
||||||
|
activation_out = self._activation(gate_out) # [batch, seq, 4*emb]
|
||||||
|
up_out = self._up(x) # [batch, seq, 4*emb]
|
||||||
|
out = up_out * activation_out # поэлементное!
|
||||||
|
out = self._down(out) # [batch, seq, emb]
|
||||||
|
return self._dropout(out)
|
||||||
|
|
||||||
188
llm/src/llm/core/gemma_decoder.py
Normal file
188
llm/src/llm/core/gemma_decoder.py
Normal file
@@ -0,0 +1,188 @@
|
|||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from llm.core.rope import RoPE
|
||||||
|
from llm.core.multi_query_attention import MultiQueryAttention
|
||||||
|
from llm.core.rms_norm import RMSNorm
|
||||||
|
from llm.core.geglu import GeGLU
|
||||||
|
|
||||||
|
class GemmaDecoder(nn.Module):
|
||||||
|
"""
|
||||||
|
GemmaDecoder — декодерный блок архитектуры Gemma (Google DeepMind, 2024).
|
||||||
|
|
||||||
|
Назначение:
|
||||||
|
-----------
|
||||||
|
Данный блок реализует одну «ячейку» декодерного стека в модели Gemma. Архитектура схожа с современными LLM (Llama/Mistral),
|
||||||
|
но имеет уникальные особенности attention и feed-forward слоёв, соответствующие спецификации Gemma.
|
||||||
|
|
||||||
|
Архитектурные компоненты:
|
||||||
|
-------------------------
|
||||||
|
- LayerNorm или RMSNorm
|
||||||
|
- Multi-head self-attention (обычно Multi-Query Attention)
|
||||||
|
- Skip connection (остаточное сложение)
|
||||||
|
- Feed-forward блок (может включать SwiGLU, GeGLU или классический FFN)
|
||||||
|
- Повторная нормализация
|
||||||
|
- Dropout (регуляризация на уровне attention и feed-forward)
|
||||||
|
|
||||||
|
Алгоритм прямого прохода:
|
||||||
|
-------------------------
|
||||||
|
1. norm1_out = LayerNorm(x)
|
||||||
|
2. attention_out = Attention(norm1_out, ...)
|
||||||
|
3. resid1 = attention_out + x
|
||||||
|
4. norm2_out = LayerNorm(resid1)
|
||||||
|
5. ffn_out = FeedForward(norm2_out)
|
||||||
|
6. output = ffn_out + resid1
|
||||||
|
|
||||||
|
Теоретические детали:
|
||||||
|
---------------------
|
||||||
|
- В Gemma используются техники оптимизации памяти и ускорения инференса (например, shared K/V-головы, Rope, кастомные FFN).
|
||||||
|
- Поддержка кэширования attention для ускорения генерации (KV cache).
|
||||||
|
- Блок проектирован для использования в стеке, повторяется N раз во всей LLM.
|
||||||
|
|
||||||
|
Аргументы конструктора:
|
||||||
|
----------------------
|
||||||
|
num_q_heads : int
|
||||||
|
Число голов query (Query Heads) для attention.
|
||||||
|
num_kv_heads : int
|
||||||
|
Число ключевых/значенческих голов (Key/Value Heads).
|
||||||
|
emb_size : int
|
||||||
|
Размерность скрытого пространства (embedding dim).
|
||||||
|
head_size : int
|
||||||
|
Размерность одной attention-головы.
|
||||||
|
max_seq_len : int
|
||||||
|
Максимальная длина последовательности (ограничение на causal mask).
|
||||||
|
dropout : float, optional
|
||||||
|
Dropout для регуляризации (примерно 0.0–0.1).
|
||||||
|
rope : RoPE, optional
|
||||||
|
Позиционное кодирование Rotary Position Embedding.
|
||||||
|
|
||||||
|
Пример использования:
|
||||||
|
---------------------
|
||||||
|
>>> decoder = GemmaDecoder(
|
||||||
|
... num_q_heads=8,
|
||||||
|
... num_kv_heads=2,
|
||||||
|
... emb_size=256,
|
||||||
|
... head_size=32,
|
||||||
|
... max_seq_len=1024,
|
||||||
|
... dropout=0.1,
|
||||||
|
... rope=rope_obj
|
||||||
|
... )
|
||||||
|
>>> x = torch.randn(2, 24, 256)
|
||||||
|
>>> out, cache = decoder(x, mask=None, use_cache=True, cache=None)
|
||||||
|
>>> print(out.shape) # torch.Size([2, 24, 256])
|
||||||
|
|
||||||
|
Литература и ссылки:
|
||||||
|
--------------------
|
||||||
|
- Gemma (официальный релиз): https://ai.google.dev/gemma
|
||||||
|
- Gemma paper: https://arxiv.org/abs/2403.07794
|
||||||
|
- Rotary Embedding: https://arxiv.org/abs/2104.09864
|
||||||
|
- Multi-Query Attention: https://arxiv.org/abs/1911.02150
|
||||||
|
- Llama: https://arxiv.org/abs/2302.13971
|
||||||
|
"""
|
||||||
|
def __init__(self,
|
||||||
|
num_q_heads: int,
|
||||||
|
emb_size: int,
|
||||||
|
head_size: int,
|
||||||
|
max_seq_len: int,
|
||||||
|
rope: RoPE,
|
||||||
|
dropout: float = 0.1
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Конструктор слоя GemmaDecoder.
|
||||||
|
|
||||||
|
Производит инициализацию всех подслоёв (нормализация, multi-head или multi-query attention, feed-forward блок, Dropout)
|
||||||
|
согласно архитектуре декодера Gemma. Обеспечивает поддержку rotary-позиционирования, обучения и inference с caching.
|
||||||
|
|
||||||
|
Аргументы:
|
||||||
|
----------
|
||||||
|
num_q_heads : int
|
||||||
|
Количество query-голов в attention (определяет степень параллелизма внимания).
|
||||||
|
emb_size : int
|
||||||
|
Размер пространства эмбеддинга (embedding dim, input/output размерность слоя).
|
||||||
|
head_size : int
|
||||||
|
Размерность одной attention-головы. Обычно emb_size // num_q_heads.
|
||||||
|
max_seq_len : int
|
||||||
|
Максимальная длина последовательности, для которой поддерживается attention и маскирование.
|
||||||
|
rope : RoPE
|
||||||
|
Объект для rotary positional encoding (позиционное кодирование для attention).
|
||||||
|
dropout : float, default=0.1
|
||||||
|
Dropout после attention и feed-forward для регуляризации (обычно 0.0–0.1).
|
||||||
|
|
||||||
|
Внутри:
|
||||||
|
-------
|
||||||
|
- Инициализируются все слои norm, attention, rope, FFN, остаточные соединения.
|
||||||
|
- Строится causal-маска автоагрессивного attention (если требуется).
|
||||||
|
- Гибко поддерживает работу как на training, так и для быстрых inference/генерации.
|
||||||
|
|
||||||
|
Пример:
|
||||||
|
-------
|
||||||
|
>>> decoder = GemmaDecoder(
|
||||||
|
... num_q_heads=8, emb_size=512, head_size=64, max_seq_len=1024, rope=rope_obj, dropout=0.05
|
||||||
|
... )
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self._heads = MultiQueryAttention(
|
||||||
|
num_q_heads=num_q_heads,
|
||||||
|
emb_size=emb_size,
|
||||||
|
head_size=head_size,
|
||||||
|
max_seq_len=max_seq_len,
|
||||||
|
rope=rope,
|
||||||
|
dropout=dropout
|
||||||
|
)
|
||||||
|
self._ff = GeGLU(emb_size=emb_size, dropout=dropout)
|
||||||
|
self._norm1 = RMSNorm(emb_size)
|
||||||
|
self._norm2 = RMSNorm(emb_size)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, mask: torch.Tensor = None, use_cache: bool = True, cache: list = None) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Прямой проход (forward) через GemmaDecoder.
|
||||||
|
|
||||||
|
Последовательно реализует:
|
||||||
|
- Нормализацию входа (обычно RMSNorm или LayerNorm)
|
||||||
|
- Self-attention (multi-query или multi-head, с опциональной маской и кэшем)
|
||||||
|
- Остаточное сложение (skip connection)
|
||||||
|
- Вторую нормализацию
|
||||||
|
- Feed-Forward-блок (например, GeGLU/SwiGLU)
|
||||||
|
- Ещё одно residual сложение
|
||||||
|
|
||||||
|
Поддерживает autoregressive режим с caching (KV-слоты attention для ускорения генерации).
|
||||||
|
|
||||||
|
Аргументы:
|
||||||
|
----------
|
||||||
|
x : torch.Tensor
|
||||||
|
Входной скрытый тензор формы [batch_size, seq_length, emb_size].
|
||||||
|
mask : torch.Tensor, optional
|
||||||
|
Attention mask (например, causal или padding mask). Если None, используется встроенная causal mask.
|
||||||
|
use_cache : bool, по умолчанию True
|
||||||
|
Если True — возвращается кэш KV для ускорения autoregressive генерации.
|
||||||
|
cache : list, optional
|
||||||
|
Кэш предыдущих ключей/значений attention (если используется при инференсе).
|
||||||
|
|
||||||
|
Возвращает:
|
||||||
|
-----------
|
||||||
|
Tuple[torch.Tensor, cache]:
|
||||||
|
- Выход декодера с той же формой [batch_size, seq_length, emb_size]
|
||||||
|
- Кэш attention (если use_cache=True), иначе None
|
||||||
|
|
||||||
|
Пример:
|
||||||
|
-------
|
||||||
|
>>> out, new_cache = decoder(x, mask=att_mask, use_cache=True, cache=old_cache)
|
||||||
|
>>> out.shape # [batch_size, seq_len, emb_size]
|
||||||
|
|
||||||
|
Примечания:
|
||||||
|
-----------
|
||||||
|
- mask используется для ограничения внимания (напр., каузальный режим GPT/LLM).
|
||||||
|
- Для ускорения в режиме генерации рекомендуется использовать use_cache=True + передавать cache.
|
||||||
|
|
||||||
|
"""
|
||||||
|
norm1_out = self._norm1(x)
|
||||||
|
attention, kv_caches = self._heads(norm1_out, mask, use_cache=use_cache, cache=cache)
|
||||||
|
out = attention + x
|
||||||
|
|
||||||
|
norm2_out = self._norm2(out)
|
||||||
|
ffn_out = self._ff(norm2_out)
|
||||||
|
|
||||||
|
if use_cache is True:
|
||||||
|
return (ffn_out + out, kv_caches)
|
||||||
|
else:
|
||||||
|
return (ffn_out + out, None)
|
||||||
211
llm/src/llm/core/mixtral_decoder.py
Normal file
211
llm/src/llm/core/mixtral_decoder.py
Normal file
@@ -0,0 +1,211 @@
|
|||||||
|
|
||||||
|
|
||||||
|
from torch import nn
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from llm.core.rope import RoPE
|
||||||
|
from llm.core.group_query_attention import GroupedQueryAttention
|
||||||
|
from llm.core.moe import MoE
|
||||||
|
from llm.core.rms_norm import RMSNorm
|
||||||
|
|
||||||
|
class MixtralDecoder(nn.Module):
|
||||||
|
"""
|
||||||
|
MixtralDecoder — декодерный блок для Mixtral/MoE-трансформеров (см. Mixtral 8x7B, Mistral v0.2 и др.).
|
||||||
|
|
||||||
|
Назначение:
|
||||||
|
-----------
|
||||||
|
MixtralDecoder реализует один модульный слой глубокой трансформерной архитектуры с Mixture-of-Experts (MoE) Feed-Forward Network и Grouped Query Attention (GQA).
|
||||||
|
Поддерживает разреженную активацию и масштабируемое количество экспертов, оптимально для больших LLM.
|
||||||
|
|
||||||
|
Архитектура блока:
|
||||||
|
------------------
|
||||||
|
- RMSNorm -> Grouped Query Attention (GQA)
|
||||||
|
- skip-connection
|
||||||
|
- RMSNorm -> MoE (SwiGLU-эксперты)
|
||||||
|
- skip-connection
|
||||||
|
|
||||||
|
Для входа `x` проходит:
|
||||||
|
1. norm1_out = RMSNorm(x)
|
||||||
|
2. attention, kv_caches = GQA(norm1_out, ...)
|
||||||
|
3. out = attention + x # residual connection
|
||||||
|
4. norm2_out = RMSNorm(out)
|
||||||
|
5. ffn_out = MoE(norm2_out)
|
||||||
|
6. return (ffn_out + out, kv_caches)
|
||||||
|
|
||||||
|
Теоретическая мотивация:
|
||||||
|
------------------------
|
||||||
|
- Использование MoE (см. https://arxiv.org/abs/1701.06538) позволяет кратно увеличивать capacity без роста затрат на ff-часть.
|
||||||
|
- Grouped Query Attention эффективно масштабирует self-attention для больших моделей (см. Mistral, Llama 2/3).
|
||||||
|
- RMSNorm (Root Mean Square LayerNorm) стабилизирует градиенты и память.
|
||||||
|
- Является строительным блоком для стека декодеров в Mixtral-моделях (см. Mixtral, Mistral, LLaMA).
|
||||||
|
|
||||||
|
Аргументы конструктора:
|
||||||
|
----------------------
|
||||||
|
num_q_heads : int
|
||||||
|
Число query-голов в attention.
|
||||||
|
num_kv_heads : int
|
||||||
|
Число key-value голов (группировка ключей/values).
|
||||||
|
emb_size : int
|
||||||
|
Скрытый размер эмбеддинга.
|
||||||
|
head_size : int
|
||||||
|
Размерность одной головы (emb_size // num_q_heads).
|
||||||
|
max_seq_len : int
|
||||||
|
Максимальная поддерживаемая длина последовательности.
|
||||||
|
num_experts : int
|
||||||
|
Количество «экспертов» (MoE).
|
||||||
|
top_k_experts : int
|
||||||
|
Сколько одновременно экспертов активируется для одного токена.
|
||||||
|
window_size : int
|
||||||
|
Размер окна внимания (используется для efficient attention).
|
||||||
|
rope : RoPE
|
||||||
|
Реализация позиционного кодирования RoPE.
|
||||||
|
dropout : float
|
||||||
|
Вероятность Dropout для регуляризации.
|
||||||
|
|
||||||
|
Пример использования:
|
||||||
|
---------------------
|
||||||
|
>>> decoder = MixtralDecoder(... параметры ...)
|
||||||
|
>>> x = torch.randn(batch, seq, emb_size)
|
||||||
|
>>> out, cache = decoder(x, mask=None, use_cache=True)
|
||||||
|
>>> out.shape
|
||||||
|
|
||||||
|
Литература и ссылки:
|
||||||
|
--------------------
|
||||||
|
- Mixtral 8x7B: https://mistral.ai/news/mixtral-of-experts/
|
||||||
|
- Shazeer et al., “Outrageously Large Neural Networks: The Sparsely-Gated Mixture-of-Experts Layer”, 2017. https://arxiv.org/abs/1701.06538
|
||||||
|
- Mistral paper: https://arxiv.org/abs/2310.06825
|
||||||
|
- GQA: https://arxiv.org/abs/2305.14236
|
||||||
|
- RMSNorm: https://arxiv.org/abs/1910.07467
|
||||||
|
|
||||||
|
"""
|
||||||
|
def __init__(self,
|
||||||
|
num_q_heads: int,
|
||||||
|
num_kv_heads: int,
|
||||||
|
emb_size: int,
|
||||||
|
head_size: int,
|
||||||
|
max_seq_len: int,
|
||||||
|
num_experts: int,
|
||||||
|
top_k_experts: int,
|
||||||
|
window_size: int,
|
||||||
|
rope: RoPE,
|
||||||
|
dropout: float = 0.1
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Конструктор декодерного блока MixtralDecoder.
|
||||||
|
|
||||||
|
Осуществляет инициализацию всех под-компонентов слоя: Attention (Grouped Query Attention), MoE (Mixture-of-Experts, SwiGLU)
|
||||||
|
и нормализации (RMSNorm). Позволяет гибко настраивать архитектуру под специфику задач и размеры LLM.
|
||||||
|
|
||||||
|
Аргументы:
|
||||||
|
----------
|
||||||
|
num_q_heads : int
|
||||||
|
Количество голов внимания (queries) в механизме GroupedQueryAttention.
|
||||||
|
Чем больше — тем тоньше дискретизация внимания по подпространствам признаков.
|
||||||
|
num_kv_heads : int
|
||||||
|
Количество групп ключей/значений (key-value heads) для GQA.
|
||||||
|
Позволяет балансировать производительность и память.
|
||||||
|
emb_size : int
|
||||||
|
Размерность эмбеддингового пространства внутри слоя (hidden).
|
||||||
|
head_size : int
|
||||||
|
Размерность одной attention-головы. Обычно emb_size // num_q_heads.
|
||||||
|
max_seq_len : int
|
||||||
|
Максимально поддерживаемая длина токенизированной последовательности.
|
||||||
|
num_experts : int
|
||||||
|
Количество экспертов в слое MoE (размер пула SwiGLU-экспертов).
|
||||||
|
top_k_experts : int
|
||||||
|
Сколько экспертов по роутингу активируется на 1 токен (разреженность — эффективная экономия вычислений).
|
||||||
|
window_size : int
|
||||||
|
Размер окна для attention (может использоваться для ограничения receptive field, как в Mistral).
|
||||||
|
rope : RoPE
|
||||||
|
Объект позиционного кодирования RoPE (Rotary Positional Embedding), необходим для архитектуры внимания.
|
||||||
|
dropout : float, по умолчанию 0.1
|
||||||
|
Вероятность зануляции выходных значений для регуляризации и борьбы с переобучением.
|
||||||
|
|
||||||
|
Пример:
|
||||||
|
-------
|
||||||
|
>>> decoder = MixtralDecoder(
|
||||||
|
... num_q_heads=8,
|
||||||
|
... num_kv_heads=2,
|
||||||
|
... emb_size=256,
|
||||||
|
... head_size=32,
|
||||||
|
... max_seq_len=1024,
|
||||||
|
... num_experts=4,
|
||||||
|
... top_k_experts=2,
|
||||||
|
... window_size=128,
|
||||||
|
... rope=rope_module,
|
||||||
|
... dropout=0.05
|
||||||
|
... )
|
||||||
|
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self._heads = GroupedQueryAttention(
|
||||||
|
num_q_heads=num_q_heads,
|
||||||
|
num_kv_heads=num_kv_heads,
|
||||||
|
emb_size=emb_size,
|
||||||
|
head_size=head_size,
|
||||||
|
max_seq_len=max_seq_len,
|
||||||
|
window_size=window_size,
|
||||||
|
rope=rope,
|
||||||
|
dropout=dropout
|
||||||
|
)
|
||||||
|
self._ff = MoE(
|
||||||
|
emb_size=emb_size,
|
||||||
|
num_experts=num_experts,
|
||||||
|
top_k_experts=top_k_experts,
|
||||||
|
dropout=dropout
|
||||||
|
)
|
||||||
|
self._norm1 = RMSNorm(emb_size)
|
||||||
|
self._norm2 = RMSNorm(emb_size)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, mask: torch.Tensor = None, use_cache: bool = True, cache: list = None) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Прямой проход (forward) через декодерный блок MixtralDecoder.
|
||||||
|
|
||||||
|
Данный метод реализует последовательную обработку входных скрытых состояний (x) через:
|
||||||
|
- нормализацию (RMSNorm),
|
||||||
|
- attention-модуль (Grouped Query Attention) с опциональным применением маски и кэша ключей/значений для ускорения инференса,
|
||||||
|
- остаточное сложение (residual connection),
|
||||||
|
- повторную нормализацию,
|
||||||
|
- feed-forward блок на основе Mixture-of-Experts (MoE),
|
||||||
|
- финальное остаточное сложение.
|
||||||
|
|
||||||
|
Аргументы:
|
||||||
|
----------
|
||||||
|
x : torch.Tensor
|
||||||
|
Входной скрытый тензор формы [batch_size, seq_len, emb_size] — результат эмбеддинга токенов либо предыдущего слоя.
|
||||||
|
mask : torch.Tensor, optional
|
||||||
|
(Необязательно) Маска внимания для ограничения области self-attention (например, для автоперемешивания или causal-LLM-моделей).
|
||||||
|
use_cache : bool, по умолчанию True
|
||||||
|
Если True — сохраняет кэш ключей/значений attention для ускорения авторегрессии (инференса).
|
||||||
|
cache : list, optional
|
||||||
|
(Необязательно) Предварительно вычисленный кеш attention (для ускорения генерации длинного текста).
|
||||||
|
|
||||||
|
Возвращает:
|
||||||
|
-----------
|
||||||
|
Tuple[torch.Tensor, Any]:
|
||||||
|
- Первый элемент: скрытый тензор выхода слоя с той же формой, что вход (последовательный residual из attention и MoE-блока).
|
||||||
|
- Второй элемент: обновлённый кэш attention (если use_cache=True), иначе None.
|
||||||
|
|
||||||
|
Пример:
|
||||||
|
-------
|
||||||
|
>>> out, cache = decoder(x, mask=att_mask, use_cache=True, cache=old_cache)
|
||||||
|
>>> out.shape # [batch_size, seq_len, emb_size]
|
||||||
|
|
||||||
|
Примечания:
|
||||||
|
-----------
|
||||||
|
- Для autoregressive-генерации (GPT-like режимов) следует передавать mask и использовать use_cache=True.
|
||||||
|
- Реализация поддерживает произвольные батчи и длины последовательностей, в пределах max_seq_len слоя.
|
||||||
|
- Модуль MixtralDecoder обычно используется в виде стека (несколько подряд) внутри крупной LLM.
|
||||||
|
|
||||||
|
"""
|
||||||
|
norm1_out = self._norm1(x)
|
||||||
|
attention, kv_caches = self._heads(norm1_out, mask, use_cache=use_cache, cache=cache)
|
||||||
|
out = attention + x
|
||||||
|
|
||||||
|
norm2_out = self._norm2(out)
|
||||||
|
ffn_out = self._ff(norm2_out)
|
||||||
|
|
||||||
|
if use_cache is True:
|
||||||
|
return (ffn_out + out, kv_caches)
|
||||||
|
else:
|
||||||
|
return (ffn_out + out, None)
|
||||||
229
llm/src/llm/core/moe.py
Normal file
229
llm/src/llm/core/moe.py
Normal file
@@ -0,0 +1,229 @@
|
|||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from llm.core.swi_glu import SwiGLU
|
||||||
|
|
||||||
|
class MoE(nn.Module):
|
||||||
|
"""
|
||||||
|
MoE (Mixture of Experts) — слой «смеси экспертов» для современных трансформерных архитектур с разреженной активацией.
|
||||||
|
|
||||||
|
Назначение:
|
||||||
|
-----------
|
||||||
|
Класс реализует слой разреженного условного вычисления для увеличения capacity трансформеров без роста вычислительных затрат.
|
||||||
|
Для каждого токена из последовательности выбирается (с помощью роутера) наиболее подходящее подмножество экспертов (малых нейросетей).
|
||||||
|
Итоговый выход формируется как взвешенная сумма откликов экспертов, выбранных для данного токена.
|
||||||
|
|
||||||
|
Архитектурная схема:
|
||||||
|
---------------------
|
||||||
|
- Для каждого входного токена `x` роутер (обычно один Linear-слой) предсказывает skor, насколько каждый из `num_experts` релевантен.
|
||||||
|
- Для каждого токена выбираются top_k_experts с максимальными skor; только они обрабатывают этот токен.
|
||||||
|
- Каждый эксперт здесь представлен отдельным экземпляром блока `SwiGLU` (может быть любая небольшая feed-forward сеть).
|
||||||
|
- Выход каждого эксперта умножается на индивидуальный вес (softmax по skor) — агрегируется взвешенная сумма.
|
||||||
|
- Dropout применяется к итоговому выходу.
|
||||||
|
|
||||||
|
Математика (коротко):
|
||||||
|
---------------------
|
||||||
|
Пусть X ∈ R^{BxSxD} — вход,
|
||||||
|
E — число экспертов,
|
||||||
|
K — число активируемых экспертов на токен.
|
||||||
|
r(x) = softmax(W_r x) — роутинг-логиты, top-K берём индексы и веса.
|
||||||
|
Для каждого токена:
|
||||||
|
y_j = Expert_j(x)
|
||||||
|
y = sum_j(w_j * y_j), где j пробегает по выбранным экспертам
|
||||||
|
Output: Y ∈ R^{BxSxD}
|
||||||
|
|
||||||
|
Аргументы конструктора:
|
||||||
|
----------------------
|
||||||
|
emb_size : int
|
||||||
|
Размерность входных/выходных векторов (обычно совпадает с embedding модели).
|
||||||
|
num_experts : int
|
||||||
|
Общее число экспертов внутри слоя MoE.
|
||||||
|
top_k_experts : int
|
||||||
|
Сколько экспертов активировать и агрегировать на каждом токене (обычно 2-8).
|
||||||
|
dropout : float, по умолчанию 0.1
|
||||||
|
Dropout к выходу агрегатора.
|
||||||
|
|
||||||
|
Пример использования:
|
||||||
|
---------------------
|
||||||
|
>>> moe = MoE(emb_size=512, num_experts=8, top_k_experts=2, dropout=0.1)
|
||||||
|
>>> x = torch.randn(4, 16, 512)
|
||||||
|
>>> y = moe(x)
|
||||||
|
>>> y.shape # torch.Size([4, 16, 512])
|
||||||
|
|
||||||
|
Литература:
|
||||||
|
-----------
|
||||||
|
- Shazeer, N. et al. “Outrageously Large Neural Networks: The Sparsely-Gated Mixture-of-Experts Layer”, 2017. https://arxiv.org/abs/1701.06538
|
||||||
|
- Fedus, W., Zoph, B., & Shazeer, N. “Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity”, 2021. https://arxiv.org/abs/2101.03961
|
||||||
|
- Mistral/Mixtral: https://mistral.ai/news/mixtral-of-experts/
|
||||||
|
"""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
emb_size: int,
|
||||||
|
num_experts: int,
|
||||||
|
top_k_experts: int,
|
||||||
|
dropout: float = 0.1,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Конструктор слоя MoE (Mixture of Experts).
|
||||||
|
|
||||||
|
Позволяет создать слой, состоящий из набора экспертов (например, отдельных небольших feedforward-нейросетей) и роутера,
|
||||||
|
который будет для каждого токена определять наиболее релевантных экспертов.
|
||||||
|
Часть экспертов (top_k_experts) активируется для каждого токена, остальные — пропускаются.
|
||||||
|
|
||||||
|
Аргументы:
|
||||||
|
----------
|
||||||
|
emb_size : int
|
||||||
|
Размерность входных и выходных векторов (embedding size).
|
||||||
|
Определяет, над каким пространством признаков будет работать роутер и эксперты.
|
||||||
|
Например, если скрытый размер слоя трансформера 512, сюда нужно передать 512.
|
||||||
|
|
||||||
|
num_experts : int
|
||||||
|
Общее количество экспертов в слое MoE.
|
||||||
|
Чем больше экспертов — тем больше capacity у модели, но тем выше требования к RAM/VRAM при обучении.
|
||||||
|
Пример: 8, 16, 32, 64.
|
||||||
|
|
||||||
|
top_k_experts : int
|
||||||
|
Сколько экспертов одновременно будет обрабатывать каждый токен.
|
||||||
|
Обычно 2–8. Меньшее значение — выше разреженность, больше экономия вычислений.
|
||||||
|
|
||||||
|
dropout : float, по умолчанию 0.1
|
||||||
|
Вероятность зануления значений на выходе после агрегации откликов экспертов.
|
||||||
|
Используется для регуляризации (борьбы с переобучением).
|
||||||
|
|
||||||
|
Пример:
|
||||||
|
-------
|
||||||
|
>>> moe = MoE(emb_size=256, num_experts=8, top_k_experts=2, dropout=0.1)
|
||||||
|
>>> print(moe)
|
||||||
|
MoE( ... )
|
||||||
|
|
||||||
|
Теория:
|
||||||
|
-------
|
||||||
|
Слой строит:
|
||||||
|
- Линейный роутер (Linear(emb_size, num_experts)): выдает «важность» каждого эксперта для токена.
|
||||||
|
- Список из num_experts экспертов (в данной реализации — SwiGLU-блоки).
|
||||||
|
|
||||||
|
При каждом проходе для каждого токена выбираются top_k_experts наиболее релевантных экспертов,
|
||||||
|
их ответы агрегируются взвешенной суммой (softmax по роутерным логитам).
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
if top_k_experts > num_experts:
|
||||||
|
raise ValueError(f"top_k_experts ({top_k_experts}) должен быть меньше или равен num_experts ({num_experts})!")
|
||||||
|
self._num_experts = num_experts
|
||||||
|
self._top_k_experts = top_k_experts
|
||||||
|
|
||||||
|
self._router = nn.Linear(emb_size, num_experts)
|
||||||
|
self._experts = nn.ModuleList([SwiGLU(
|
||||||
|
emb_size=emb_size,
|
||||||
|
dropout=dropout,
|
||||||
|
) for _ in range(num_experts)])
|
||||||
|
self._dropout = nn.Dropout(dropout)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
"""
|
||||||
|
Прямой проход (forward) через слой MoE.
|
||||||
|
|
||||||
|
Для входной последовательности скрытых состояний (обычно из предыдущего слоя трансформера)
|
||||||
|
данный метод динамически выбирает для каждого токена топ-k наиболее релевантных экспертов с помощью роутера,
|
||||||
|
пропускает соответствующие токены через выбранных экспертов и агрегирует их результаты.
|
||||||
|
|
||||||
|
Математически:
|
||||||
|
--------------
|
||||||
|
1. Для каждого токена вычисляются логиты маршрутизатора (роутера):
|
||||||
|
router_logits = Linear(x) ∈ ℝ^{batch, seq, num_experts}
|
||||||
|
2. Выбираются top_k экспертов (topk_indices) и соответствующие им softmax-веса (topk_weights).
|
||||||
|
3. Каждый эксперт обрабатывает только свой поднабор токенов.
|
||||||
|
4. Результат агрегируется — отклик эксперта умножается на вес, ответы суммируются для каждого токена.
|
||||||
|
5. На результат применяется dropout для регуляризации.
|
||||||
|
|
||||||
|
Аргументы:
|
||||||
|
----------
|
||||||
|
x : torch.Tensor
|
||||||
|
Трёхмерный входной тензор формы [batch_size, seq_length, emb_size],
|
||||||
|
где batch_size — размер батча, seq_length — длина последовательности, emb_size — размерность эмбеддинга.
|
||||||
|
|
||||||
|
Возвращает:
|
||||||
|
-----------
|
||||||
|
torch.Tensor :
|
||||||
|
Тензор той же формы [batch_size, seq_length, emb_size] — результат комбинирования выходов выбранных экспертов
|
||||||
|
с учетом softmax-весов маршрутизатора и dropout'а.
|
||||||
|
|
||||||
|
Пример:
|
||||||
|
-------
|
||||||
|
>>> y = moe(x)
|
||||||
|
>>> print(y.shape)
|
||||||
|
torch.Size([batch_size, seq_length, emb_size])
|
||||||
|
|
||||||
|
Примечание:
|
||||||
|
-----------
|
||||||
|
- Каждый токен чаще всего активирует только подмножество экспертов.
|
||||||
|
- Остальные эксперты вычислительно “спят”, что позволяет строить очень большие (по параметрам) модели с малым ростом затрат.
|
||||||
|
- Работа с распределением топ-к экспертов и агрегирование с весами реализовано автоматически.
|
||||||
|
|
||||||
|
"""
|
||||||
|
batch_size, seq_len, emb_size = x.shape
|
||||||
|
|
||||||
|
# 1. Пропускаем через роутер
|
||||||
|
router_logits = self._router(x) # [batch_size, seq_len, num_experts]
|
||||||
|
|
||||||
|
# 2. Отбираем топ-k экспертов для каждого токена
|
||||||
|
topk_logits, topk_indices = torch.topk(
|
||||||
|
router_logits,
|
||||||
|
k=self._top_k_experts,
|
||||||
|
dim=-1
|
||||||
|
) # topk_logits: [batch_size, seq_len, top_k]
|
||||||
|
# topk_indices: [batch_size, seq_len, top_k]
|
||||||
|
|
||||||
|
# 3. Получаем веса через softmax и нормируем
|
||||||
|
topk_weights = F.softmax(topk_logits, dim=-1) # [batch_size, seq_len, top_k]
|
||||||
|
|
||||||
|
# 4. Создаём нулевой тензор для результата
|
||||||
|
output = torch.zeros_like(x) # [batch_size, seq_len, emb_size]
|
||||||
|
|
||||||
|
# 5. Проходим по всем экспертам
|
||||||
|
for expert_id in range(self._num_experts):
|
||||||
|
# Шаг 1: Создаём маску - где находится текущий эксперт в топ-k
|
||||||
|
expert_mask = (topk_indices == expert_id) # [batch_size, seq_len, top_k]
|
||||||
|
# Шаг 2: Проверяем, выбран ли эксперт хотя бы одним токеном
|
||||||
|
if not expert_mask.any():
|
||||||
|
continue # Эксперт никем не выбран, переходим к следующему
|
||||||
|
|
||||||
|
# Шаг 3: Находим токены, которые выбрали этого эксперта
|
||||||
|
# (хотя бы в одной из top_k позиций)
|
||||||
|
token_mask = expert_mask.any(dim=-1) # [batch_size, seq_len]
|
||||||
|
|
||||||
|
# Шаг 4: Отбираем токены из x
|
||||||
|
# Отбираем токены для этого эксперта
|
||||||
|
expert_input = x[token_mask]
|
||||||
|
|
||||||
|
# Пропускаем через эксперта
|
||||||
|
# Добавляем batch dimension для SwiGLU и затем убираем
|
||||||
|
expert_output = self._experts[expert_id](
|
||||||
|
expert_input.unsqueeze(0)
|
||||||
|
).squeeze(0)
|
||||||
|
|
||||||
|
# Получаем веса для этого эксперта
|
||||||
|
# Для каждого токена может быть несколько весов (если эксперт в топ-k несколько раз)
|
||||||
|
# Но на практике каждый эксперт появляется максимум 1 раз в топ-k
|
||||||
|
# Находим веса: где expert_mask == True, берём соответствующий вес
|
||||||
|
weights_for_expert = torch.zeros(
|
||||||
|
batch_size, seq_len, device=x.device
|
||||||
|
)
|
||||||
|
|
||||||
|
# Для каждой позиции в топ-k
|
||||||
|
for k in range(self._top_k_experts):
|
||||||
|
mask_k = topk_indices[:, :, k] == expert_id
|
||||||
|
weights_for_expert[mask_k] = topk_weights[:, :, k][mask_k]
|
||||||
|
|
||||||
|
# Отбираем только веса для выбранных токенов
|
||||||
|
selected_weights = weights_for_expert[token_mask] # [num_selected_tokens]
|
||||||
|
|
||||||
|
|
||||||
|
# Перемножьте выход эксперта на веса текущего эксперта.
|
||||||
|
weighted_output = selected_weights.unsqueeze(-1) * expert_output
|
||||||
|
|
||||||
|
# Помещаем результат на своё место в выходном тензоре
|
||||||
|
output[token_mask] += weighted_output
|
||||||
|
|
||||||
|
out = self._dropout(output)
|
||||||
|
|
||||||
|
return out
|
||||||
252
llm/src/llm/core/multi_query_attention.py
Normal file
252
llm/src/llm/core/multi_query_attention.py
Normal file
@@ -0,0 +1,252 @@
|
|||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from llm.core.rope import RoPE
|
||||||
|
|
||||||
|
class MultiQueryAttention(nn.Module):
|
||||||
|
"""
|
||||||
|
Multi-Query Attention (MQA) — быстрый и экономичный вариант self-attention для LLM.
|
||||||
|
|
||||||
|
Назначение:
|
||||||
|
-----------
|
||||||
|
Класс реализует механизм внимания (self-attention), в котором для всех Query-голов используются одни и те же Key и Value.
|
||||||
|
В классическом MultiHeadAttention (MHA) на каждый Query используется свой Key/Value. В MQA набор Key/Value общий для всех голов,
|
||||||
|
что снижает требования к памяти и ускоряет работу, что особенно важно для больших LLM на inference.
|
||||||
|
|
||||||
|
Теоретическое преимущество:
|
||||||
|
--------------------------
|
||||||
|
- Существенно экономит память на матрицы Key и Value: количество KV-голов обычно в 4–8 раз меньше, чем число Query-голов.
|
||||||
|
- Позволяет достигать скорости почти обычной MHA при минимальной потере точности (см. Llama, Mistral).
|
||||||
|
- Является стандартом де-факто для deployment и inference современных LLM.
|
||||||
|
|
||||||
|
Архитектурная схема:
|
||||||
|
--------------------
|
||||||
|
- Для каждого токена во входе вычисляются Q_h (отдельные для каждой Query-головы), но K и V — общие для всех.
|
||||||
|
- Attention внутри каждой головы формируется через матричный продукт соответствующей Q_h и общего K.
|
||||||
|
- Выходные вектора голов конкатенируются и проецируются обратно в emb_size.
|
||||||
|
|
||||||
|
Формулы:
|
||||||
|
--------
|
||||||
|
Q = Wq·x, K = Wk·x, V = Wv·x
|
||||||
|
(Wq — отдельные для всех Query, Wk/Wv — общие для всех голов)
|
||||||
|
Attention_h(x) = softmax(Q_h·K^T / sqrt(d_k))·V
|
||||||
|
Output = Concat_h([Attention_h(x)])·W_o
|
||||||
|
|
||||||
|
Аргументы конструктора:
|
||||||
|
-----------------------
|
||||||
|
emb_size : int
|
||||||
|
Размерность скрытого пространства (hidden size, embedding dim).
|
||||||
|
num_heads : int
|
||||||
|
Число Query-голов (обычно 8–32 в LLM).
|
||||||
|
kv_heads : int
|
||||||
|
Число Key/Value-голов (обычно 1, 2, 4, 8).
|
||||||
|
head_size : int, optional
|
||||||
|
Размерность одной головы (обычно emb_size // num_heads).
|
||||||
|
dropout : float, optional
|
||||||
|
Вероятность Dropout для регуляризации внимания.
|
||||||
|
|
||||||
|
Пример использования:
|
||||||
|
---------------------
|
||||||
|
>>> mqa = MultiQueryAttention(emb_size=512, num_heads=8, kv_heads=1)
|
||||||
|
>>> x = torch.randn(2, 16, 512)
|
||||||
|
>>> mask = torch.ones(2, 16, 16)
|
||||||
|
>>> out = mqa(x, mask)
|
||||||
|
>>> print(out.shape) # torch.Size([2, 16, 512])
|
||||||
|
|
||||||
|
Литература и статьи:
|
||||||
|
--------------------
|
||||||
|
- Shazeer, N., “Fast Transformer Decoding: One Write-Head Is All You Need” (MQA): https://arxiv.org/abs/1911.02150
|
||||||
|
- Llama: https://arxiv.org/abs/2302.13971
|
||||||
|
- Mistral: https://arxiv.org/abs/2310.06825
|
||||||
|
- PaLM/PaLM2, Mixtral, ChatGLM: практическое описание MQA.
|
||||||
|
"""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_q_heads: int,
|
||||||
|
emb_size: int,
|
||||||
|
head_size: int,
|
||||||
|
max_seq_len: int,
|
||||||
|
rope: RoPE = None,
|
||||||
|
dropout: float = 0.1,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Конструктор MultiQueryAttention.
|
||||||
|
|
||||||
|
Инициализирует все слои и буферы для реализации Multi-Query Attention с общими K/V-головами и индивидуальными Q-головами.
|
||||||
|
Позволяет существенно ускорять инференс и экономить память при работе с большими языковыми моделями.
|
||||||
|
|
||||||
|
Аргументы:
|
||||||
|
----------
|
||||||
|
num_q_heads : int
|
||||||
|
Число query-голов (обычно совпадает с количеством attention heads в модели).
|
||||||
|
Определяет количество параллельных subspace для запроса.
|
||||||
|
emb_size : int
|
||||||
|
Размер скрытого пространства embedding (input/output размерность attention слоя).
|
||||||
|
head_size : int
|
||||||
|
Размерность одной attention-головы.
|
||||||
|
Обычно emb_size // num_q_heads.
|
||||||
|
max_seq_len : int
|
||||||
|
Максимально поддерживаемая длина последовательности (нужна для построения треугольной маски causal attention).
|
||||||
|
rope : RoPE, optional
|
||||||
|
Модуль для rotary positional encoding (позиционный энкодер, улучшает обобщающую способность attention).
|
||||||
|
Если None, positional encoding не применяется.
|
||||||
|
dropout : float, по умолчанию 0.1
|
||||||
|
Вероятность dropout для выходного слоя attention (регуляризация).
|
||||||
|
|
||||||
|
Внутри:
|
||||||
|
-------
|
||||||
|
- Насчитывает отдельные весовые слои для Q, общие для всех голов K/V.
|
||||||
|
- Строит causal маску для автогрессивной генерации.
|
||||||
|
- (Опционально) использует RoPE для позиционного кодирования.
|
||||||
|
- Dropout применяется после финального projection.
|
||||||
|
|
||||||
|
Пример:
|
||||||
|
-------
|
||||||
|
>>> mqa = MultiQueryAttention(emb_size=256, num_q_heads=8, head_size=32, max_seq_len=2048, rope=None, dropout=0.1)
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self._num_q_heads = num_q_heads
|
||||||
|
self._head_size = head_size
|
||||||
|
self._max_seq_len = max_seq_len
|
||||||
|
self._rope = rope
|
||||||
|
|
||||||
|
self._q = nn.Linear(emb_size, num_q_heads * head_size)
|
||||||
|
self._k = nn.Linear(emb_size, head_size)
|
||||||
|
self._v = nn.Linear(emb_size, head_size)
|
||||||
|
|
||||||
|
# Создание causal маски
|
||||||
|
mask = torch.tril(torch.ones(max_seq_len, max_seq_len))
|
||||||
|
self.register_buffer(
|
||||||
|
"_tril_mask", mask.bool() if hasattr(torch, "bool") else mask.byte()
|
||||||
|
)
|
||||||
|
|
||||||
|
self._layer = nn.Linear(num_q_heads * head_size, emb_size)
|
||||||
|
self._dropout = nn.Dropout(dropout)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
mask: torch.Tensor = None,
|
||||||
|
use_cache: bool = True,
|
||||||
|
cache: list = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Прямой проход (forward) через слой MultiQueryAttention.
|
||||||
|
|
||||||
|
Реализует multi-query self-attention для входных последовательностей с оптимизацией памяти за счёт общих K/V-голов для всех Query.
|
||||||
|
Поддерживает работу с rotary positional encoding (RoPE), каузальной маской и кэшированием для ускорения генерации.
|
||||||
|
|
||||||
|
Аргументы:
|
||||||
|
----------
|
||||||
|
x : torch.Tensor
|
||||||
|
Входной тензор формы [batch_size, seq_len, emb_size] — скрытые состояния после предыдущего слоя или эмбеддинга.
|
||||||
|
mask : torch.Tensor, optional
|
||||||
|
Необязательная маска внимания (например, для padding или custom-маскировки). По умолчанию используется встроенная causal mask.
|
||||||
|
use_cache : bool, по умолчанию True
|
||||||
|
Если True, возвращает кэш ключей/значений (для autoregressive inference/generation).
|
||||||
|
cache : list, optional
|
||||||
|
(K_cache, V_cache) — предварительный кэш KV (для ускоренного инференса). Если None, кэш не используется/создаётся заново.
|
||||||
|
|
||||||
|
Возвращает:
|
||||||
|
-----------
|
||||||
|
если use_cache == True:
|
||||||
|
Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
|
||||||
|
- attention_out: [batch_size, seq_len, emb_size] — результат attention после проекции и dropout.
|
||||||
|
- (K, V): кэшированные ключи и значения (использовать для последующих forward'ов в autoregressive генерации)
|
||||||
|
если use_cache == False:
|
||||||
|
Tuple[torch.Tensor, None]
|
||||||
|
|
||||||
|
Математические шаги:
|
||||||
|
--------------------
|
||||||
|
1. Q = Wq·x; K = Wk·x; V = Wv·x # Q: индивидуальные для каждой головы, K/V — общие
|
||||||
|
2. [optional] Rotary positional encoding применяется к Q и K
|
||||||
|
3. (optional) concat c k/v cache (for autoregressive inference)
|
||||||
|
4. attention_scores = softmax(Q·K^T / sqrt(head_size), mask)
|
||||||
|
5. attention_out = attention_scores·V
|
||||||
|
6. heads сливаются и проецируются в emb_size; применяется dropout.
|
||||||
|
|
||||||
|
Пример:
|
||||||
|
-------
|
||||||
|
>>> out, cache = mqa(x, mask=attn_mask, use_cache=True, cache=prev_cache)
|
||||||
|
>>> print(out.shape) # torch.Size([batch_size, seq_len, emb_size])
|
||||||
|
|
||||||
|
Примечания:
|
||||||
|
-----------
|
||||||
|
- Для каузального режима используется треугольная маска (по умолчанию).
|
||||||
|
- Для генерации текста с cache передавайте кэш от предыдущих токенов — это ускоряет autoregressive inference.
|
||||||
|
- Внимание! Тензоры внутри cache должны иметь форму [batch, heads, seq_len, head_size].
|
||||||
|
"""
|
||||||
|
batch_size, seq_len, emb_size = x.shape
|
||||||
|
if seq_len > self._max_seq_len:
|
||||||
|
raise ValueError(
|
||||||
|
f"Длина последовательности {seq_len} превышает максимум {self._max_seq_len}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Пропустите тензор x через матрицы Wq, Wk , Wv, чтобы получить матрицы запроса, ключа и значения.
|
||||||
|
k = self._k(x) # [B, T, hs]
|
||||||
|
q = self._q(x) # [B, T, hs]
|
||||||
|
v = self._v(x) # [B, T, hs]
|
||||||
|
|
||||||
|
# Шаг 2: Изменение формы для multi-head
|
||||||
|
# [batch_size, seq_len, num_heads * head_size]
|
||||||
|
# -> [batch_size, seq_len, num_heads, head_size]
|
||||||
|
q = q.reshape(batch_size, seq_len, self._num_q_heads, self._head_size)
|
||||||
|
k = k.reshape(batch_size, seq_len, 1, self._head_size)
|
||||||
|
v = v.reshape(batch_size, seq_len, 1, self._head_size)
|
||||||
|
|
||||||
|
# 3. Transpose: [B, T, H, hs] -> [B, H, T, hs]
|
||||||
|
q = q.transpose(1, 2)
|
||||||
|
k = k.transpose(1, 2)
|
||||||
|
v = v.transpose(1, 2)
|
||||||
|
|
||||||
|
# Пропустите матрицы запроса и ключа через экземпляр rope, чтобы выполнить поворот.
|
||||||
|
if self._rope is not None:
|
||||||
|
# Применяем RoPE к Q и K (НЕ к V!)
|
||||||
|
q = self._rope(q) # [B, T, hs]
|
||||||
|
k = self._rope(k) # [B, T, hs]
|
||||||
|
|
||||||
|
|
||||||
|
# Если cache пришел, то объединяем кэш и одну строку из ключа и значения. Это будут новые key и value для последующих вычислений.
|
||||||
|
# 5. Кэширование (для autoregressive generation)
|
||||||
|
if cache is not None:
|
||||||
|
k_cache, v_cache = cache
|
||||||
|
k = torch.cat([k_cache, k], dim=2) # Concat по seq_len (dim=2)
|
||||||
|
v = torch.cat([v_cache, v], dim=2)
|
||||||
|
|
||||||
|
|
||||||
|
# Перемножим матрицы запроса и ключа (транспонированную), чтобы вычислить матрицу внимания.
|
||||||
|
# И разделить все значения в матрице внимания на корень из head_size.
|
||||||
|
scores = q @ k.transpose(-2, -1) / (self._head_size ** 0.5)
|
||||||
|
|
||||||
|
# Если cache пришел, то маску не накладываем. Иначе наложите на матрицу внимания треугольную маску, созданную при инициализации. Все скрытые значения должны быть приведены к минус бесконечности: float('-inf').
|
||||||
|
if cache is None:
|
||||||
|
scores = scores.masked_fill(
|
||||||
|
~self._tril_mask[:seq_len, :seq_len], float("-inf")
|
||||||
|
)
|
||||||
|
|
||||||
|
# Применить к матрице внимания (построчно) функцию Softmax.
|
||||||
|
weights = F.softmax(scores, dim=-1)
|
||||||
|
|
||||||
|
# Перемножим матрицу внимания и матрицу значения.
|
||||||
|
x_out = weights @ v # [B, T, hs]
|
||||||
|
|
||||||
|
|
||||||
|
# Измените форму тензора на batch_size × seq_len × num_heads*head_size.
|
||||||
|
# Transpose обратно и concatenate heads
|
||||||
|
x_out = x_out.transpose(1, 2) # [B, T_q, H, hs]
|
||||||
|
x_out = x_out.contiguous() # Важно для reshape!
|
||||||
|
concatenated_attention = x_out.reshape(batch_size, seq_len, self._num_q_heads * self._head_size)
|
||||||
|
|
||||||
|
|
||||||
|
# Пропустите получившийся тензор через последний линейный слой.
|
||||||
|
# 3. Проецируем в пространство эмбеддингов
|
||||||
|
projected_output = self._layer(concatenated_attention)
|
||||||
|
|
||||||
|
|
||||||
|
# 4. Применяем dropout для регуляризации
|
||||||
|
final_output = self._dropout(projected_output)
|
||||||
|
|
||||||
|
if use_cache is True:
|
||||||
|
return (final_output, (k, v))
|
||||||
|
else:
|
||||||
|
return (final_output, None)
|
||||||
3
llm/src/llm/models/gemma/__init__.py
Normal file
3
llm/src/llm/models/gemma/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
from .gemma import Gemma
|
||||||
|
|
||||||
|
__all__ = ["Gemma"]
|
||||||
343
llm/src/llm/models/gemma/gemma.py
Normal file
343
llm/src/llm/models/gemma/gemma.py
Normal file
@@ -0,0 +1,343 @@
|
|||||||
|
import torch
|
||||||
|
import math
|
||||||
|
from torch import nn
|
||||||
|
from torch import Tensor
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from math import sqrt
|
||||||
|
from llm.core.base_model import BaseModel
|
||||||
|
from llm.core.token_embeddings import TokenEmbeddings
|
||||||
|
from llm.core.rope import RoPE
|
||||||
|
from llm.core.rms_norm import RMSNorm
|
||||||
|
from llm.core.gemma_decoder import GemmaDecoder
|
||||||
|
|
||||||
|
|
||||||
|
class Gemma(BaseModel):
|
||||||
|
"""
|
||||||
|
Gemma — языковая трансформер-модель от Google, с архитектурой, оптимизированной для open-source и research-комьюнити.
|
||||||
|
|
||||||
|
Назначение:
|
||||||
|
-----------
|
||||||
|
Модель Gemma реализует стек современных декодерных блоков (GemmaDecoder), поддерживает rotary-позиционирование, multi-query self-attention,
|
||||||
|
эффективный режим генерации (KV-cache), dropout, compact residual connections, базируется на best-practice LLM-инженерии последних лет.
|
||||||
|
Поддерживает batched-тренировку и inference, генерацию с различными стратегиями выборки (greedy, top-k, top-p), автосохранение.
|
||||||
|
|
||||||
|
Архитектурные особенности:
|
||||||
|
--------------------------
|
||||||
|
- Stack из N слоёв GemmaDecoder (attention с Multi-Query либо Grouped heads, FFN с GeGLU/SwiGLU)
|
||||||
|
- RMSNorm или LayerNorm для стабилизации
|
||||||
|
- Dropout для регуляризации
|
||||||
|
- Rotary Position Embedding (RoPE) для позиционных кодов
|
||||||
|
- Выходная проекция (linear → logits) к словарю токенов
|
||||||
|
- Полная поддержка cache для ускорения autoregressive генерации
|
||||||
|
|
||||||
|
Конфиг/Параметры конструктора:
|
||||||
|
------------------------------
|
||||||
|
config : dict
|
||||||
|
Словарь c параметрами модели:
|
||||||
|
- vocab_size : int — размер словаря
|
||||||
|
- embed_dim : int — размер скрытого (hidden) пространства
|
||||||
|
- max_position_embeddings : int — максимальная длина последовательности
|
||||||
|
- num_layers : int — количество декодерных блоков
|
||||||
|
- num_q_heads : int — количество attention голов (Queries)
|
||||||
|
- num_kv_heads : int — количество ключевых/значенческих attention голов
|
||||||
|
- dropout : float — Dropout率
|
||||||
|
- ... (доп. гиперпараметры, требуемые GemmaDecoder'ами)
|
||||||
|
|
||||||
|
Основные методы:
|
||||||
|
----------------
|
||||||
|
- forward(x, use_cache=True, cache=None): выдает батч логитов по токенам, возвращает при необходимости обновленный cache.
|
||||||
|
- generate(...): автотекстогенерация с greedy, temperature, top-k/p sampling, поддержкой кэша (ускорение inference).
|
||||||
|
- save(path)/load(path, device): сохранение и загрузка предобученных весов, параметров и состояния.
|
||||||
|
|
||||||
|
Пример:
|
||||||
|
-------
|
||||||
|
>>> config = {...} # словарь с параметрами
|
||||||
|
>>> model = Gemma(config)
|
||||||
|
>>> x = torch.randint(0, config["vocab_size"], (4, 64))
|
||||||
|
>>> logits, cache = model(x, use_cache=True)
|
||||||
|
>>> print(logits.shape) # [4, 64, vocab_size]
|
||||||
|
>>> out = model.generate(x, max_new_tokens=20, do_sample=True, top_k=10, temperature=0.8)
|
||||||
|
|
||||||
|
Литература и ссылки:
|
||||||
|
--------------------
|
||||||
|
- Gemma: https://ai.google.dev/gemma (официальная страница)
|
||||||
|
- Разработка и архитектура: https://arxiv.org/abs/2403.07794
|
||||||
|
- Rotary Embedding: https://arxiv.org/abs/2104.09864
|
||||||
|
- Multi-Query Attention: https://arxiv.org/abs/1911.02150
|
||||||
|
- Llama: https://arxiv.org/abs/2302.13971
|
||||||
|
"""
|
||||||
|
def __init__(self, config):
|
||||||
|
"""
|
||||||
|
Конструктор класса Gemma.
|
||||||
|
|
||||||
|
Позволяет создать объект языковой модели с архитектурой Gemma и
|
||||||
|
произвольной конфигурацией (гибкая поддержка разных масштабов, ширин, глубин).
|
||||||
|
|
||||||
|
Аргументы:
|
||||||
|
----------
|
||||||
|
config : dict
|
||||||
|
Словарь со всеми необходимыми гиперпараметрами и архитектурными детальями модели Gemma.
|
||||||
|
Ожидаемые ключи (группы параметров):
|
||||||
|
- vocab_size : int — размер словаря токенов (размерность входа/выхода)
|
||||||
|
- embed_dim : int — скрытый размер эмбеддинга (hidden dim)
|
||||||
|
- max_position_embeddings : int — максимальная длина последовательности
|
||||||
|
- num_layers : int — количество декодерных блоков (глубина стека)
|
||||||
|
- num_q_heads : int — число attention голов (Query heads)
|
||||||
|
- num_kv_heads : int — число голов для Key/Value (MultiQuery Attention)
|
||||||
|
- dropout : float — Dropout для регуляризации
|
||||||
|
- остальные специфичные для GemmaDecoder'ов параметры
|
||||||
|
|
||||||
|
Внутри:
|
||||||
|
-------
|
||||||
|
- Инициализируются модули эмбеддинга токенов, позиционного кодирования (RoPE) и Dropout,
|
||||||
|
стек декодеров (GemmaDecoder(...)), слой финальной нормализации и выходная проекция (linear).
|
||||||
|
- Все архитектурные параметры напрямую берутся из config.
|
||||||
|
|
||||||
|
Пример:
|
||||||
|
-------
|
||||||
|
>>> config = {
|
||||||
|
... "vocab_size": 32000,
|
||||||
|
... "embed_dim": 512,
|
||||||
|
... "max_position_embeddings": 2048,
|
||||||
|
... "num_layers": 24,
|
||||||
|
... "num_q_heads": 8,
|
||||||
|
... "num_kv_heads": 4,
|
||||||
|
... "dropout": 0.1,
|
||||||
|
... }
|
||||||
|
>>> model = Gemma(config)
|
||||||
|
|
||||||
|
Примечание:
|
||||||
|
-----------
|
||||||
|
- Внимание: значения config должны быть согласованы друг с другом! Например, embed_dim должен быть кратным num_q_heads и т.д.
|
||||||
|
- Поддерживается дальнейшая кастомизация стека декодеров через ключи в config.
|
||||||
|
"""
|
||||||
|
super().__init__(config)
|
||||||
|
|
||||||
|
self._max_seq_len = config["max_position_embeddings"]
|
||||||
|
|
||||||
|
# Инициализация слоев
|
||||||
|
self._token_embeddings = TokenEmbeddings(
|
||||||
|
vocab_size=config["vocab_size"],
|
||||||
|
emb_size=config["embed_dim"]
|
||||||
|
)
|
||||||
|
self._position_embeddings = RoPE(
|
||||||
|
head_size=config["embed_dim"] // config["num_q_heads"],
|
||||||
|
max_seq_len=config["max_position_embeddings"]
|
||||||
|
)
|
||||||
|
#self._position_embeddings = PositionalEmbeddings(
|
||||||
|
# max_seq_len=max_seq_len,
|
||||||
|
# emb_size=emb_size
|
||||||
|
#)
|
||||||
|
self._dropout = nn.Dropout(config["dropout"])
|
||||||
|
self._decoders = nn.ModuleList([GemmaDecoder(
|
||||||
|
num_q_heads=config["num_q_heads"],
|
||||||
|
emb_size=config["embed_dim"],
|
||||||
|
head_size=config["embed_dim"] // config["num_q_heads"],
|
||||||
|
max_seq_len=config["max_position_embeddings"],
|
||||||
|
rope=self._position_embeddings,
|
||||||
|
dropout=config["dropout"]
|
||||||
|
) for _ in range(config["num_layers"])])
|
||||||
|
self._norm = RMSNorm(config["embed_dim"])
|
||||||
|
self._linear = nn.Linear(config["embed_dim"], config["vocab_size"])
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, use_cache: bool = True, cache: list = None) -> tuple:
|
||||||
|
"""
|
||||||
|
Прямой проход (forward) через полную модель Gemma.
|
||||||
|
|
||||||
|
Трансформирует входную последовательность токенов через стек из декодерных блоков GemmaDecoder.
|
||||||
|
Возвращает логиты по всем токенам и (при необходимости) кэш attention для быстрой autoregressive-генерации.
|
||||||
|
|
||||||
|
Аргументы:
|
||||||
|
----------
|
||||||
|
x : torch.Tensor
|
||||||
|
Входной тензор shape [batch_size, seq_len], содержащий токен-IDs.
|
||||||
|
use_cache : bool, по умолчанию True
|
||||||
|
Если True — сохраняет и возвращает KV-кэш attention (ускоряет автогенерацию).
|
||||||
|
Если False — кэш не используется.
|
||||||
|
cache : list, optional
|
||||||
|
(Необязательно) Список/None: с кэшами KV-матриц для каждого слоя (для режима генерации статей/диalogов).
|
||||||
|
|
||||||
|
Возвращает:
|
||||||
|
-----------
|
||||||
|
tuple:
|
||||||
|
- logits : torch.Tensor shape [batch_size, seq_len, vocab_size]
|
||||||
|
Логиты по словарю для каждого токена (input + сколь угодно новых).
|
||||||
|
- new_cache : list или None
|
||||||
|
Обновлённый cache (если use_cache=True).
|
||||||
|
|
||||||
|
Пример:
|
||||||
|
-------
|
||||||
|
>>> logits, new_cache = model(x, use_cache=True, cache=None)
|
||||||
|
>>> logits.shape # [batch_size, seq_len, vocab_size]
|
||||||
|
|
||||||
|
Примечания:
|
||||||
|
-----------
|
||||||
|
- Используется при обучении и инференсе.
|
||||||
|
- Если нужно только инференс last-token — используйте logits[:, -1, :].
|
||||||
|
- При превышении x.shape[1] > max_seq_len выдаёт ValueError.
|
||||||
|
"""
|
||||||
|
# Проверка длины последовательности (только при отсутствии кэша)
|
||||||
|
if cache is None and x.size(1) > self._max_seq_len:
|
||||||
|
raise ValueError(f"Длина последовательности {x.size(1)} превышает максимальную {self.max_seq_len}")
|
||||||
|
|
||||||
|
# Эмбеддинги токенов и позиций
|
||||||
|
tok_out = self._token_embeddings(x) # [batch, seq_len, emb_size]
|
||||||
|
#pos_out = self._position_embeddings(x) # [batch, seq_len, emb_size]
|
||||||
|
|
||||||
|
# Комбинирование
|
||||||
|
out = self._dropout(tok_out) # [batch, seq_len, emb_size]
|
||||||
|
|
||||||
|
# Стек декодеров с передачей кэша
|
||||||
|
new_cache = []
|
||||||
|
for i, decoder in enumerate(self._decoders):
|
||||||
|
decoder_cache = cache[i] if cache is not None else None
|
||||||
|
decoder_result = decoder(out, use_cache=use_cache, cache=decoder_cache)
|
||||||
|
|
||||||
|
# Извлекаем результат из кортежа
|
||||||
|
if use_cache:
|
||||||
|
out, decoder_new_cache = decoder_result
|
||||||
|
new_cache.append(decoder_new_cache)
|
||||||
|
else:
|
||||||
|
out = decoder_result[0]
|
||||||
|
|
||||||
|
out = self._norm(out)
|
||||||
|
logits = self._linear(out)
|
||||||
|
|
||||||
|
# Возвращаем результат с учетом use_cache
|
||||||
|
if use_cache:
|
||||||
|
return (logits, new_cache)
|
||||||
|
else:
|
||||||
|
return (logits, None)
|
||||||
|
|
||||||
|
def generate(self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
max_new_tokens: int,
|
||||||
|
do_sample: bool,
|
||||||
|
temperature: float = 1.0,
|
||||||
|
top_k: int = None,
|
||||||
|
top_p: float = None,
|
||||||
|
use_cache: bool = True
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Авторегрессивная генерация токенов с использованием greedy, temperature, top-k и top-p sampling.
|
||||||
|
Реализует generation-loop с обновлением attention-кэша для ускорения инференса.
|
||||||
|
|
||||||
|
Аргументы:
|
||||||
|
----------
|
||||||
|
x : torch.Tensor
|
||||||
|
Входной тензор с последовательностью токенов (shape [batch_size, seq_len]), который необходимо продолжить.
|
||||||
|
max_new_tokens : int
|
||||||
|
Сколько новых токенов сгенерировать (максимум).
|
||||||
|
do_sample : bool
|
||||||
|
Если True — сэмплирует следующий токен согласно распределению вероятностей (stochastic), иначе выбирает токен с максимальной вероятностью (greedy).
|
||||||
|
temperature : float, default=1.0
|
||||||
|
Параметр для шкалирования распределения вероятностей логитов. Больше 1.0 — больше случайности, меньше 1.0 — более детерминированный (жёсткий) выбор.
|
||||||
|
top_k : int, optional
|
||||||
|
Если задано — для сэмплирования учитываются только top_k наиболее вероятных токенов.
|
||||||
|
top_p : float, optional
|
||||||
|
Если задано — работают nucleus sampling: учитываются токены, суммарная вероятность которых не превышает top_p.
|
||||||
|
use_cache : bool, default=True
|
||||||
|
Если True — для ускорения использует и обновляет attention-кэши (KV-cache).
|
||||||
|
|
||||||
|
Возвращает:
|
||||||
|
-----------
|
||||||
|
torch.Tensor
|
||||||
|
Тензор shape [batch_size, seq_len + max_new_tokens] с исходными и сгенерированными токенами (token IDs).
|
||||||
|
|
||||||
|
Пример:
|
||||||
|
-------
|
||||||
|
>>> out = model.generate(
|
||||||
|
... x, max_new_tokens=20, do_sample=True, temperature=0.8, top_k=50
|
||||||
|
... )
|
||||||
|
>>> print(out.shape) # [batch_size, seq_len+20]
|
||||||
|
|
||||||
|
Примечания:
|
||||||
|
-----------
|
||||||
|
- Нельзя указывать одновременно top_k и top_p (будет выброшено исключение).
|
||||||
|
- temperature <= 0 некорректно (будет выброшено исключение).
|
||||||
|
- Поддержка cache (use_cache=True) значительно ускоряет генерацию длинных последовательностей и позволяет использовать beam search/decoding.
|
||||||
|
- Для воспроизводимых результатов установите torch.manual_seed перед генерацией.
|
||||||
|
- Метод возвращает только token_ids, если нужны logits — используйте .forward напрямую.
|
||||||
|
|
||||||
|
Литература:
|
||||||
|
-----------
|
||||||
|
- Holtzman et al., "The Curious Case of Neural Text Degeneration" (nucleus/top-p sampling): https://arxiv.org/abs/1904.09751
|
||||||
|
- Gemma: https://arxiv.org/abs/2403.07794
|
||||||
|
"""
|
||||||
|
|
||||||
|
cache = None
|
||||||
|
|
||||||
|
for _ in range(max_new_tokens):
|
||||||
|
if use_cache and cache is not None:
|
||||||
|
# Используем кэш - передаем только последний токен
|
||||||
|
x_input = x[:, -1:] # [batch_size, 1]
|
||||||
|
else:
|
||||||
|
# Первая итерация или кэш отключен - передаем всю последовательность
|
||||||
|
x_input = x
|
||||||
|
|
||||||
|
# Прямой проход с кэшем
|
||||||
|
logits, new_cache = self.forward(x_input, use_cache=use_cache, cache=cache)
|
||||||
|
|
||||||
|
# Обновляем кэш для следующей итерации
|
||||||
|
if use_cache:
|
||||||
|
cache = new_cache
|
||||||
|
|
||||||
|
last_logits = logits[:, -1, :] # [batch_size, vocab_size]
|
||||||
|
|
||||||
|
# Масштабируем логиты температурой
|
||||||
|
if temperature > 0:
|
||||||
|
logits_scaled = last_logits / temperature
|
||||||
|
else:
|
||||||
|
logits_scaled = last_logits
|
||||||
|
|
||||||
|
if do_sample == True and top_k != None:
|
||||||
|
_, topk_indices = torch.topk(logits_scaled, top_k, dim=-1)
|
||||||
|
|
||||||
|
# # Заменим все НЕ top-k логиты на -inf
|
||||||
|
masked_logits = logits_scaled.clone()
|
||||||
|
vocab_size = logits_scaled.size(-1)
|
||||||
|
|
||||||
|
# создаём маску: 1, если токен НЕ в topk_indices
|
||||||
|
mask = torch.ones_like(logits_scaled, dtype=torch.bool if hasattr(torch, "bool") else torch.uint8)
|
||||||
|
mask.scatter_(1, topk_indices, False if hasattr(torch, "bool") else 0) # 0 там, где top-k индексы
|
||||||
|
masked_logits[mask.bool() if hasattr(torch, "bool") else mask.byte()] = float('-inf')
|
||||||
|
|
||||||
|
logits_scaled = masked_logits
|
||||||
|
|
||||||
|
if do_sample == True and top_p != None:
|
||||||
|
# 1. Применим softmax, чтобы получить вероятности:
|
||||||
|
probs = F.softmax(logits_scaled, dim=-1) # [B, vocab_size]
|
||||||
|
# 2. Отсортируем токены по убыванию вероятностей:
|
||||||
|
sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1)
|
||||||
|
# 3. Посчитаем кумулятивную сумму вероятностей:
|
||||||
|
cum_probs = torch.cumsum(sorted_probs, dim=-1) # [B, vocab_size]
|
||||||
|
# 4. Определим маску: оставить токены, пока сумма < top_p
|
||||||
|
sorted_mask = (cum_probs <= top_p).bool() if hasattr(torch, "bool") else (cum_probs <= top_p).byte() # [B, vocab_size]
|
||||||
|
# Гарантируем, что хотя бы первый токен останется
|
||||||
|
sorted_mask[:, 0] = True if hasattr(torch, "bool") else 1
|
||||||
|
# 5. Преобразуем маску обратно в оригинальный порядок:
|
||||||
|
# Создаём полную маску из 0
|
||||||
|
mask = torch.zeros_like(probs, dtype=torch.bool if hasattr(torch, "bool") else torch.uint8)
|
||||||
|
# Устанавливаем 1 в местах нужных токенов
|
||||||
|
mask.scatter_(dim=1, index=sorted_indices, src=sorted_mask)
|
||||||
|
# 6. Зануляем логиты токенов вне топ-p:
|
||||||
|
logits_scaled[~mask] = float('-inf')
|
||||||
|
|
||||||
|
# 4. Применяем Softmax
|
||||||
|
probs = F.softmax(logits_scaled, dim=-1) # [batch_size, vocab_size]
|
||||||
|
|
||||||
|
|
||||||
|
if do_sample == True:
|
||||||
|
# 5. Если do_sample равен True, то отбираем токен случайно с помощью torch.multinomial
|
||||||
|
next_token = torch.multinomial(probs, num_samples=1) # [batch_size, 1]
|
||||||
|
else:
|
||||||
|
# 5. Если do_sample равен False, то выбираем токен с максимальной вероятностью
|
||||||
|
next_token = torch.argmax(probs, dim=-1, keepdim=True) # [batch_size, 1]
|
||||||
|
|
||||||
|
# 6. Добавляем его к последовательности
|
||||||
|
x = torch.cat([x, next_token], dim=1) # [batch_size, seq_len+1]
|
||||||
|
return x
|
||||||
|
|
||||||
|
@property
|
||||||
|
def max_seq_len(self) -> int:
|
||||||
|
return self._max_seq_len
|
||||||
3
llm/src/llm/models/mixtral/__init__.py
Normal file
3
llm/src/llm/models/mixtral/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
from .mixtral import Mixtral
|
||||||
|
|
||||||
|
__all__ = ["Mixtral"]
|
||||||
358
llm/src/llm/models/mixtral/mixtral.py
Normal file
358
llm/src/llm/models/mixtral/mixtral.py
Normal file
@@ -0,0 +1,358 @@
|
|||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from torch import Tensor
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from math import sqrt
|
||||||
|
from llm.core.base_model import BaseModel
|
||||||
|
from llm.core.token_embeddings import TokenEmbeddings
|
||||||
|
from llm.core.rope import RoPE
|
||||||
|
from llm.core.rms_norm import RMSNorm
|
||||||
|
from llm.core.mixtral_decoder import MixtralDecoder
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class Mixtral(BaseModel):
|
||||||
|
"""
|
||||||
|
Mixtral — языковая модель с архитектурой Mixture-of-Experts на основе современных трансформеров (см. Mixtral 8x7B).
|
||||||
|
|
||||||
|
Описание:
|
||||||
|
---------
|
||||||
|
Данный класс реализует полностью функциональную LLM с блоками MixtralDecoder, которые используют разреженные Feed-Forward сети MoE (Mixture-of-Experts)
|
||||||
|
и Grouped Query Attention (GQA). Позволяет масштабировать количество параметров без экспоненциального роста вычислительных затрат благодаря активации лишь части экспертов на каждый токен.
|
||||||
|
Mixtral поддерживает автотекстогенерацию с caching, position encoding через RoPE и всё необходимое для работы и тренировки современных LLM.
|
||||||
|
|
||||||
|
Архитектурные особенности:
|
||||||
|
--------------------------
|
||||||
|
- Stack из N слоёв MixtralDecoder (каждый — MoE-блок + attention + RMSNorm).
|
||||||
|
- Dropout для регуляризации на уровне эмбеддингов и слоёв.
|
||||||
|
- Позиционные эмбеддинги реализованы через RoPE (Rotary Positional Embeddings).
|
||||||
|
- Финальная RMSNorm плюс Linear-проекция к словарю токенов.
|
||||||
|
- Поддержка автогенерации с sampling (greedy, top-k, top-p), temperature и KV-cache.
|
||||||
|
|
||||||
|
Аргументы конструктора:
|
||||||
|
----------------------
|
||||||
|
config : dict
|
||||||
|
Словарь-конфиг с основными гиперпараметрами модели:
|
||||||
|
- vocab_size : int — размер словаря токенов
|
||||||
|
- embed_dim : int — размер скрытого пространства
|
||||||
|
- max_position_embeddings : int — макс. длина последовательности
|
||||||
|
- num_layers : int — количество декодерных блоков в стеке
|
||||||
|
- num_q_heads : int — число query-голов в attention
|
||||||
|
- num_kv_heads : int — число kv-голов в attention
|
||||||
|
- num_experts : int — число MoE-экспертов
|
||||||
|
- top_k_experts : int — сколько экспертов активировать на токен
|
||||||
|
- dropout : float — вероятность Dropout
|
||||||
|
- window_size : int — размер окна внимания
|
||||||
|
|
||||||
|
Основные методы:
|
||||||
|
----------------
|
||||||
|
- forward(x, use_cache=True, cache=None) — прямой проход, поддерживает batched вход, caching.
|
||||||
|
- generate(...) — авторегрессивная генерация с разными стратегиями sampling и ускорением через cache.
|
||||||
|
- save(path)/load(path, device) — сохранение и восстановление обученной модели.
|
||||||
|
|
||||||
|
Пример:
|
||||||
|
-------
|
||||||
|
>>> config = {...} # dict с параметрами
|
||||||
|
>>> model = Mixtral(config)
|
||||||
|
>>> x = torch.randint(0, config["vocab_size"], (2, 16))
|
||||||
|
>>> logits, cache = model(x, use_cache=True)
|
||||||
|
>>> print(logits.shape) # [2, 16, vocab_size]
|
||||||
|
|
||||||
|
>>> # Генерация
|
||||||
|
>>> out = model.generate(x, max_new_tokens=20, do_sample=True, top_k=10, temperature=0.9)
|
||||||
|
|
||||||
|
Литература:
|
||||||
|
-----------
|
||||||
|
- Mixtral 8x7B: https://mistral.ai/news/mixtral-of-experts/
|
||||||
|
- Switch Transformer: https://arxiv.org/abs/2101.03961
|
||||||
|
- GShard: https://arxiv.org/abs/2006.16668
|
||||||
|
- RoPE: https://arxiv.org/abs/2104.09864
|
||||||
|
- Grouped Query Attention: https://arxiv.org/abs/2305.14236
|
||||||
|
- RMSNorm: https://arxiv.org/abs/1910.07467
|
||||||
|
"""
|
||||||
|
def __init__(self, config):
|
||||||
|
"""
|
||||||
|
Конструктор класса Mixtral.
|
||||||
|
|
||||||
|
Осуществляет инициализацию всех модулей и внутренних параметров большой языковой модели с архитектурой Mixtral/MoE.
|
||||||
|
Использует параметры из конфиг-словаря `config` для гибкой настройки модели.
|
||||||
|
|
||||||
|
Аргументы:
|
||||||
|
----------
|
||||||
|
config : dict
|
||||||
|
Словарь с основными гиперпараметрами архитектуры. Должен содержать ключи:
|
||||||
|
vocab_size (int): Размер словаря токенов.
|
||||||
|
embed_dim (int): Размер скрытого пространства (эмбеддингов).
|
||||||
|
max_position_embeddings (int): Максимальная длина токенной последовательности.
|
||||||
|
num_layers (int): Количество декодерных блоков (слоёв) в модели.
|
||||||
|
num_q_heads (int): Число query-голов (attention heads).
|
||||||
|
num_kv_heads (int): Число key-value голов (attention heads).
|
||||||
|
num_experts (int): Количество экспертов в каждом MoE-блоке.
|
||||||
|
top_k_experts (int): Сколько экспертов активируется для одного токена.
|
||||||
|
dropout (float): Dropout для регуляризации.
|
||||||
|
window_size (int): Размер окна внимания (Attention Window).
|
||||||
|
|
||||||
|
Внутри:
|
||||||
|
-------
|
||||||
|
- Инициализируются эмбеддинги токенов, позиционные эмбеддинги RoPE, Dropout.
|
||||||
|
- Строится стек из num_layers модулей MixtralDecoder с заданным количеством attention heads и экспертов.
|
||||||
|
- Финальный слой нормализации и проекция к логитам словаря (linear layer).
|
||||||
|
|
||||||
|
Пример:
|
||||||
|
-------
|
||||||
|
>>> config = {
|
||||||
|
... "vocab_size": 32000,
|
||||||
|
... "embed_dim": 512,
|
||||||
|
... "max_position_embeddings": 2048,
|
||||||
|
... "num_layers": 24,
|
||||||
|
... "num_q_heads": 8,
|
||||||
|
... "num_kv_heads": 8,
|
||||||
|
... "num_experts": 8,
|
||||||
|
... "top_k_experts": 2,
|
||||||
|
... "dropout": 0.1,
|
||||||
|
... "window_size": 256,
|
||||||
|
... }
|
||||||
|
>>> model = Mixtral(config)
|
||||||
|
|
||||||
|
Примечания:
|
||||||
|
-----------
|
||||||
|
- Конфиг модели должен быть согласован: размеры должны делиться на число голов, число экспертов и top_k_experts корректно выбраны.
|
||||||
|
- Все параметры, необходимые для построения MixtralDecoder, attention и MoE, берутся из config.
|
||||||
|
"""
|
||||||
|
super().__init__(config)
|
||||||
|
|
||||||
|
self._max_seq_len = config["max_position_embeddings"]
|
||||||
|
|
||||||
|
# Инициализация слоев
|
||||||
|
self._token_embeddings = TokenEmbeddings(
|
||||||
|
vocab_size=config["vocab_size"],
|
||||||
|
emb_size=config["embed_dim"]
|
||||||
|
)
|
||||||
|
self._position_embeddings = RoPE(
|
||||||
|
head_size=config["embed_dim"] // config["num_q_heads"],
|
||||||
|
max_seq_len=config["max_position_embeddings"]
|
||||||
|
)
|
||||||
|
#self._position_embeddings = PositionalEmbeddings(
|
||||||
|
# max_seq_len=max_seq_len,
|
||||||
|
# emb_size=emb_size
|
||||||
|
#)
|
||||||
|
self._dropout = nn.Dropout(config["dropout"])
|
||||||
|
self._decoders = nn.ModuleList([MixtralDecoder(
|
||||||
|
num_q_heads=config["num_q_heads"],
|
||||||
|
num_kv_heads=config["num_kv_heads"],
|
||||||
|
emb_size=config["embed_dim"],
|
||||||
|
head_size=config["embed_dim"] // config["num_q_heads"],
|
||||||
|
max_seq_len=config["max_position_embeddings"],
|
||||||
|
num_experts=config["num_experts"],
|
||||||
|
top_k_experts=config["top_k_experts"],
|
||||||
|
window_size=config["window_size"],
|
||||||
|
rope=self._position_embeddings,
|
||||||
|
dropout=config["dropout"]
|
||||||
|
) for _ in range(config["num_layers"])])
|
||||||
|
self._norm = RMSNorm(config["embed_dim"])
|
||||||
|
self._linear = nn.Linear(config["embed_dim"], config["vocab_size"])
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, use_cache: bool = True, cache: list = None) -> tuple:
|
||||||
|
"""
|
||||||
|
Прямой проход (forward) через всю модель Mixtral.
|
||||||
|
|
||||||
|
Данный метод реализует трансформацию входной последовательности токенов в логиты (предсказания вероятностей токенов словаря)
|
||||||
|
с поддержкой эффективного инференса с использованием cache (KV-кэш attention для автогенерации).
|
||||||
|
|
||||||
|
Аргументы:
|
||||||
|
----------
|
||||||
|
x : torch.Tensor
|
||||||
|
Двумерный входной тензор shape [batch_size, seq_len], где каждое значение — ID токена.
|
||||||
|
use_cache : bool, по умолчанию True
|
||||||
|
Если True — в режиме генерации модель возвращает обновлённый список кэшей attention для ускорения последовательного инференса.
|
||||||
|
Если False — attention cache не используется.
|
||||||
|
cache : list, optional
|
||||||
|
(Необязательно) Список (или None) с кэшем KV attention для каждого слоя. Используется для автогенерации текста.
|
||||||
|
|
||||||
|
Возвращает:
|
||||||
|
-----------
|
||||||
|
tuple:
|
||||||
|
- logits : torch.Tensor — выходной тензор shape [batch_size, seq_len, vocab_size] — массив логитов по токенам и словарю.
|
||||||
|
- new_cache : list или None — обновлённый cache, если используется.
|
||||||
|
|
||||||
|
Пример:
|
||||||
|
-------
|
||||||
|
>>> logits, new_cache = model(x, use_cache=True, cache=None)
|
||||||
|
>>> logits.shape # [batch_size, seq_len, vocab_size]
|
||||||
|
|
||||||
|
Примечания:
|
||||||
|
-----------
|
||||||
|
- Если используется cache — эффективно для авторегрессионной генерации (token-by-token), например, при диалогах или длинной генерации.
|
||||||
|
- Если входная последовательность длиннее max_seq_len — будет выброшено исключение.
|
||||||
|
- Если нужен только логит последнего токена — используйте slice: logits[:, -1, :]
|
||||||
|
|
||||||
|
"""
|
||||||
|
# Проверка длины последовательности (только при отсутствии кэша)
|
||||||
|
if cache is None and x.size(1) > self._max_seq_len:
|
||||||
|
raise ValueError(f"Длина последовательности {x.size(1)} превышает максимальную {self.max_seq_len}")
|
||||||
|
|
||||||
|
# Эмбеддинги токенов и позиций
|
||||||
|
tok_out = self._token_embeddings(x) # [batch, seq_len, emb_size]
|
||||||
|
#pos_out = self._position_embeddings(x) # [batch, seq_len, emb_size]
|
||||||
|
|
||||||
|
# Комбинирование
|
||||||
|
out = self._dropout(tok_out) # [batch, seq_len, emb_size]
|
||||||
|
|
||||||
|
# Стек декодеров с передачей кэша
|
||||||
|
new_cache = []
|
||||||
|
for i, decoder in enumerate(self._decoders):
|
||||||
|
decoder_cache = cache[i] if cache is not None else None
|
||||||
|
decoder_result = decoder(out, use_cache=use_cache, cache=decoder_cache)
|
||||||
|
|
||||||
|
# Извлекаем результат из кортежа
|
||||||
|
if use_cache:
|
||||||
|
out, decoder_new_cache = decoder_result
|
||||||
|
new_cache.append(decoder_new_cache)
|
||||||
|
else:
|
||||||
|
out = decoder_result[0]
|
||||||
|
|
||||||
|
out = self._norm(out)
|
||||||
|
logits = self._linear(out)
|
||||||
|
|
||||||
|
# Возвращаем результат с учетом use_cache
|
||||||
|
if use_cache:
|
||||||
|
return (logits, new_cache)
|
||||||
|
else:
|
||||||
|
return (logits, None)
|
||||||
|
|
||||||
|
def generate(self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
max_new_tokens: int,
|
||||||
|
do_sample: bool,
|
||||||
|
temperature: float = 1.0,
|
||||||
|
top_k: int = None,
|
||||||
|
top_p: float = None,
|
||||||
|
use_cache: bool = True
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Авторегрессивная генерация токенов с поддержкой greedy, temperature, top-k/top-p sampling
|
||||||
|
и ускорением через attention-кэш (KV-cache, важно для inference на длинных текстах).
|
||||||
|
|
||||||
|
Аргументы:
|
||||||
|
x (torch.Tensor): Входной тензор с токенами shape [batch_size, seq_len].
|
||||||
|
max_new_tokens (int): Максимальное количество новых токенов для генерации.
|
||||||
|
do_sample (bool): Если True — вероятность/случайность (random sampling); если False — жадная генерация (argmax).
|
||||||
|
temperature (float): Температура (>0, по умолчанию 1.0); >1.0 — более случайные выборы, <1.0 — более строгие.
|
||||||
|
top_k (int, optional): top-k sampling; при сэмплировании выбираются только top_k наиболее вероятных токенов.
|
||||||
|
top_p (float, optional): nucleus (top-p) sampling; выбираются токены с накопленной вероятностью ≤ top_p.
|
||||||
|
use_cache (bool, по умолчанию True): Использовать ускорение через KV attention cache для autoregressive режима.
|
||||||
|
|
||||||
|
Возвращает:
|
||||||
|
torch.Tensor: Последовательность индексов токенов shape [batch_size, seq_len + max_new_tokens].
|
||||||
|
|
||||||
|
Исключения:
|
||||||
|
ValueError: Если x длиннее max_seq_len модели.
|
||||||
|
ValueError: Если temperature ≤ 0.
|
||||||
|
ValueError: Если одновременно заданы top_k и top_p.
|
||||||
|
ValueError: Если top_k ≤ 0.
|
||||||
|
ValueError: Если top_p не в диапазоне (0, 1].
|
||||||
|
|
||||||
|
Примеры:
|
||||||
|
>>> # Жадная генерация
|
||||||
|
>>> out = model.generate(input_ids, max_new_tokens=16, do_sample=False)
|
||||||
|
>>> # Сэмплирование с температурой
|
||||||
|
>>> out = model.generate(input_ids, max_new_tokens=16, do_sample=True, temperature=0.8)
|
||||||
|
>>> # Top-k sampling
|
||||||
|
>>> out = model.generate(input_ids, max_new_tokens=16, do_sample=True, top_k=50)
|
||||||
|
>>> # Top-p (nucleus) sampling
|
||||||
|
>>> out = model.generate(input_ids, max_new_tokens=16, do_sample=True, top_p=0.92)
|
||||||
|
>>> # Температура + top-k
|
||||||
|
>>> out = model.generate(input_ids, max_new_tokens=16, do_sample=True, temperature=1.0, top_k=100)
|
||||||
|
|
||||||
|
Примечания:
|
||||||
|
- Одновременно использовать top_k и top_p нельзя.
|
||||||
|
- Параметры temperature, top_k, top_p работают только при do_sample=True.
|
||||||
|
- Для полного воспроизведения результата зафиксируйте seed через torch.manual_seed.
|
||||||
|
- Метод всегда возвращает только индексы токенов; для получения логитов используйте forward.
|
||||||
|
|
||||||
|
Ссылки:
|
||||||
|
- Holtzman et al., "The Curious Case of Neural Text Degeneration" (nucleus/top-p sampling): https://arxiv.org/abs/1904.09751
|
||||||
|
- Mistral: https://arxiv.org/abs/2310.06825
|
||||||
|
"""
|
||||||
|
cache = None
|
||||||
|
|
||||||
|
for _ in range(max_new_tokens):
|
||||||
|
if use_cache and cache is not None:
|
||||||
|
# Используем кэш - передаем только последний токен
|
||||||
|
x_input = x[:, -1:] # [batch_size, 1]
|
||||||
|
else:
|
||||||
|
# Первая итерация или кэш отключен - передаем всю последовательность
|
||||||
|
x_input = x
|
||||||
|
|
||||||
|
# Прямой проход с кэшем
|
||||||
|
logits, new_cache = self.forward(x_input, use_cache=use_cache, cache=cache)
|
||||||
|
|
||||||
|
# Обновляем кэш для следующей итерации
|
||||||
|
if use_cache:
|
||||||
|
cache = new_cache
|
||||||
|
|
||||||
|
last_logits = logits[:, -1, :] # [batch_size, vocab_size]
|
||||||
|
|
||||||
|
# Масштабируем логиты температурой
|
||||||
|
if temperature > 0:
|
||||||
|
logits_scaled = last_logits / temperature
|
||||||
|
else:
|
||||||
|
logits_scaled = last_logits
|
||||||
|
|
||||||
|
if do_sample == True and top_k != None:
|
||||||
|
_, topk_indices = torch.topk(logits_scaled, top_k, dim=-1)
|
||||||
|
|
||||||
|
# # Заменим все НЕ top-k логиты на -inf
|
||||||
|
masked_logits = logits_scaled.clone()
|
||||||
|
vocab_size = logits_scaled.size(-1)
|
||||||
|
|
||||||
|
# создаём маску: 1, если токен НЕ в topk_indices
|
||||||
|
mask = torch.ones_like(logits_scaled, dtype=torch.bool if hasattr(torch, "bool") else torch.uint8)
|
||||||
|
mask.scatter_(1, topk_indices, False if hasattr(torch, "bool") else 0) # 0 там, где top-k индексы
|
||||||
|
masked_logits[mask.bool() if hasattr(torch, "bool") else mask.byte()] = float('-inf')
|
||||||
|
|
||||||
|
logits_scaled = masked_logits
|
||||||
|
|
||||||
|
if do_sample == True and top_p != None:
|
||||||
|
# 1. Применим softmax, чтобы получить вероятности:
|
||||||
|
probs = F.softmax(logits_scaled, dim=-1) # [B, vocab_size]
|
||||||
|
# 2. Отсортируем токены по убыванию вероятностей:
|
||||||
|
sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1)
|
||||||
|
# 3. Посчитаем кумулятивную сумму вероятностей:
|
||||||
|
cum_probs = torch.cumsum(sorted_probs, dim=-1) # [B, vocab_size]
|
||||||
|
# 4. Определим маску: оставить токены, пока сумма < top_p
|
||||||
|
sorted_mask = (cum_probs <= top_p).bool() if hasattr(torch, "bool") else (cum_probs <= top_p).byte() # [B, vocab_size]
|
||||||
|
# Гарантируем, что хотя бы первый токен останется
|
||||||
|
sorted_mask[:, 0] = True if hasattr(torch, "bool") else 1
|
||||||
|
# 5. Преобразуем маску обратно в оригинальный порядок:
|
||||||
|
# Создаём полную маску из 0
|
||||||
|
mask = torch.zeros_like(probs, dtype=torch.bool if hasattr(torch, "bool") else torch.uint8)
|
||||||
|
# Устанавливаем 1 в местах нужных токенов
|
||||||
|
mask.scatter_(dim=1, index=sorted_indices, src=sorted_mask)
|
||||||
|
# 6. Зануляем логиты токенов вне топ-p:
|
||||||
|
logits_scaled[~mask] = float('-inf')
|
||||||
|
|
||||||
|
# 4. Применяем Softmax
|
||||||
|
probs = F.softmax(logits_scaled, dim=-1) # [batch_size, vocab_size]
|
||||||
|
|
||||||
|
|
||||||
|
if do_sample == True:
|
||||||
|
# 5. Если do_sample равен True, то отбираем токен случайно с помощью torch.multinomial
|
||||||
|
next_token = torch.multinomial(probs, num_samples=1) # [batch_size, 1]
|
||||||
|
else:
|
||||||
|
# 5. Если do_sample равен False, то выбираем токен с максимальной вероятностью
|
||||||
|
next_token = torch.argmax(probs, dim=-1, keepdim=True) # [batch_size, 1]
|
||||||
|
|
||||||
|
# 6. Добавляем его к последовательности
|
||||||
|
x = torch.cat([x, next_token], dim=1) # [batch_size, seq_len+1]
|
||||||
|
return x
|
||||||
|
|
||||||
|
@property
|
||||||
|
def max_seq_len(self) -> int:
|
||||||
|
return self._max_seq_len
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
60
llm/tests/core/test_geglu.py
Normal file
60
llm/tests/core/test_geglu.py
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
import torch
|
||||||
|
import pytest
|
||||||
|
from llm.core.geglu import GeGLU
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def geglu():
|
||||||
|
return GeGLU(emb_size=16, dropout=0.1)
|
||||||
|
|
||||||
|
def test_forward_shape(geglu):
|
||||||
|
x = torch.randn(2, 5, 16)
|
||||||
|
y = geglu(x)
|
||||||
|
assert y.shape == x.shape
|
||||||
|
|
||||||
|
def test_forward_no_batch(geglu):
|
||||||
|
x = torch.randn(1, 16)
|
||||||
|
y = geglu(x.unsqueeze(0))
|
||||||
|
assert y.shape == (1, 1, 16)
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="float16 not supported without parameter casting")
|
||||||
|
def test_forward_dtype_fp16():
|
||||||
|
geglu = GeGLU(emb_size=8, dropout=0.0)
|
||||||
|
x = torch.randn(2, 4, 8).half()
|
||||||
|
y = geglu(x)
|
||||||
|
assert y.shape == x.shape
|
||||||
|
assert y.dtype == torch.float16
|
||||||
|
|
||||||
|
def test_forward_no_dropout():
|
||||||
|
geglu = GeGLU(emb_size=4, dropout=0.0)
|
||||||
|
x = torch.randn(3, 2, 4)
|
||||||
|
y = geglu(x)
|
||||||
|
assert not torch.isnan(y).any()
|
||||||
|
assert not torch.isinf(y).any()
|
||||||
|
|
||||||
|
def test_gradient_flow(geglu):
|
||||||
|
x = torch.randn(3, 8, 16, requires_grad=True)
|
||||||
|
y = geglu(x)
|
||||||
|
y.sum().backward()
|
||||||
|
assert x.grad is not None
|
||||||
|
assert x.grad.shape == x.shape
|
||||||
|
|
||||||
|
def test_forward_repeatability():
|
||||||
|
torch.manual_seed(42)
|
||||||
|
geglu = GeGLU(emb_size=8, dropout=0.0)
|
||||||
|
x = torch.randn(3, 2, 8)
|
||||||
|
y1 = geglu(x)
|
||||||
|
torch.manual_seed(42)
|
||||||
|
geglu2 = GeGLU(emb_size=8, dropout=0.0)
|
||||||
|
x2 = torch.randn(3, 2, 8)
|
||||||
|
y2 = geglu2(x2)
|
||||||
|
assert torch.allclose(y1, y2, atol=1e-5)
|
||||||
|
|
||||||
|
def test_edge_small_large():
|
||||||
|
geglu = GeGLU(emb_size=2, dropout=0.0)
|
||||||
|
x = torch.randn(2, 2, 2)
|
||||||
|
y = geglu(x)
|
||||||
|
assert y.shape == x.shape
|
||||||
|
geglu = GeGLU(emb_size=256, dropout=0.0)
|
||||||
|
x = torch.randn(1, 1, 256)
|
||||||
|
y = geglu(x)
|
||||||
|
assert y.shape == x.shape
|
||||||
67
llm/tests/core/test_gemma_decoder.py
Normal file
67
llm/tests/core/test_gemma_decoder.py
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
import torch
|
||||||
|
import pytest
|
||||||
|
from llm.core.gemma_decoder import GemmaDecoder
|
||||||
|
from llm.core.rope import RoPE
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def gemma_decoder():
|
||||||
|
rope = RoPE(head_size=4, max_seq_len=32)
|
||||||
|
return GemmaDecoder(
|
||||||
|
num_q_heads=4,
|
||||||
|
emb_size=16,
|
||||||
|
head_size=4,
|
||||||
|
max_seq_len=32,
|
||||||
|
rope=rope,
|
||||||
|
dropout=0.1,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_forward_shape(gemma_decoder):
|
||||||
|
x = torch.randn(2, 12, 16)
|
||||||
|
out, cache = gemma_decoder(x)
|
||||||
|
assert out.shape == (2, 12, 16)
|
||||||
|
assert isinstance(cache, tuple) or cache is None
|
||||||
|
|
||||||
|
def test_forward_masked(gemma_decoder):
|
||||||
|
x = torch.randn(1, 8, 16)
|
||||||
|
mask = torch.ones(1, 8, 8, dtype=torch.bool)
|
||||||
|
out, _ = gemma_decoder(x, mask=mask)
|
||||||
|
assert out.shape == x.shape
|
||||||
|
|
||||||
|
def test_forward_with_cache_flag(gemma_decoder):
|
||||||
|
x = torch.randn(2, 7, 16)
|
||||||
|
out, cache = gemma_decoder(x, use_cache=True, cache=None)
|
||||||
|
assert out.shape == (2, 7, 16)
|
||||||
|
|
||||||
|
def test_forward_wrong_seq_len_raises(gemma_decoder):
|
||||||
|
x = torch.randn(1, 100, 16)
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
gemma_decoder(x)
|
||||||
|
|
||||||
|
def test_gradient_flow(gemma_decoder):
|
||||||
|
x = torch.randn(3, 9, 16, requires_grad=True)
|
||||||
|
y, _ = gemma_decoder(x)
|
||||||
|
y.sum().backward()
|
||||||
|
assert x.grad is not None
|
||||||
|
assert x.grad.shape == x.shape
|
||||||
|
|
||||||
|
def test_various_shapes(gemma_decoder):
|
||||||
|
for b, s in [(1, 1), (2, 5), (2, 32)]:
|
||||||
|
x = torch.randn(b, s, 16)
|
||||||
|
y, _ = gemma_decoder(x)
|
||||||
|
assert y.shape == (b, s, 16)
|
||||||
|
|
||||||
|
def test_forward_repeatability():
|
||||||
|
torch.manual_seed(42)
|
||||||
|
rope = RoPE(head_size=4, max_seq_len=32)
|
||||||
|
decoder = GemmaDecoder(
|
||||||
|
num_q_heads=4, emb_size=16, head_size=4, max_seq_len=32, rope=rope, dropout=0.0,
|
||||||
|
)
|
||||||
|
x = torch.randn(2, 8, 16)
|
||||||
|
y1, _ = decoder(x)
|
||||||
|
torch.manual_seed(42)
|
||||||
|
decoder2 = GemmaDecoder(
|
||||||
|
num_q_heads=4, emb_size=16, head_size=4, max_seq_len=32, rope=rope, dropout=0.0,
|
||||||
|
)
|
||||||
|
x2 = torch.randn(2, 8, 16)
|
||||||
|
y2, _ = decoder2(x2)
|
||||||
|
assert torch.allclose(y1, y2, atol=1e-5)
|
||||||
80
llm/tests/core/test_mixtral_decoder.py
Normal file
80
llm/tests/core/test_mixtral_decoder.py
Normal file
@@ -0,0 +1,80 @@
|
|||||||
|
import torch
|
||||||
|
import pytest
|
||||||
|
from llm.core.mixtral_decoder import MixtralDecoder
|
||||||
|
from llm.core.rope import RoPE
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def basic_decoder():
|
||||||
|
emb_size = 16
|
||||||
|
num_q_heads = 4
|
||||||
|
num_kv_heads = 2
|
||||||
|
head_size = 4
|
||||||
|
max_seq_len = 32
|
||||||
|
num_experts = 4
|
||||||
|
top_k_experts = 2
|
||||||
|
window_size = 8
|
||||||
|
rope = RoPE(head_size=head_size, max_seq_len=max_seq_len)
|
||||||
|
return MixtralDecoder(
|
||||||
|
num_q_heads=num_q_heads,
|
||||||
|
num_kv_heads=num_kv_heads,
|
||||||
|
emb_size=emb_size,
|
||||||
|
head_size=head_size,
|
||||||
|
max_seq_len=max_seq_len,
|
||||||
|
num_experts=num_experts,
|
||||||
|
top_k_experts=top_k_experts,
|
||||||
|
window_size=window_size,
|
||||||
|
rope=rope,
|
||||||
|
dropout=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_forward_shape(basic_decoder):
|
||||||
|
x = torch.randn(2, 10, 16)
|
||||||
|
out, cache = basic_decoder(x)
|
||||||
|
assert out.shape == (2, 10, 16)
|
||||||
|
assert cache is None or isinstance(cache, (tuple, list))
|
||||||
|
|
||||||
|
def test_forward_masked(basic_decoder):
|
||||||
|
x = torch.randn(3, 7, 16)
|
||||||
|
mask = torch.ones(3, 7, 7, dtype=torch.bool)
|
||||||
|
out, cache = basic_decoder(x, mask=mask)
|
||||||
|
assert out.shape == (3, 7, 16)
|
||||||
|
|
||||||
|
def test_forward_with_cache_flag(basic_decoder):
|
||||||
|
x = torch.randn(2, 8, 16)
|
||||||
|
out, cache = basic_decoder(x, use_cache=True, cache=None)
|
||||||
|
assert out.shape == (2, 8, 16)
|
||||||
|
assert isinstance(cache, (tuple, list)) or cache is None
|
||||||
|
|
||||||
|
def test_backprop_pass(basic_decoder):
|
||||||
|
x = torch.randn(2, 5, 16, requires_grad=True)
|
||||||
|
out, _ = basic_decoder(x)
|
||||||
|
y = out.sum()
|
||||||
|
y.backward()
|
||||||
|
assert x.grad is not None
|
||||||
|
assert x.grad.shape == x.shape
|
||||||
|
|
||||||
|
def test_seq_too_long_raises(basic_decoder):
|
||||||
|
x = torch.randn(1, 40, 16) # seq_len > max_seq_len
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
basic_decoder(x)
|
||||||
|
|
||||||
|
def test_different_config():
|
||||||
|
rope = RoPE(head_size=2, max_seq_len=12)
|
||||||
|
decoder = MixtralDecoder(
|
||||||
|
num_q_heads=2, num_kv_heads=2, emb_size=4, head_size=2,
|
||||||
|
max_seq_len=12, num_experts=2, top_k_experts=1, window_size=4, rope=rope, dropout=0.1
|
||||||
|
)
|
||||||
|
x = torch.randn(1, 8, 4)
|
||||||
|
out, cache = decoder(x)
|
||||||
|
assert out.shape == x.shape
|
||||||
|
|
||||||
|
def test_forward_no_dropout():
|
||||||
|
# Проверка на корректность shape при отсутствии Dropout
|
||||||
|
rope = RoPE(head_size=2, max_seq_len=12)
|
||||||
|
decoder = MixtralDecoder(
|
||||||
|
num_q_heads=2, num_kv_heads=1, emb_size=4, head_size=2,
|
||||||
|
max_seq_len=12, num_experts=2, top_k_experts=1, window_size=3, rope=rope, dropout=0.0
|
||||||
|
)
|
||||||
|
x = torch.randn(2, 3, 4)
|
||||||
|
out, cache = decoder(x)
|
||||||
|
assert out.shape == x.shape
|
||||||
61
llm/tests/core/test_moe.py
Normal file
61
llm/tests/core/test_moe.py
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
import torch
|
||||||
|
import pytest
|
||||||
|
from llm.core.moe import MoE
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def moe():
|
||||||
|
# Базовая MoE для коротких тестов
|
||||||
|
return MoE(emb_size=16, num_experts=4, top_k_experts=2, dropout=0.0)
|
||||||
|
|
||||||
|
def test_forward_shape(moe):
|
||||||
|
x = torch.randn(3, 5, 16) # [batch, seq, emb]
|
||||||
|
y = moe(x)
|
||||||
|
assert y.shape == x.shape
|
||||||
|
|
||||||
|
def test_forward_grad(moe):
|
||||||
|
x = torch.randn(2, 4, 16, requires_grad=True)
|
||||||
|
y = moe(x)
|
||||||
|
(y.sum()).backward()
|
||||||
|
assert x.grad is not None
|
||||||
|
assert x.grad.shape == x.shape
|
||||||
|
|
||||||
|
def test_top_k_larger_than_experts():
|
||||||
|
# top_k_experts > num_experts должно падать
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
MoE(emb_size=8, num_experts=2, top_k_experts=4)
|
||||||
|
|
||||||
|
def test_single_expert_no_error():
|
||||||
|
# один эксперт, один топ-к — модель всё ещё валидна
|
||||||
|
moe = MoE(emb_size=8, num_experts=1, top_k_experts=1)
|
||||||
|
x = torch.randn(2, 2, 8)
|
||||||
|
y = moe(x)
|
||||||
|
assert y.shape == x.shape
|
||||||
|
|
||||||
|
def test_forward_trivial_weights():
|
||||||
|
"""Проверяет, что при одинаковых весах роутера MoE возвращает усреднённое по экспертам."""
|
||||||
|
class DummyMoE(MoE):
|
||||||
|
def forward(self, x):
|
||||||
|
# Роутер отдаёт всегда единичные логиты = softmax -> uniform
|
||||||
|
self._router = torch.nn.Linear(x.size(-1), self._num_experts, bias=False)
|
||||||
|
torch.nn.init.constant_(self._router.weight, 0.0)
|
||||||
|
return super().forward(x)
|
||||||
|
moe = DummyMoE(emb_size=4, num_experts=2, top_k_experts=2)
|
||||||
|
x = torch.zeros(1, 2, 4)
|
||||||
|
y = moe(x)
|
||||||
|
assert y.shape == x.shape
|
||||||
|
|
||||||
|
def test_forward_deterministic_seed(moe):
|
||||||
|
torch.manual_seed(42)
|
||||||
|
x = torch.randn(2, 3, 16)
|
||||||
|
y1 = moe(x)
|
||||||
|
torch.manual_seed(42)
|
||||||
|
y2 = moe(x)
|
||||||
|
assert torch.allclose(y1, y2, atol=1e-5)
|
||||||
|
|
||||||
|
def test_forward_no_dropout():
|
||||||
|
"""Без dropout MoE не меняет shape и не даёт NaN."""
|
||||||
|
moe = MoE(emb_size=5, num_experts=3, top_k_experts=2, dropout=0.0)
|
||||||
|
x = torch.randn(2, 7, 5)
|
||||||
|
y = moe(x)
|
||||||
|
assert y.shape == x.shape
|
||||||
|
assert not torch.isnan(y).any()
|
||||||
71
llm/tests/core/test_multi_query_attention.py
Normal file
71
llm/tests/core/test_multi_query_attention.py
Normal file
@@ -0,0 +1,71 @@
|
|||||||
|
import torch
|
||||||
|
import pytest
|
||||||
|
from llm.core.multi_query_attention import MultiQueryAttention
|
||||||
|
from llm.core.rope import RoPE
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mqa_rope():
|
||||||
|
return MultiQueryAttention(
|
||||||
|
num_q_heads=4, emb_size=16, head_size=4, max_seq_len=32, rope=RoPE(head_size=4, max_seq_len=32), dropout=0.1
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mqa_no_rope():
|
||||||
|
return MultiQueryAttention(
|
||||||
|
num_q_heads=2, emb_size=8, head_size=4, max_seq_len=16, rope=None, dropout=0.0
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_forward_shape(mqa_rope):
|
||||||
|
x = torch.randn(2, 10, 16)
|
||||||
|
out, cache = mqa_rope(x)
|
||||||
|
assert out.shape == (2, 10, 16)
|
||||||
|
assert isinstance(cache, tuple) and len(cache) == 2
|
||||||
|
|
||||||
|
def test_forward_masked(mqa_rope):
|
||||||
|
x = torch.randn(2, 8, 16)
|
||||||
|
mask = torch.ones(2, 8, 8, dtype=torch.bool)
|
||||||
|
out, cache = mqa_rope(x, mask=mask)
|
||||||
|
assert out.shape == (2, 8, 16)
|
||||||
|
|
||||||
|
def test_forward_cache(mqa_rope):
|
||||||
|
x = torch.randn(1, 4, 16)
|
||||||
|
# Первый вызов — кэша нет
|
||||||
|
out1, cache1 = mqa_rope(x)
|
||||||
|
# Повторяем: подаем x второй раз — теперь добавим cache
|
||||||
|
out2, cache2 = mqa_rope(x, use_cache=True, cache=cache1)
|
||||||
|
assert out2.shape == (1, 4, 16)
|
||||||
|
assert isinstance(cache2, tuple) and len(cache2) == 2
|
||||||
|
# Проверка, что длина k_cache увеличилась
|
||||||
|
assert cache2[0].shape[2] == cache1[0].shape[2] + x.shape[1] # по длине seq
|
||||||
|
|
||||||
|
def test_forward_no_rope(mqa_no_rope):
|
||||||
|
x = torch.randn(3, 6, 8)
|
||||||
|
out, _ = mqa_no_rope(x)
|
||||||
|
assert out.shape == (3, 6, 8)
|
||||||
|
|
||||||
|
def test_forward_different_batch_seq(mqa_rope):
|
||||||
|
for batch, seq in [(1, 1), (2, 5), (3, 32)]:
|
||||||
|
x = torch.randn(batch, seq, 16)
|
||||||
|
out, _ = mqa_rope(x)
|
||||||
|
assert out.shape == (batch, seq, 16)
|
||||||
|
|
||||||
|
def test_forward_raise_on_long_seq(mqa_rope):
|
||||||
|
x = torch.randn(2, 40, 16) # seq_len > max_seq_len
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
mqa_rope(x)
|
||||||
|
|
||||||
|
def test_forward_grad(mqa_rope):
|
||||||
|
x = torch.randn(2, 7, 16, requires_grad=True)
|
||||||
|
out, _ = mqa_rope(x)
|
||||||
|
y = out.sum()
|
||||||
|
y.backward()
|
||||||
|
assert x.grad is not None
|
||||||
|
assert x.grad.shape == x.shape
|
||||||
|
|
||||||
|
def test_dropout_applied():
|
||||||
|
mqa = MultiQueryAttention(num_q_heads=2, emb_size=8, head_size=4, max_seq_len=12, rope=None, dropout=0.99)
|
||||||
|
x = torch.ones(1, 3, 8)
|
||||||
|
mqa.train()
|
||||||
|
y, _ = mqa(x)
|
||||||
|
# При очень большом dropout почти всё обнуляется
|
||||||
|
assert (torch.abs(y) < 1e-5).float().mean() > 0.6 or y.sum() < 1e-2
|
||||||
56
llm/tests/models/test_gemma.py
Normal file
56
llm/tests/models/test_gemma.py
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
# llm/tests/models/test_gemma.py
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import pytest
|
||||||
|
from llm.models.gemma.gemma import Gemma
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def config():
|
||||||
|
return {
|
||||||
|
"vocab_size": 100,
|
||||||
|
"embed_dim": 32,
|
||||||
|
"num_q_heads": 4,
|
||||||
|
"num_layers": 2,
|
||||||
|
"max_position_embeddings": 16,
|
||||||
|
"dropout": 0.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def model(config):
|
||||||
|
return Gemma(config)
|
||||||
|
|
||||||
|
def test_forward_basic(model):
|
||||||
|
x = torch.randint(0, 100, (2, 8))
|
||||||
|
logits, cache = model(x)
|
||||||
|
assert logits.shape == (2, 8, 100)
|
||||||
|
assert isinstance(cache, list)
|
||||||
|
assert len(cache) == model._decoders.__len__()
|
||||||
|
|
||||||
|
def test_forward_with_cache(model):
|
||||||
|
x = torch.randint(0, 100, (2, 4))
|
||||||
|
logits, cache = model(x, use_cache=True)
|
||||||
|
# Второй проход с cache и одним новым токеном
|
||||||
|
x2 = torch.randint(0, 100, (2, 1))
|
||||||
|
logits2, cache2 = model(x2, use_cache=True, cache=cache)
|
||||||
|
assert logits2.shape == (2, 1, 100)
|
||||||
|
assert isinstance(cache2, list)
|
||||||
|
|
||||||
|
def test_generate_and_shape(model):
|
||||||
|
x = torch.randint(0, 100, (1, 5))
|
||||||
|
result = model.generate(x, max_new_tokens=3, do_sample=False)
|
||||||
|
assert result.shape == (1, 8)
|
||||||
|
|
||||||
|
def test_forward_sequence_too_long(model, config):
|
||||||
|
x = torch.randint(0, 100, (1, config["max_position_embeddings"] + 1))
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
model(x)
|
||||||
|
|
||||||
|
def test_generate_with_sampling_topk(model):
|
||||||
|
x = torch.randint(0, 100, (1, 3))
|
||||||
|
out = model.generate(x, max_new_tokens=2, do_sample=True, top_k=5)
|
||||||
|
assert out.shape == (1, 5)
|
||||||
|
|
||||||
|
def test_generate_with_sampling_topp(model):
|
||||||
|
x = torch.randint(0, 100, (1, 3))
|
||||||
|
out = model.generate(x, max_new_tokens=2, do_sample=True, top_p=0.8)
|
||||||
|
assert out.shape == (1, 5)
|
||||||
57
llm/tests/models/test_mixtral.py
Normal file
57
llm/tests/models/test_mixtral.py
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
import torch
|
||||||
|
import pytest
|
||||||
|
from llm.models.mixtral.mixtral import Mixtral
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def config():
|
||||||
|
return {
|
||||||
|
"vocab_size": 100,
|
||||||
|
"embed_dim": 32,
|
||||||
|
"num_q_heads": 4,
|
||||||
|
"num_kv_heads": 2,
|
||||||
|
"num_layers": 2,
|
||||||
|
"max_position_embeddings": 16,
|
||||||
|
"window_size": 8,
|
||||||
|
"dropout": 0.0,
|
||||||
|
"num_experts": 4,
|
||||||
|
"top_k_experts": 2,
|
||||||
|
}
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def model(config):
|
||||||
|
return Mixtral(config)
|
||||||
|
|
||||||
|
def test_forward_basic(model):
|
||||||
|
x = torch.randint(0, 100, (2, 8))
|
||||||
|
logits, cache = model(x)
|
||||||
|
assert logits.shape == (2, 8, 100)
|
||||||
|
assert isinstance(cache, list)
|
||||||
|
assert len(cache) == model._decoders.__len__()
|
||||||
|
|
||||||
|
def test_forward_with_cache(model):
|
||||||
|
x = torch.randint(0, 100, (2, 4))
|
||||||
|
logits, cache = model(x, use_cache=True)
|
||||||
|
x2 = torch.randint(0, 100, (2, 1))
|
||||||
|
logits2, cache2 = model(x2, use_cache=True, cache=cache)
|
||||||
|
assert logits2.shape == (2, 1, 100)
|
||||||
|
assert isinstance(cache2, list)
|
||||||
|
|
||||||
|
def test_generate_and_shape(model):
|
||||||
|
x = torch.randint(0, 100, (1, 5))
|
||||||
|
result = model.generate(x, max_new_tokens=3, do_sample=False)
|
||||||
|
assert result.shape == (1, 8)
|
||||||
|
|
||||||
|
def test_forward_sequence_too_long(model, config):
|
||||||
|
x = torch.randint(0, 100, (1, config["max_position_embeddings"] + 1))
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
model(x)
|
||||||
|
|
||||||
|
def test_generate_with_sampling_topk(model):
|
||||||
|
x = torch.randint(0, 100, (1, 3))
|
||||||
|
out = model.generate(x, max_new_tokens=2, do_sample=True, top_k=5)
|
||||||
|
assert out.shape == (1, 5)
|
||||||
|
|
||||||
|
def test_generate_with_sampling_topp(model):
|
||||||
|
x = torch.randint(0, 100, (1, 3))
|
||||||
|
out = model.generate(x, max_new_tokens=2, do_sample=True, top_p=0.8)
|
||||||
|
assert out.shape == (1, 5)
|
||||||
1344
notebooks/gemma.ipynb
Normal file
1344
notebooks/gemma.ipynb
Normal file
File diff suppressed because it is too large
Load Diff
1510
notebooks/mixstral.ipynb
Normal file
1510
notebooks/mixstral.ipynb
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user