feat(mixtral): initial implementation of Mixtral MoE model, configs, and tests

- Add Mixtral architecture implementation with MoE support (llm/src/llm/models/mixtral/mixtral.py)
- Introduce generic Mixture-of-Experts (MoE) block (llm/src/llm/core/moe.py)
- Create dedicated configuration files for Mixtral training and generation experiments
- Register and test Mixtral support in experiment runner (run_llm_experiment.py)
- Add unit tests for Mixtral API including forward, caching, and generation modes
- Include Jupyter notebook mixstral.ipynb for architectural exploration and research
- Ensure correct handling of torch bool masks in sampling (top-k, top-p) during generation

BREAKING CHANGE: Adds new model code and test coverage, modifying experiment runner logic to register Mixtral.
This commit is contained in:
Sergey Penkovsky
2025-10-20 08:12:11 +03:00
parent 1aba02cab9
commit b1737bbce2
8 changed files with 2007 additions and 0 deletions

92
llm/src/llm/core/moe.py Normal file
View File

@@ -0,0 +1,92 @@
import torch
from torch import nn
import torch.nn.functional as F
from llm.core.swi_glu import SwiGLU
class MoE(nn.Module):
def __init__(
self,
emb_size: int,
num_experts: int,
top_k_experts: int,
dropout: float = 0.1,
):
super().__init__()
self._num_experts = num_experts
self._top_k_experts = top_k_experts
self._router = nn.Linear(emb_size, num_experts)
self._experts = nn.ModuleList([SwiGLU(
emb_size=emb_size,
dropout=dropout,
) for _ in range(num_experts)])
self._dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor):
batch_size, seq_len, emb_size = x.shape
# 1. Пропускаем через роутер
router_logits = self._router(x) # [batch_size, seq_len, num_experts]
# 2. Отбираем топ-k экспертов для каждого токена
topk_logits, topk_indices = torch.topk(
router_logits,
k=self._top_k_experts,
dim=-1
) # topk_logits: [batch_size, seq_len, top_k]
# topk_indices: [batch_size, seq_len, top_k]
# 3. Получаем веса через softmax и нормируем
topk_weights = F.softmax(topk_logits, dim=-1) # [batch_size, seq_len, top_k]
# 4. Создаём нулевой тензор для результата
output = torch.zeros_like(x) # [batch_size, seq_len, emb_size]
# 5. Проходим по всем экспертам
for expert_id in range(self._num_experts):
# Шаг 1: Создаём маску - где находится текущий эксперт в топ-k
expert_mask = (topk_indices == expert_id) # [batch_size, seq_len, top_k]
# Шаг 2: Проверяем, выбран ли эксперт хотя бы одним токеном
if not expert_mask.any():
continue # Эксперт никем не выбран, переходим к следующему
# Шаг 3: Находим токены, которые выбрали этого эксперта
# (хотя бы в одной из top_k позиций)
token_mask = expert_mask.any(dim=-1) # [batch_size, seq_len]
# Шаг 4: Отбираем токены из x
# Отбираем токены для этого эксперта
expert_input = x[token_mask]
# Пропускаем через эксперта
# Добавляем batch dimension для SwiGLU и затем убираем
expert_output = self._experts[expert_id](
expert_input.unsqueeze(0)
).squeeze(0)
# Получаем веса для этого эксперта
# Для каждого токена может быть несколько весов (если эксперт в топ-k несколько раз)
# Но на практике каждый эксперт появляется максимум 1 раз в топ-k
# Находим веса: где expert_mask == True, берём соответствующий вес
weights_for_expert = torch.zeros(
batch_size, seq_len, device=x.device
)
# Для каждой позиции в топ-k
for k in range(self._top_k_experts):
mask_k = topk_indices[:, :, k] == expert_id
weights_for_expert[mask_k] = topk_weights[:, :, k][mask_k]
# Отбираем только веса для выбранных токенов
selected_weights = weights_for_expert[token_mask] # [num_selected_tokens]
# Перемножьте выход эксперта на веса текущего эксперта.
weighted_output = selected_weights.unsqueeze(-1) * expert_output
# Помещаем результат на своё место в выходном тензоре
output[token_mask] += weighted_output
out = self._dropout(output)
return out

View File

@@ -0,0 +1,3 @@
from .mixtral import Mixtral
__all__ = ["Mixtral"]

View File

@@ -0,0 +1,295 @@
import torch
from torch import nn
from torch import Tensor
import torch.nn.functional as F
from math import sqrt
from llm.core.base_model import BaseModel
from llm.core.token_embeddings import TokenEmbeddings
from llm.core.rope import RoPE
from llm.core.rms_norm import RMSNorm
from llm.core.moe import MoE
from llm.core.group_query_attention import GroupedQueryAttention
class Decoder(nn.Module):
def __init__(self,
num_q_heads: int,
num_kv_heads: int,
emb_size: int,
head_size: int,
max_seq_len: int,
num_experts: int,
top_k_experts: int,
window_size: int,
rope: RoPE,
dropout: float = 0.1
):
super().__init__()
self._heads = GroupedQueryAttention(
num_q_heads=num_q_heads,
num_kv_heads=num_kv_heads,
emb_size=emb_size,
head_size=head_size,
max_seq_len=max_seq_len,
window_size=window_size,
rope=rope,
dropout=dropout
)
self._ff = MoE(
emb_size=emb_size,
num_experts=num_experts,
top_k_experts=top_k_experts,
dropout=dropout
)
self._norm1 = RMSNorm(emb_size)
self._norm2 = RMSNorm(emb_size)
def forward(self, x: torch.Tensor, mask: torch.Tensor = None, use_cache: bool = True, cache: list = None) -> torch.Tensor:
norm1_out = self._norm1(x)
attention, kv_caches = self._heads(norm1_out, mask, use_cache=use_cache, cache=cache)
out = attention + x
norm2_out = self._norm2(out)
ffn_out = self._ff(norm2_out)
if use_cache is True:
return (ffn_out + out, kv_caches)
else:
return (ffn_out + out, None)
from torch import nn
import torch
import torch.nn.functional as F
class Mixtral(BaseModel):
def __init__(self, config):
super().__init__(config)
self._max_seq_len = config["max_position_embeddings"]
# Инициализация слоев
self._token_embeddings = TokenEmbeddings(
vocab_size=config["vocab_size"],
emb_size=config["embed_dim"]
)
self._position_embeddings = RoPE(
head_size=config["embed_dim"] // config["num_q_heads"],
max_seq_len=config["max_position_embeddings"]
)
#self._position_embeddings = PositionalEmbeddings(
# max_seq_len=max_seq_len,
# emb_size=emb_size
#)
self._dropout = nn.Dropout(config["dropout"])
self._decoders = nn.ModuleList([Decoder(
num_q_heads=config["num_q_heads"],
num_kv_heads=config["num_kv_heads"],
emb_size=config["embed_dim"],
head_size=config["embed_dim"] // config["num_q_heads"],
max_seq_len=config["max_position_embeddings"],
num_experts=config["num_experts"],
top_k_experts=config["top_k_experts"],
window_size=config["window_size"],
rope=self._position_embeddings,
dropout=config["dropout"]
) for _ in range(config["num_layers"])])
self._norm = RMSNorm(config["embed_dim"])
self._linear = nn.Linear(config["embed_dim"], config["vocab_size"])
def forward(self, x: torch.Tensor, use_cache: bool = True, cache: list = None) -> tuple:
# Проверка длины последовательности (только при отсутствии кэша)
if cache is None and x.size(1) > self._max_seq_len:
raise ValueError(f"Длина последовательности {x.size(1)} превышает максимальную {self.max_seq_len}")
# Эмбеддинги токенов и позиций
tok_out = self._token_embeddings(x) # [batch, seq_len, emb_size]
#pos_out = self._position_embeddings(x) # [batch, seq_len, emb_size]
# Комбинирование
out = self._dropout(tok_out) # [batch, seq_len, emb_size]
# Стек декодеров с передачей кэша
new_cache = []
for i, decoder in enumerate(self._decoders):
decoder_cache = cache[i] if cache is not None else None
decoder_result = decoder(out, use_cache=use_cache, cache=decoder_cache)
# Извлекаем результат из кортежа
if use_cache:
out, decoder_new_cache = decoder_result
new_cache.append(decoder_new_cache)
else:
out = decoder_result[0]
out = self._norm(out)
logits = self._linear(out)
# Возвращаем результат с учетом use_cache
if use_cache:
return (logits, new_cache)
else:
return (logits, None)
def generate(self,
x: torch.Tensor,
max_new_tokens: int,
do_sample: bool,
temperature: float = 1.0,
top_k: int = None,
top_p: float = None,
use_cache: bool = True
) -> torch.Tensor:
"""
Авторегрессивная генерация токенов с поддержкой greedy, temperature, top-k/top-p sampling
и ускорением через attention-кэш (KV-cache, важно для inference на длинных текстах).
Аргументы:
x (torch.Tensor): Входной тензор с токенами shape [batch_size, seq_len].
max_new_tokens (int): Максимальное количество новых токенов для генерации.
do_sample (bool): Если True — вероятность/случайность (random sampling); если False — жадная генерация (argmax).
temperature (float): Температура (>0, по умолчанию 1.0); >1.0 — более случайные выборы, <1.0 — более строгие.
top_k (int, optional): top-k sampling; при сэмплировании выбираются только top_k наиболее вероятных токенов.
top_p (float, optional): nucleus (top-p) sampling; выбираются токены с накопленной вероятностью ≤ top_p.
use_cache (bool, по умолчанию True): Использовать ускорение через KV attention cache для autoregressive режима.
Возвращает:
torch.Tensor: Последовательность индексов токенов shape [batch_size, seq_len + max_new_tokens].
Исключения:
ValueError: Если x длиннее max_seq_len модели.
ValueError: Если temperature ≤ 0.
ValueError: Если одновременно заданы top_k и top_p.
ValueError: Если top_k ≤ 0.
ValueError: Если top_p не в диапазоне (0, 1].
Примеры:
>>> # Жадная генерация
>>> out = model.generate(input_ids, max_new_tokens=16, do_sample=False)
>>> # Сэмплирование с температурой
>>> out = model.generate(input_ids, max_new_tokens=16, do_sample=True, temperature=0.8)
>>> # Top-k sampling
>>> out = model.generate(input_ids, max_new_tokens=16, do_sample=True, top_k=50)
>>> # Top-p (nucleus) sampling
>>> out = model.generate(input_ids, max_new_tokens=16, do_sample=True, top_p=0.92)
>>> # Температура + top-k
>>> out = model.generate(input_ids, max_new_tokens=16, do_sample=True, temperature=1.0, top_k=100)
Примечания:
- Одновременно использовать top_k и top_p нельзя.
- Параметры temperature, top_k, top_p работают только при do_sample=True.
- Для полного воспроизведения результата зафиксируйте seed через torch.manual_seed.
- Метод всегда возвращает только индексы токенов; для получения логитов используйте forward.
Ссылки:
- Holtzman et al., "The Curious Case of Neural Text Degeneration" (nucleus/top-p sampling): https://arxiv.org/abs/1904.09751
- Mistral: https://arxiv.org/abs/2310.06825
"""
cache = None
for _ in range(max_new_tokens):
if use_cache and cache is not None:
# Используем кэш - передаем только последний токен
x_input = x[:, -1:] # [batch_size, 1]
else:
# Первая итерация или кэш отключен - передаем всю последовательность
x_input = x
# Прямой проход с кэшем
logits, new_cache = self.forward(x_input, use_cache=use_cache, cache=cache)
# Обновляем кэш для следующей итерации
if use_cache:
cache = new_cache
last_logits = logits[:, -1, :] # [batch_size, vocab_size]
# Масштабируем логиты температурой
if temperature > 0:
logits_scaled = last_logits / temperature
else:
logits_scaled = last_logits
if do_sample == True and top_k != None:
_, topk_indices = torch.topk(logits_scaled, top_k, dim=-1)
# # Заменим все НЕ top-k логиты на -inf
masked_logits = logits_scaled.clone()
vocab_size = logits_scaled.size(-1)
# создаём маску: 1, если токен НЕ в topk_indices
mask = torch.ones_like(logits_scaled, dtype=torch.bool if hasattr(torch, "bool") else torch.uint8)
mask.scatter_(1, topk_indices, False if hasattr(torch, "bool") else 0) # 0 там, где top-k индексы
masked_logits[mask.bool() if hasattr(torch, "bool") else mask.byte()] = float('-inf')
logits_scaled = masked_logits
if do_sample == True and top_p != None:
# 1. Применим softmax, чтобы получить вероятности:
probs = F.softmax(logits_scaled, dim=-1) # [B, vocab_size]
# 2. Отсортируем токены по убыванию вероятностей:
sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1)
# 3. Посчитаем кумулятивную сумму вероятностей:
cum_probs = torch.cumsum(sorted_probs, dim=-1) # [B, vocab_size]
# 4. Определим маску: оставить токены, пока сумма < top_p
sorted_mask = (cum_probs <= top_p).bool() if hasattr(torch, "bool") else (cum_probs <= top_p).byte() # [B, vocab_size]
# Гарантируем, что хотя бы первый токен останется
sorted_mask[:, 0] = True if hasattr(torch, "bool") else 1
# 5. Преобразуем маску обратно в оригинальный порядок:
# Создаём полную маску из 0
mask = torch.zeros_like(probs, dtype=torch.bool if hasattr(torch, "bool") else torch.uint8)
# Устанавливаем 1 в местах нужных токенов
mask.scatter_(dim=1, index=sorted_indices, src=sorted_mask)
# 6. Зануляем логиты токенов вне топ-p:
logits_scaled[~mask] = float('-inf')
# 4. Применяем Softmax
probs = F.softmax(logits_scaled, dim=-1) # [batch_size, vocab_size]
if do_sample == True:
# 5. Если do_sample равен True, то отбираем токен случайно с помощью torch.multinomial
next_token = torch.multinomial(probs, num_samples=1) # [batch_size, 1]
else:
# 5. Если do_sample равен False, то выбираем токен с максимальной вероятностью
next_token = torch.argmax(probs, dim=-1, keepdim=True) # [batch_size, 1]
# 6. Добавляем его к последовательности
x = torch.cat([x, next_token], dim=1) # [batch_size, seq_len+1]
return x
def save(self, path):
torch.save({
'model_state_dict': self.state_dict(),
'vocab_size': self._vocab_size,
'max_seq_len': self._max_seq_len,
'emb_size': self._emb_size,
'num_heads': self._num_heads,
'head_size': self._head_size,
'num_layers': self._num_layers
}, path)
@classmethod
def load(cls, path, device):
checkpoint = torch.load(path, map_location=device)
model = cls(
vocab_size=checkpoint['vocab_size'],
max_seq_len=checkpoint['max_seq_len'],
emb_size=checkpoint['emb_size'],
num_heads=checkpoint['num_heads'],
head_size=checkpoint['head_size'],
num_layers=checkpoint['num_layers']
)
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device)
return model
@property
def max_seq_len(self) -> int:
return self._max_seq_len

View File

@@ -0,0 +1,57 @@
import torch
import pytest
from llm.models.mixtral.mixtral import Mixtral
@pytest.fixture
def config():
return {
"vocab_size": 100,
"embed_dim": 32,
"num_q_heads": 4,
"num_kv_heads": 2,
"num_layers": 2,
"max_position_embeddings": 16,
"window_size": 8,
"dropout": 0.0,
"num_experts": 4,
"top_k_experts": 2,
}
@pytest.fixture
def model(config):
return Mixtral(config)
def test_forward_basic(model):
x = torch.randint(0, 100, (2, 8))
logits, cache = model(x)
assert logits.shape == (2, 8, 100)
assert isinstance(cache, list)
assert len(cache) == model._decoders.__len__()
def test_forward_with_cache(model):
x = torch.randint(0, 100, (2, 4))
logits, cache = model(x, use_cache=True)
x2 = torch.randint(0, 100, (2, 1))
logits2, cache2 = model(x2, use_cache=True, cache=cache)
assert logits2.shape == (2, 1, 100)
assert isinstance(cache2, list)
def test_generate_and_shape(model):
x = torch.randint(0, 100, (1, 5))
result = model.generate(x, max_new_tokens=3, do_sample=False)
assert result.shape == (1, 8)
def test_forward_sequence_too_long(model, config):
x = torch.randint(0, 100, (1, config["max_position_embeddings"] + 1))
with pytest.raises(ValueError):
model(x)
def test_generate_with_sampling_topk(model):
x = torch.randint(0, 100, (1, 3))
out = model.generate(x, max_new_tokens=2, do_sample=True, top_k=5)
assert out.shape == (1, 5)
def test_generate_with_sampling_topp(model):
x = torch.randint(0, 100, (1, 3))
out = model.generate(x, max_new_tokens=2, do_sample=True, top_p=0.8)
assert out.shape == (1, 5)