Merge pull request #4 from pese-git/feature/mixtral

Feature/mixtral
This commit is contained in:
Sergey Penkovsky
2025-10-20 16:36:39 +03:00
committed by GitHub
11 changed files with 2559 additions and 0 deletions

View File

@@ -0,0 +1,19 @@
{
"bpe_tokenizer": "checkpoints/bpe_tokenizer.json",
"test_prompts": [
"Open weights",
"The Llama model is",
"Efficient transformers"
],
"model_config_path": "checkpoints/mixtral-bpe/config.json",
"model_weights": "checkpoints/mixtral-bpe/model.pt",
"generation": {
"max_new_tokens": 40,
"temperature": 0.8,
"do_sample": true,
"top_k": null,
"top_p": null
},
"log_path": "checkpoints/mixtral_only_generation_logs.json"
}

View File

@@ -0,0 +1,28 @@
{
"bpe_tokenizer": "checkpoints/bpe_tokenizer.json",
"bpe_vocab_size": 1000,
"bpe_special_tokens": ["<pad>", "<unk>", "<bos>", "<eos>"],
"test_prompts": ["Open source AI", "What is Llama?"],
"model_config": {
"vocab_size": null,
"embed_dim": 256,
"num_q_heads": 4,
"num_kv_heads": 2,
"head_size": 64,
"num_layers": 4,
"max_position_embeddings": 512,
"num_experts": 8,
"top_k_experts": 2,
"window_size": 16,
"dropout": 0.1
},
"model_weights": "checkpoints/mixtral-bpe/model.pt",
"model_config_path": "checkpoints/mixtral-bpe/config.json",
"training": {
"learning_rate": 0.0003,
"batch_size": 2,
"num_epochs": 3,
"warmup_steps": 50
},
"log_path": "checkpoints/mixtral_only_training_logs.json"
}

View File

@@ -45,6 +45,9 @@ def load_model_class(model_name):
elif model_name.lower() == 'mistral':
from llm.models.mistral import Mistral
return Mistral
elif model_name.lower() == 'mixtral':
from llm.models.mixtral import Mixtral
return Mixtral
else:
raise ValueError(f"Модель '{model_name}' не поддерживается.")

View File

@@ -0,0 +1,211 @@
from torch import nn
import torch
import torch.nn.functional as F
from llm.core.rope import RoPE
from llm.core.group_query_attention import GroupedQueryAttention
from llm.core.moe import MoE
from llm.core.rms_norm import RMSNorm
class MixtralDecoder(nn.Module):
"""
MixtralDecoder — декодерный блок для Mixtral/MoE-трансформеров (см. Mixtral 8x7B, Mistral v0.2 и др.).
Назначение:
-----------
MixtralDecoder реализует один модульный слой глубокой трансформерной архитектуры с Mixture-of-Experts (MoE) Feed-Forward Network и Grouped Query Attention (GQA).
Поддерживает разреженную активацию и масштабируемое количество экспертов, оптимально для больших LLM.
Архитектура блока:
------------------
- RMSNorm -> Grouped Query Attention (GQA)
- skip-connection
- RMSNorm -> MoE (SwiGLU-эксперты)
- skip-connection
Для входа `x` проходит:
1. norm1_out = RMSNorm(x)
2. attention, kv_caches = GQA(norm1_out, ...)
3. out = attention + x # residual connection
4. norm2_out = RMSNorm(out)
5. ffn_out = MoE(norm2_out)
6. return (ffn_out + out, kv_caches)
Теоретическая мотивация:
------------------------
- Использование MoE (см. https://arxiv.org/abs/1701.06538) позволяет кратно увеличивать capacity без роста затрат на ff-часть.
- Grouped Query Attention эффективно масштабирует self-attention для больших моделей (см. Mistral, Llama 2/3).
- RMSNorm (Root Mean Square LayerNorm) стабилизирует градиенты и память.
- Является строительным блоком для стека декодеров в Mixtral-моделях (см. Mixtral, Mistral, LLaMA).
Аргументы конструктора:
----------------------
num_q_heads : int
Число query-голов в attention.
num_kv_heads : int
Число key-value голов (группировка ключей/values).
emb_size : int
Скрытый размер эмбеддинга.
head_size : int
Размерность одной головы (emb_size // num_q_heads).
max_seq_len : int
Максимальная поддерживаемая длина последовательности.
num_experts : int
Количество «экспертов» (MoE).
top_k_experts : int
Сколько одновременно экспертов активируется для одного токена.
window_size : int
Размер окна внимания (используется для efficient attention).
rope : RoPE
Реализация позиционного кодирования RoPE.
dropout : float
Вероятность Dropout для регуляризации.
Пример использования:
---------------------
>>> decoder = MixtralDecoder(... параметры ...)
>>> x = torch.randn(batch, seq, emb_size)
>>> out, cache = decoder(x, mask=None, use_cache=True)
>>> out.shape
Литература и ссылки:
--------------------
- Mixtral 8x7B: https://mistral.ai/news/mixtral-of-experts/
- Shazeer et al., “Outrageously Large Neural Networks: The Sparsely-Gated Mixture-of-Experts Layer”, 2017. https://arxiv.org/abs/1701.06538
- Mistral paper: https://arxiv.org/abs/2310.06825
- GQA: https://arxiv.org/abs/2305.14236
- RMSNorm: https://arxiv.org/abs/1910.07467
"""
def __init__(self,
num_q_heads: int,
num_kv_heads: int,
emb_size: int,
head_size: int,
max_seq_len: int,
num_experts: int,
top_k_experts: int,
window_size: int,
rope: RoPE,
dropout: float = 0.1
):
"""
Конструктор декодерного блока MixtralDecoder.
Осуществляет инициализацию всех под-компонентов слоя: Attention (Grouped Query Attention), MoE (Mixture-of-Experts, SwiGLU)
и нормализации (RMSNorm). Позволяет гибко настраивать архитектуру под специфику задач и размеры LLM.
Аргументы:
----------
num_q_heads : int
Количество голов внимания (queries) в механизме GroupedQueryAttention.
Чем больше — тем тоньше дискретизация внимания по подпространствам признаков.
num_kv_heads : int
Количество групп ключей/значений (key-value heads) для GQA.
Позволяет балансировать производительность и память.
emb_size : int
Размерность эмбеддингового пространства внутри слоя (hidden).
head_size : int
Размерность одной attention-головы. Обычно emb_size // num_q_heads.
max_seq_len : int
Максимально поддерживаемая длина токенизированной последовательности.
num_experts : int
Количество экспертов в слое MoE (размер пула SwiGLU-экспертов).
top_k_experts : int
Сколько экспертов по роутингу активируется на 1 токен (разреженность — эффективная экономия вычислений).
window_size : int
Размер окна для attention (может использоваться для ограничения receptive field, как в Mistral).
rope : RoPE
Объект позиционного кодирования RoPE (Rotary Positional Embedding), необходим для архитектуры внимания.
dropout : float, по умолчанию 0.1
Вероятность зануляции выходных значений для регуляризации и борьбы с переобучением.
Пример:
-------
>>> decoder = MixtralDecoder(
... num_q_heads=8,
... num_kv_heads=2,
... emb_size=256,
... head_size=32,
... max_seq_len=1024,
... num_experts=4,
... top_k_experts=2,
... window_size=128,
... rope=rope_module,
... dropout=0.05
... )
"""
super().__init__()
self._heads = GroupedQueryAttention(
num_q_heads=num_q_heads,
num_kv_heads=num_kv_heads,
emb_size=emb_size,
head_size=head_size,
max_seq_len=max_seq_len,
window_size=window_size,
rope=rope,
dropout=dropout
)
self._ff = MoE(
emb_size=emb_size,
num_experts=num_experts,
top_k_experts=top_k_experts,
dropout=dropout
)
self._norm1 = RMSNorm(emb_size)
self._norm2 = RMSNorm(emb_size)
def forward(self, x: torch.Tensor, mask: torch.Tensor = None, use_cache: bool = True, cache: list = None) -> torch.Tensor:
"""
Прямой проход (forward) через декодерный блок MixtralDecoder.
Данный метод реализует последовательную обработку входных скрытых состояний (x) через:
- нормализацию (RMSNorm),
- attention-модуль (Grouped Query Attention) с опциональным применением маски и кэша ключей/значений для ускорения инференса,
- остаточное сложение (residual connection),
- повторную нормализацию,
- feed-forward блок на основе Mixture-of-Experts (MoE),
- финальное остаточное сложение.
Аргументы:
----------
x : torch.Tensor
Входной скрытый тензор формы [batch_size, seq_len, emb_size] — результат эмбеддинга токенов либо предыдущего слоя.
mask : torch.Tensor, optional
(Необязательно) Маска внимания для ограничения области self-attention (например, для автоперемешивания или causal-LLM-моделей).
use_cache : bool, по умолчанию True
Если True — сохраняет кэш ключей/значений attention для ускорения авторегрессии (инференса).
cache : list, optional
(Необязательно) Предварительно вычисленный кеш attention (для ускорения генерации длинного текста).
Возвращает:
-----------
Tuple[torch.Tensor, Any]:
- Первый элемент: скрытый тензор выхода слоя с той же формой, что вход (последовательный residual из attention и MoE-блока).
- Второй элемент: обновлённый кэш attention (если use_cache=True), иначе None.
Пример:
-------
>>> out, cache = decoder(x, mask=att_mask, use_cache=True, cache=old_cache)
>>> out.shape # [batch_size, seq_len, emb_size]
Примечания:
-----------
- Для autoregressive-генерации (GPT-like режимов) следует передавать mask и использовать use_cache=True.
- Реализация поддерживает произвольные батчи и длины последовательностей, в пределах max_seq_len слоя.
- Модуль MixtralDecoder обычно используется в виде стека (несколько подряд) внутри крупной LLM.
"""
norm1_out = self._norm1(x)
attention, kv_caches = self._heads(norm1_out, mask, use_cache=use_cache, cache=cache)
out = attention + x
norm2_out = self._norm2(out)
ffn_out = self._ff(norm2_out)
if use_cache is True:
return (ffn_out + out, kv_caches)
else:
return (ffn_out + out, None)

229
llm/src/llm/core/moe.py Normal file
View File

@@ -0,0 +1,229 @@
import torch
from torch import nn
import torch.nn.functional as F
from llm.core.swi_glu import SwiGLU
class MoE(nn.Module):
"""
MoE (Mixture of Experts) — слой «смеси экспертов» для современных трансформерных архитектур с разреженной активацией.
Назначение:
-----------
Класс реализует слой разреженного условного вычисления для увеличения capacity трансформеров без роста вычислительных затрат.
Для каждого токена из последовательности выбирается (с помощью роутера) наиболее подходящее подмножество экспертов (малых нейросетей).
Итоговый выход формируется как взвешенная сумма откликов экспертов, выбранных для данного токена.
Архитектурная схема:
---------------------
- Для каждого входного токена `x` роутер (обычно один Linear-слой) предсказывает skor, насколько каждый из `num_experts` релевантен.
- Для каждого токена выбираются top_k_experts с максимальными skor; только они обрабатывают этот токен.
- Каждый эксперт здесь представлен отдельным экземпляром блока `SwiGLU` (может быть любая небольшая feed-forward сеть).
- Выход каждого эксперта умножается на индивидуальный вес (softmax по skor) — агрегируется взвешенная сумма.
- Dropout применяется к итоговому выходу.
Математика (коротко):
---------------------
Пусть X ∈ R^{BxSxD} — вход,
E — число экспертов,
K — число активируемых экспертов на токен.
r(x) = softmax(W_r x) — роутинг-логиты, top-K берём индексы и веса.
Для каждого токена:
y_j = Expert_j(x)
y = sum_j(w_j * y_j), где j пробегает по выбранным экспертам
Output: Y ∈ R^{BxSxD}
Аргументы конструктора:
----------------------
emb_size : int
Размерность входных/выходных векторов (обычно совпадает с embedding модели).
num_experts : int
Общее число экспертов внутри слоя MoE.
top_k_experts : int
Сколько экспертов активировать и агрегировать на каждом токене (обычно 2-8).
dropout : float, по умолчанию 0.1
Dropout к выходу агрегатора.
Пример использования:
---------------------
>>> moe = MoE(emb_size=512, num_experts=8, top_k_experts=2, dropout=0.1)
>>> x = torch.randn(4, 16, 512)
>>> y = moe(x)
>>> y.shape # torch.Size([4, 16, 512])
Литература:
-----------
- Shazeer, N. et al. “Outrageously Large Neural Networks: The Sparsely-Gated Mixture-of-Experts Layer”, 2017. https://arxiv.org/abs/1701.06538
- Fedus, W., Zoph, B., & Shazeer, N. “Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity”, 2021. https://arxiv.org/abs/2101.03961
- Mistral/Mixtral: https://mistral.ai/news/mixtral-of-experts/
"""
def __init__(
self,
emb_size: int,
num_experts: int,
top_k_experts: int,
dropout: float = 0.1,
):
"""
Конструктор слоя MoE (Mixture of Experts).
Позволяет создать слой, состоящий из набора экспертов (например, отдельных небольших feedforward-нейросетей) и роутера,
который будет для каждого токена определять наиболее релевантных экспертов.
Часть экспертов (top_k_experts) активируется для каждого токена, остальные — пропускаются.
Аргументы:
----------
emb_size : int
Размерность входных и выходных векторов (embedding size).
Определяет, над каким пространством признаков будет работать роутер и эксперты.
Например, если скрытый размер слоя трансформера 512, сюда нужно передать 512.
num_experts : int
Общее количество экспертов в слое MoE.
Чем больше экспертов — тем больше capacity у модели, но тем выше требования к RAM/VRAM при обучении.
Пример: 8, 16, 32, 64.
top_k_experts : int
Сколько экспертов одновременно будет обрабатывать каждый токен.
Обычно 28. Меньшее значение — выше разреженность, больше экономия вычислений.
dropout : float, по умолчанию 0.1
Вероятность зануления значений на выходе после агрегации откликов экспертов.
Используется для регуляризации (борьбы с переобучением).
Пример:
-------
>>> moe = MoE(emb_size=256, num_experts=8, top_k_experts=2, dropout=0.1)
>>> print(moe)
MoE( ... )
Теория:
-------
Слой строит:
- Линейный роутер (Linear(emb_size, num_experts)): выдает «важность» каждого эксперта для токена.
- Список из num_experts экспертов (в данной реализации — SwiGLU-блоки).
При каждом проходе для каждого токена выбираются top_k_experts наиболее релевантных экспертов,
их ответы агрегируются взвешенной суммой (softmax по роутерным логитам).
"""
super().__init__()
if top_k_experts > num_experts:
raise ValueError(f"top_k_experts ({top_k_experts}) должен быть меньше или равен num_experts ({num_experts})!")
self._num_experts = num_experts
self._top_k_experts = top_k_experts
self._router = nn.Linear(emb_size, num_experts)
self._experts = nn.ModuleList([SwiGLU(
emb_size=emb_size,
dropout=dropout,
) for _ in range(num_experts)])
self._dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor):
"""
Прямой проход (forward) через слой MoE.
Для входной последовательности скрытых состояний (обычно из предыдущего слоя трансформера)
данный метод динамически выбирает для каждого токена топ-k наиболее релевантных экспертов с помощью роутера,
пропускает соответствующие токены через выбранных экспертов и агрегирует их результаты.
Математически:
--------------
1. Для каждого токена вычисляются логиты маршрутизатора (роутера):
router_logits = Linear(x) ∈ ^{batch, seq, num_experts}
2. Выбираются top_k экспертов (topk_indices) и соответствующие им softmax-веса (topk_weights).
3. Каждый эксперт обрабатывает только свой поднабор токенов.
4. Результат агрегируется — отклик эксперта умножается на вес, ответы суммируются для каждого токена.
5. На результат применяется dropout для регуляризации.
Аргументы:
----------
x : torch.Tensor
Трёхмерный входной тензор формы [batch_size, seq_length, emb_size],
где batch_size — размер батча, seq_length — длина последовательности, emb_size — размерность эмбеддинга.
Возвращает:
-----------
torch.Tensor :
Тензор той же формы [batch_size, seq_length, emb_size] — результат комбинирования выходов выбранных экспертов
с учетом softmax-весов маршрутизатора и dropout'а.
Пример:
-------
>>> y = moe(x)
>>> print(y.shape)
torch.Size([batch_size, seq_length, emb_size])
Примечание:
-----------
- Каждый токен чаще всего активирует только подмножество экспертов.
- Остальные эксперты вычислительно “спят”, что позволяет строить очень большие (по параметрам) модели с малым ростом затрат.
- Работа с распределением топ-к экспертов и агрегирование с весами реализовано автоматически.
"""
batch_size, seq_len, emb_size = x.shape
# 1. Пропускаем через роутер
router_logits = self._router(x) # [batch_size, seq_len, num_experts]
# 2. Отбираем топ-k экспертов для каждого токена
topk_logits, topk_indices = torch.topk(
router_logits,
k=self._top_k_experts,
dim=-1
) # topk_logits: [batch_size, seq_len, top_k]
# topk_indices: [batch_size, seq_len, top_k]
# 3. Получаем веса через softmax и нормируем
topk_weights = F.softmax(topk_logits, dim=-1) # [batch_size, seq_len, top_k]
# 4. Создаём нулевой тензор для результата
output = torch.zeros_like(x) # [batch_size, seq_len, emb_size]
# 5. Проходим по всем экспертам
for expert_id in range(self._num_experts):
# Шаг 1: Создаём маску - где находится текущий эксперт в топ-k
expert_mask = (topk_indices == expert_id) # [batch_size, seq_len, top_k]
# Шаг 2: Проверяем, выбран ли эксперт хотя бы одним токеном
if not expert_mask.any():
continue # Эксперт никем не выбран, переходим к следующему
# Шаг 3: Находим токены, которые выбрали этого эксперта
# (хотя бы в одной из top_k позиций)
token_mask = expert_mask.any(dim=-1) # [batch_size, seq_len]
# Шаг 4: Отбираем токены из x
# Отбираем токены для этого эксперта
expert_input = x[token_mask]
# Пропускаем через эксперта
# Добавляем batch dimension для SwiGLU и затем убираем
expert_output = self._experts[expert_id](
expert_input.unsqueeze(0)
).squeeze(0)
# Получаем веса для этого эксперта
# Для каждого токена может быть несколько весов (если эксперт в топ-k несколько раз)
# Но на практике каждый эксперт появляется максимум 1 раз в топ-k
# Находим веса: где expert_mask == True, берём соответствующий вес
weights_for_expert = torch.zeros(
batch_size, seq_len, device=x.device
)
# Для каждой позиции в топ-k
for k in range(self._top_k_experts):
mask_k = topk_indices[:, :, k] == expert_id
weights_for_expert[mask_k] = topk_weights[:, :, k][mask_k]
# Отбираем только веса для выбранных токенов
selected_weights = weights_for_expert[token_mask] # [num_selected_tokens]
# Перемножьте выход эксперта на веса текущего эксперта.
weighted_output = selected_weights.unsqueeze(-1) * expert_output
# Помещаем результат на своё место в выходном тензоре
output[token_mask] += weighted_output
out = self._dropout(output)
return out

View File

@@ -0,0 +1,3 @@
from .mixtral import Mixtral
__all__ = ["Mixtral"]

View File

@@ -0,0 +1,358 @@
import torch
from torch import nn
from torch import Tensor
import torch.nn.functional as F
from math import sqrt
from llm.core.base_model import BaseModel
from llm.core.token_embeddings import TokenEmbeddings
from llm.core.rope import RoPE
from llm.core.rms_norm import RMSNorm
from llm.core.mixtral_decoder import MixtralDecoder
class Mixtral(BaseModel):
"""
Mixtral — языковая модель с архитектурой Mixture-of-Experts на основе современных трансформеров (см. Mixtral 8x7B).
Описание:
---------
Данный класс реализует полностью функциональную LLM с блоками MixtralDecoder, которые используют разреженные Feed-Forward сети MoE (Mixture-of-Experts)
и Grouped Query Attention (GQA). Позволяет масштабировать количество параметров без экспоненциального роста вычислительных затрат благодаря активации лишь части экспертов на каждый токен.
Mixtral поддерживает автотекстогенерацию с caching, position encoding через RoPE и всё необходимое для работы и тренировки современных LLM.
Архитектурные особенности:
--------------------------
- Stack из N слоёв MixtralDecoder (каждый — MoE-блок + attention + RMSNorm).
- Dropout для регуляризации на уровне эмбеддингов и слоёв.
- Позиционные эмбеддинги реализованы через RoPE (Rotary Positional Embeddings).
- Финальная RMSNorm плюс Linear-проекция к словарю токенов.
- Поддержка автогенерации с sampling (greedy, top-k, top-p), temperature и KV-cache.
Аргументы конструктора:
----------------------
config : dict
Словарь-конфиг с основными гиперпараметрами модели:
- vocab_size : int — размер словаря токенов
- embed_dim : int — размер скрытого пространства
- max_position_embeddings : int — макс. длина последовательности
- num_layers : int — количество декодерных блоков в стеке
- num_q_heads : int — число query-голов в attention
- num_kv_heads : int — число kv-голов в attention
- num_experts : int — число MoE-экспертов
- top_k_experts : int — сколько экспертов активировать на токен
- dropout : float — вероятность Dropout
- window_size : int — размер окна внимания
Основные методы:
----------------
- forward(x, use_cache=True, cache=None) — прямой проход, поддерживает batched вход, caching.
- generate(...) — авторегрессивная генерация с разными стратегиями sampling и ускорением через cache.
- save(path)/load(path, device) — сохранение и восстановление обученной модели.
Пример:
-------
>>> config = {...} # dict с параметрами
>>> model = Mixtral(config)
>>> x = torch.randint(0, config["vocab_size"], (2, 16))
>>> logits, cache = model(x, use_cache=True)
>>> print(logits.shape) # [2, 16, vocab_size]
>>> # Генерация
>>> out = model.generate(x, max_new_tokens=20, do_sample=True, top_k=10, temperature=0.9)
Литература:
-----------
- Mixtral 8x7B: https://mistral.ai/news/mixtral-of-experts/
- Switch Transformer: https://arxiv.org/abs/2101.03961
- GShard: https://arxiv.org/abs/2006.16668
- RoPE: https://arxiv.org/abs/2104.09864
- Grouped Query Attention: https://arxiv.org/abs/2305.14236
- RMSNorm: https://arxiv.org/abs/1910.07467
"""
def __init__(self, config):
"""
Конструктор класса Mixtral.
Осуществляет инициализацию всех модулей и внутренних параметров большой языковой модели с архитектурой Mixtral/MoE.
Использует параметры из конфиг-словаря `config` для гибкой настройки модели.
Аргументы:
----------
config : dict
Словарь с основными гиперпараметрами архитектуры. Должен содержать ключи:
vocab_size (int): Размер словаря токенов.
embed_dim (int): Размер скрытого пространства (эмбеддингов).
max_position_embeddings (int): Максимальная длина токенной последовательности.
num_layers (int): Количество декодерных блоков (слоёв) в модели.
num_q_heads (int): Число query-голов (attention heads).
num_kv_heads (int): Число key-value голов (attention heads).
num_experts (int): Количество экспертов в каждом MoE-блоке.
top_k_experts (int): Сколько экспертов активируется для одного токена.
dropout (float): Dropout для регуляризации.
window_size (int): Размер окна внимания (Attention Window).
Внутри:
-------
- Инициализируются эмбеддинги токенов, позиционные эмбеддинги RoPE, Dropout.
- Строится стек из num_layers модулей MixtralDecoder с заданным количеством attention heads и экспертов.
- Финальный слой нормализации и проекция к логитам словаря (linear layer).
Пример:
-------
>>> config = {
... "vocab_size": 32000,
... "embed_dim": 512,
... "max_position_embeddings": 2048,
... "num_layers": 24,
... "num_q_heads": 8,
... "num_kv_heads": 8,
... "num_experts": 8,
... "top_k_experts": 2,
... "dropout": 0.1,
... "window_size": 256,
... }
>>> model = Mixtral(config)
Примечания:
-----------
- Конфиг модели должен быть согласован: размеры должны делиться на число голов, число экспертов и top_k_experts корректно выбраны.
- Все параметры, необходимые для построения MixtralDecoder, attention и MoE, берутся из config.
"""
super().__init__(config)
self._max_seq_len = config["max_position_embeddings"]
# Инициализация слоев
self._token_embeddings = TokenEmbeddings(
vocab_size=config["vocab_size"],
emb_size=config["embed_dim"]
)
self._position_embeddings = RoPE(
head_size=config["embed_dim"] // config["num_q_heads"],
max_seq_len=config["max_position_embeddings"]
)
#self._position_embeddings = PositionalEmbeddings(
# max_seq_len=max_seq_len,
# emb_size=emb_size
#)
self._dropout = nn.Dropout(config["dropout"])
self._decoders = nn.ModuleList([MixtralDecoder(
num_q_heads=config["num_q_heads"],
num_kv_heads=config["num_kv_heads"],
emb_size=config["embed_dim"],
head_size=config["embed_dim"] // config["num_q_heads"],
max_seq_len=config["max_position_embeddings"],
num_experts=config["num_experts"],
top_k_experts=config["top_k_experts"],
window_size=config["window_size"],
rope=self._position_embeddings,
dropout=config["dropout"]
) for _ in range(config["num_layers"])])
self._norm = RMSNorm(config["embed_dim"])
self._linear = nn.Linear(config["embed_dim"], config["vocab_size"])
def forward(self, x: torch.Tensor, use_cache: bool = True, cache: list = None) -> tuple:
"""
Прямой проход (forward) через всю модель Mixtral.
Данный метод реализует трансформацию входной последовательности токенов в логиты (предсказания вероятностей токенов словаря)
с поддержкой эффективного инференса с использованием cache (KV-кэш attention для автогенерации).
Аргументы:
----------
x : torch.Tensor
Двумерный входной тензор shape [batch_size, seq_len], где каждое значение — ID токена.
use_cache : bool, по умолчанию True
Если True — в режиме генерации модель возвращает обновлённый список кэшей attention для ускорения последовательного инференса.
Если False — attention cache не используется.
cache : list, optional
(Необязательно) Список (или None) с кэшем KV attention для каждого слоя. Используется для автогенерации текста.
Возвращает:
-----------
tuple:
- logits : torch.Tensor — выходной тензор shape [batch_size, seq_len, vocab_size] — массив логитов по токенам и словарю.
- new_cache : list или None — обновлённый cache, если используется.
Пример:
-------
>>> logits, new_cache = model(x, use_cache=True, cache=None)
>>> logits.shape # [batch_size, seq_len, vocab_size]
Примечания:
-----------
- Если используется cache — эффективно для авторегрессионной генерации (token-by-token), например, при диалогах или длинной генерации.
- Если входная последовательность длиннее max_seq_len — будет выброшено исключение.
- Если нужен только логит последнего токена — используйте slice: logits[:, -1, :]
"""
# Проверка длины последовательности (только при отсутствии кэша)
if cache is None and x.size(1) > self._max_seq_len:
raise ValueError(f"Длина последовательности {x.size(1)} превышает максимальную {self.max_seq_len}")
# Эмбеддинги токенов и позиций
tok_out = self._token_embeddings(x) # [batch, seq_len, emb_size]
#pos_out = self._position_embeddings(x) # [batch, seq_len, emb_size]
# Комбинирование
out = self._dropout(tok_out) # [batch, seq_len, emb_size]
# Стек декодеров с передачей кэша
new_cache = []
for i, decoder in enumerate(self._decoders):
decoder_cache = cache[i] if cache is not None else None
decoder_result = decoder(out, use_cache=use_cache, cache=decoder_cache)
# Извлекаем результат из кортежа
if use_cache:
out, decoder_new_cache = decoder_result
new_cache.append(decoder_new_cache)
else:
out = decoder_result[0]
out = self._norm(out)
logits = self._linear(out)
# Возвращаем результат с учетом use_cache
if use_cache:
return (logits, new_cache)
else:
return (logits, None)
def generate(self,
x: torch.Tensor,
max_new_tokens: int,
do_sample: bool,
temperature: float = 1.0,
top_k: int = None,
top_p: float = None,
use_cache: bool = True
) -> torch.Tensor:
"""
Авторегрессивная генерация токенов с поддержкой greedy, temperature, top-k/top-p sampling
и ускорением через attention-кэш (KV-cache, важно для inference на длинных текстах).
Аргументы:
x (torch.Tensor): Входной тензор с токенами shape [batch_size, seq_len].
max_new_tokens (int): Максимальное количество новых токенов для генерации.
do_sample (bool): Если True — вероятность/случайность (random sampling); если False — жадная генерация (argmax).
temperature (float): Температура (>0, по умолчанию 1.0); >1.0 — более случайные выборы, <1.0 — более строгие.
top_k (int, optional): top-k sampling; при сэмплировании выбираются только top_k наиболее вероятных токенов.
top_p (float, optional): nucleus (top-p) sampling; выбираются токены с накопленной вероятностью ≤ top_p.
use_cache (bool, по умолчанию True): Использовать ускорение через KV attention cache для autoregressive режима.
Возвращает:
torch.Tensor: Последовательность индексов токенов shape [batch_size, seq_len + max_new_tokens].
Исключения:
ValueError: Если x длиннее max_seq_len модели.
ValueError: Если temperature ≤ 0.
ValueError: Если одновременно заданы top_k и top_p.
ValueError: Если top_k ≤ 0.
ValueError: Если top_p не в диапазоне (0, 1].
Примеры:
>>> # Жадная генерация
>>> out = model.generate(input_ids, max_new_tokens=16, do_sample=False)
>>> # Сэмплирование с температурой
>>> out = model.generate(input_ids, max_new_tokens=16, do_sample=True, temperature=0.8)
>>> # Top-k sampling
>>> out = model.generate(input_ids, max_new_tokens=16, do_sample=True, top_k=50)
>>> # Top-p (nucleus) sampling
>>> out = model.generate(input_ids, max_new_tokens=16, do_sample=True, top_p=0.92)
>>> # Температура + top-k
>>> out = model.generate(input_ids, max_new_tokens=16, do_sample=True, temperature=1.0, top_k=100)
Примечания:
- Одновременно использовать top_k и top_p нельзя.
- Параметры temperature, top_k, top_p работают только при do_sample=True.
- Для полного воспроизведения результата зафиксируйте seed через torch.manual_seed.
- Метод всегда возвращает только индексы токенов; для получения логитов используйте forward.
Ссылки:
- Holtzman et al., "The Curious Case of Neural Text Degeneration" (nucleus/top-p sampling): https://arxiv.org/abs/1904.09751
- Mistral: https://arxiv.org/abs/2310.06825
"""
cache = None
for _ in range(max_new_tokens):
if use_cache and cache is not None:
# Используем кэш - передаем только последний токен
x_input = x[:, -1:] # [batch_size, 1]
else:
# Первая итерация или кэш отключен - передаем всю последовательность
x_input = x
# Прямой проход с кэшем
logits, new_cache = self.forward(x_input, use_cache=use_cache, cache=cache)
# Обновляем кэш для следующей итерации
if use_cache:
cache = new_cache
last_logits = logits[:, -1, :] # [batch_size, vocab_size]
# Масштабируем логиты температурой
if temperature > 0:
logits_scaled = last_logits / temperature
else:
logits_scaled = last_logits
if do_sample == True and top_k != None:
_, topk_indices = torch.topk(logits_scaled, top_k, dim=-1)
# # Заменим все НЕ top-k логиты на -inf
masked_logits = logits_scaled.clone()
vocab_size = logits_scaled.size(-1)
# создаём маску: 1, если токен НЕ в topk_indices
mask = torch.ones_like(logits_scaled, dtype=torch.bool if hasattr(torch, "bool") else torch.uint8)
mask.scatter_(1, topk_indices, False if hasattr(torch, "bool") else 0) # 0 там, где top-k индексы
masked_logits[mask.bool() if hasattr(torch, "bool") else mask.byte()] = float('-inf')
logits_scaled = masked_logits
if do_sample == True and top_p != None:
# 1. Применим softmax, чтобы получить вероятности:
probs = F.softmax(logits_scaled, dim=-1) # [B, vocab_size]
# 2. Отсортируем токены по убыванию вероятностей:
sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1)
# 3. Посчитаем кумулятивную сумму вероятностей:
cum_probs = torch.cumsum(sorted_probs, dim=-1) # [B, vocab_size]
# 4. Определим маску: оставить токены, пока сумма < top_p
sorted_mask = (cum_probs <= top_p).bool() if hasattr(torch, "bool") else (cum_probs <= top_p).byte() # [B, vocab_size]
# Гарантируем, что хотя бы первый токен останется
sorted_mask[:, 0] = True if hasattr(torch, "bool") else 1
# 5. Преобразуем маску обратно в оригинальный порядок:
# Создаём полную маску из 0
mask = torch.zeros_like(probs, dtype=torch.bool if hasattr(torch, "bool") else torch.uint8)
# Устанавливаем 1 в местах нужных токенов
mask.scatter_(dim=1, index=sorted_indices, src=sorted_mask)
# 6. Зануляем логиты токенов вне топ-p:
logits_scaled[~mask] = float('-inf')
# 4. Применяем Softmax
probs = F.softmax(logits_scaled, dim=-1) # [batch_size, vocab_size]
if do_sample == True:
# 5. Если do_sample равен True, то отбираем токен случайно с помощью torch.multinomial
next_token = torch.multinomial(probs, num_samples=1) # [batch_size, 1]
else:
# 5. Если do_sample равен False, то выбираем токен с максимальной вероятностью
next_token = torch.argmax(probs, dim=-1, keepdim=True) # [batch_size, 1]
# 6. Добавляем его к последовательности
x = torch.cat([x, next_token], dim=1) # [batch_size, seq_len+1]
return x
@property
def max_seq_len(self) -> int:
return self._max_seq_len

View File

@@ -0,0 +1,80 @@
import torch
import pytest
from llm.core.mixtral_decoder import MixtralDecoder
from llm.core.rope import RoPE
@pytest.fixture
def basic_decoder():
emb_size = 16
num_q_heads = 4
num_kv_heads = 2
head_size = 4
max_seq_len = 32
num_experts = 4
top_k_experts = 2
window_size = 8
rope = RoPE(head_size=head_size, max_seq_len=max_seq_len)
return MixtralDecoder(
num_q_heads=num_q_heads,
num_kv_heads=num_kv_heads,
emb_size=emb_size,
head_size=head_size,
max_seq_len=max_seq_len,
num_experts=num_experts,
top_k_experts=top_k_experts,
window_size=window_size,
rope=rope,
dropout=0.0,
)
def test_forward_shape(basic_decoder):
x = torch.randn(2, 10, 16)
out, cache = basic_decoder(x)
assert out.shape == (2, 10, 16)
assert cache is None or isinstance(cache, (tuple, list))
def test_forward_masked(basic_decoder):
x = torch.randn(3, 7, 16)
mask = torch.ones(3, 7, 7, dtype=torch.bool)
out, cache = basic_decoder(x, mask=mask)
assert out.shape == (3, 7, 16)
def test_forward_with_cache_flag(basic_decoder):
x = torch.randn(2, 8, 16)
out, cache = basic_decoder(x, use_cache=True, cache=None)
assert out.shape == (2, 8, 16)
assert isinstance(cache, (tuple, list)) or cache is None
def test_backprop_pass(basic_decoder):
x = torch.randn(2, 5, 16, requires_grad=True)
out, _ = basic_decoder(x)
y = out.sum()
y.backward()
assert x.grad is not None
assert x.grad.shape == x.shape
def test_seq_too_long_raises(basic_decoder):
x = torch.randn(1, 40, 16) # seq_len > max_seq_len
with pytest.raises(Exception):
basic_decoder(x)
def test_different_config():
rope = RoPE(head_size=2, max_seq_len=12)
decoder = MixtralDecoder(
num_q_heads=2, num_kv_heads=2, emb_size=4, head_size=2,
max_seq_len=12, num_experts=2, top_k_experts=1, window_size=4, rope=rope, dropout=0.1
)
x = torch.randn(1, 8, 4)
out, cache = decoder(x)
assert out.shape == x.shape
def test_forward_no_dropout():
# Проверка на корректность shape при отсутствии Dropout
rope = RoPE(head_size=2, max_seq_len=12)
decoder = MixtralDecoder(
num_q_heads=2, num_kv_heads=1, emb_size=4, head_size=2,
max_seq_len=12, num_experts=2, top_k_experts=1, window_size=3, rope=rope, dropout=0.0
)
x = torch.randn(2, 3, 4)
out, cache = decoder(x)
assert out.shape == x.shape

View File

@@ -0,0 +1,61 @@
import torch
import pytest
from llm.core.moe import MoE
@pytest.fixture
def moe():
# Базовая MoE для коротких тестов
return MoE(emb_size=16, num_experts=4, top_k_experts=2, dropout=0.0)
def test_forward_shape(moe):
x = torch.randn(3, 5, 16) # [batch, seq, emb]
y = moe(x)
assert y.shape == x.shape
def test_forward_grad(moe):
x = torch.randn(2, 4, 16, requires_grad=True)
y = moe(x)
(y.sum()).backward()
assert x.grad is not None
assert x.grad.shape == x.shape
def test_top_k_larger_than_experts():
# top_k_experts > num_experts должно падать
with pytest.raises(ValueError):
MoE(emb_size=8, num_experts=2, top_k_experts=4)
def test_single_expert_no_error():
# один эксперт, один топ-к — модель всё ещё валидна
moe = MoE(emb_size=8, num_experts=1, top_k_experts=1)
x = torch.randn(2, 2, 8)
y = moe(x)
assert y.shape == x.shape
def test_forward_trivial_weights():
"""Проверяет, что при одинаковых весах роутера MoE возвращает усреднённое по экспертам."""
class DummyMoE(MoE):
def forward(self, x):
# Роутер отдаёт всегда единичные логиты = softmax -> uniform
self._router = torch.nn.Linear(x.size(-1), self._num_experts, bias=False)
torch.nn.init.constant_(self._router.weight, 0.0)
return super().forward(x)
moe = DummyMoE(emb_size=4, num_experts=2, top_k_experts=2)
x = torch.zeros(1, 2, 4)
y = moe(x)
assert y.shape == x.shape
def test_forward_deterministic_seed(moe):
torch.manual_seed(42)
x = torch.randn(2, 3, 16)
y1 = moe(x)
torch.manual_seed(42)
y2 = moe(x)
assert torch.allclose(y1, y2, atol=1e-5)
def test_forward_no_dropout():
"""Без dropout MoE не меняет shape и не даёт NaN."""
moe = MoE(emb_size=5, num_experts=3, top_k_experts=2, dropout=0.0)
x = torch.randn(2, 7, 5)
y = moe(x)
assert y.shape == x.shape
assert not torch.isnan(y).any()

View File

@@ -0,0 +1,57 @@
import torch
import pytest
from llm.models.mixtral.mixtral import Mixtral
@pytest.fixture
def config():
return {
"vocab_size": 100,
"embed_dim": 32,
"num_q_heads": 4,
"num_kv_heads": 2,
"num_layers": 2,
"max_position_embeddings": 16,
"window_size": 8,
"dropout": 0.0,
"num_experts": 4,
"top_k_experts": 2,
}
@pytest.fixture
def model(config):
return Mixtral(config)
def test_forward_basic(model):
x = torch.randint(0, 100, (2, 8))
logits, cache = model(x)
assert logits.shape == (2, 8, 100)
assert isinstance(cache, list)
assert len(cache) == model._decoders.__len__()
def test_forward_with_cache(model):
x = torch.randint(0, 100, (2, 4))
logits, cache = model(x, use_cache=True)
x2 = torch.randint(0, 100, (2, 1))
logits2, cache2 = model(x2, use_cache=True, cache=cache)
assert logits2.shape == (2, 1, 100)
assert isinstance(cache2, list)
def test_generate_and_shape(model):
x = torch.randint(0, 100, (1, 5))
result = model.generate(x, max_new_tokens=3, do_sample=False)
assert result.shape == (1, 8)
def test_forward_sequence_too_long(model, config):
x = torch.randint(0, 100, (1, config["max_position_embeddings"] + 1))
with pytest.raises(ValueError):
model(x)
def test_generate_with_sampling_topk(model):
x = torch.randint(0, 100, (1, 3))
out = model.generate(x, max_new_tokens=2, do_sample=True, top_k=5)
assert out.shape == (1, 5)
def test_generate_with_sampling_topp(model):
x = torch.randint(0, 100, (1, 3))
out = model.generate(x, max_new_tokens=2, do_sample=True, top_p=0.8)
assert out.shape == (1, 5)

1510
notebooks/mixstral.ipynb Normal file

File diff suppressed because it is too large Load Diff