mirror of
https://github.com/pese-git/llm-arch-research.git
synced 2026-01-23 21:10:54 +00:00
feat(mixtral): initial implementation of Mixtral MoE model, configs, and tests
- Add Mixtral architecture implementation with MoE support (llm/src/llm/models/mixtral/mixtral.py) - Introduce generic Mixture-of-Experts (MoE) block (llm/src/llm/core/moe.py) - Create dedicated configuration files for Mixtral training and generation experiments - Register and test Mixtral support in experiment runner (run_llm_experiment.py) - Add unit tests for Mixtral API including forward, caching, and generation modes - Include Jupyter notebook mixstral.ipynb for architectural exploration and research - Ensure correct handling of torch bool masks in sampling (top-k, top-p) during generation BREAKING CHANGE: Adds new model code and test coverage, modifying experiment runner logic to register Mixtral.
This commit is contained in:
92
llm/src/llm/core/moe.py
Normal file
92
llm/src/llm/core/moe.py
Normal file
@@ -0,0 +1,92 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from llm.core.swi_glu import SwiGLU
|
||||
|
||||
class MoE(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
emb_size: int,
|
||||
num_experts: int,
|
||||
top_k_experts: int,
|
||||
dropout: float = 0.1,
|
||||
):
|
||||
super().__init__()
|
||||
self._num_experts = num_experts
|
||||
self._top_k_experts = top_k_experts
|
||||
|
||||
self._router = nn.Linear(emb_size, num_experts)
|
||||
self._experts = nn.ModuleList([SwiGLU(
|
||||
emb_size=emb_size,
|
||||
dropout=dropout,
|
||||
) for _ in range(num_experts)])
|
||||
self._dropout = nn.Dropout(dropout)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
batch_size, seq_len, emb_size = x.shape
|
||||
|
||||
# 1. Пропускаем через роутер
|
||||
router_logits = self._router(x) # [batch_size, seq_len, num_experts]
|
||||
|
||||
# 2. Отбираем топ-k экспертов для каждого токена
|
||||
topk_logits, topk_indices = torch.topk(
|
||||
router_logits,
|
||||
k=self._top_k_experts,
|
||||
dim=-1
|
||||
) # topk_logits: [batch_size, seq_len, top_k]
|
||||
# topk_indices: [batch_size, seq_len, top_k]
|
||||
|
||||
# 3. Получаем веса через softmax и нормируем
|
||||
topk_weights = F.softmax(topk_logits, dim=-1) # [batch_size, seq_len, top_k]
|
||||
|
||||
# 4. Создаём нулевой тензор для результата
|
||||
output = torch.zeros_like(x) # [batch_size, seq_len, emb_size]
|
||||
|
||||
# 5. Проходим по всем экспертам
|
||||
for expert_id in range(self._num_experts):
|
||||
# Шаг 1: Создаём маску - где находится текущий эксперт в топ-k
|
||||
expert_mask = (topk_indices == expert_id) # [batch_size, seq_len, top_k]
|
||||
# Шаг 2: Проверяем, выбран ли эксперт хотя бы одним токеном
|
||||
if not expert_mask.any():
|
||||
continue # Эксперт никем не выбран, переходим к следующему
|
||||
|
||||
# Шаг 3: Находим токены, которые выбрали этого эксперта
|
||||
# (хотя бы в одной из top_k позиций)
|
||||
token_mask = expert_mask.any(dim=-1) # [batch_size, seq_len]
|
||||
|
||||
# Шаг 4: Отбираем токены из x
|
||||
# Отбираем токены для этого эксперта
|
||||
expert_input = x[token_mask]
|
||||
|
||||
# Пропускаем через эксперта
|
||||
# Добавляем batch dimension для SwiGLU и затем убираем
|
||||
expert_output = self._experts[expert_id](
|
||||
expert_input.unsqueeze(0)
|
||||
).squeeze(0)
|
||||
|
||||
# Получаем веса для этого эксперта
|
||||
# Для каждого токена может быть несколько весов (если эксперт в топ-k несколько раз)
|
||||
# Но на практике каждый эксперт появляется максимум 1 раз в топ-k
|
||||
# Находим веса: где expert_mask == True, берём соответствующий вес
|
||||
weights_for_expert = torch.zeros(
|
||||
batch_size, seq_len, device=x.device
|
||||
)
|
||||
|
||||
# Для каждой позиции в топ-k
|
||||
for k in range(self._top_k_experts):
|
||||
mask_k = topk_indices[:, :, k] == expert_id
|
||||
weights_for_expert[mask_k] = topk_weights[:, :, k][mask_k]
|
||||
|
||||
# Отбираем только веса для выбранных токенов
|
||||
selected_weights = weights_for_expert[token_mask] # [num_selected_tokens]
|
||||
|
||||
|
||||
# Перемножьте выход эксперта на веса текущего эксперта.
|
||||
weighted_output = selected_weights.unsqueeze(-1) * expert_output
|
||||
|
||||
# Помещаем результат на своё место в выходном тензоре
|
||||
output[token_mask] += weighted_output
|
||||
|
||||
out = self._dropout(output)
|
||||
|
||||
return out
|
||||
3
llm/src/llm/models/mixtral/__init__.py
Normal file
3
llm/src/llm/models/mixtral/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .mixtral import Mixtral
|
||||
|
||||
__all__ = ["Mixtral"]
|
||||
295
llm/src/llm/models/mixtral/mixtral.py
Normal file
295
llm/src/llm/models/mixtral/mixtral.py
Normal file
@@ -0,0 +1,295 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch import Tensor
|
||||
import torch.nn.functional as F
|
||||
from math import sqrt
|
||||
from llm.core.base_model import BaseModel
|
||||
from llm.core.token_embeddings import TokenEmbeddings
|
||||
from llm.core.rope import RoPE
|
||||
from llm.core.rms_norm import RMSNorm
|
||||
from llm.core.moe import MoE
|
||||
from llm.core.group_query_attention import GroupedQueryAttention
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(self,
|
||||
num_q_heads: int,
|
||||
num_kv_heads: int,
|
||||
emb_size: int,
|
||||
head_size: int,
|
||||
max_seq_len: int,
|
||||
num_experts: int,
|
||||
top_k_experts: int,
|
||||
window_size: int,
|
||||
rope: RoPE,
|
||||
dropout: float = 0.1
|
||||
):
|
||||
super().__init__()
|
||||
self._heads = GroupedQueryAttention(
|
||||
num_q_heads=num_q_heads,
|
||||
num_kv_heads=num_kv_heads,
|
||||
emb_size=emb_size,
|
||||
head_size=head_size,
|
||||
max_seq_len=max_seq_len,
|
||||
window_size=window_size,
|
||||
rope=rope,
|
||||
dropout=dropout
|
||||
)
|
||||
self._ff = MoE(
|
||||
emb_size=emb_size,
|
||||
num_experts=num_experts,
|
||||
top_k_experts=top_k_experts,
|
||||
dropout=dropout
|
||||
)
|
||||
self._norm1 = RMSNorm(emb_size)
|
||||
self._norm2 = RMSNorm(emb_size)
|
||||
|
||||
def forward(self, x: torch.Tensor, mask: torch.Tensor = None, use_cache: bool = True, cache: list = None) -> torch.Tensor:
|
||||
norm1_out = self._norm1(x)
|
||||
attention, kv_caches = self._heads(norm1_out, mask, use_cache=use_cache, cache=cache)
|
||||
out = attention + x
|
||||
|
||||
norm2_out = self._norm2(out)
|
||||
ffn_out = self._ff(norm2_out)
|
||||
|
||||
if use_cache is True:
|
||||
return (ffn_out + out, kv_caches)
|
||||
else:
|
||||
return (ffn_out + out, None)
|
||||
|
||||
|
||||
|
||||
from torch import nn
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
class Mixtral(BaseModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
self._max_seq_len = config["max_position_embeddings"]
|
||||
|
||||
# Инициализация слоев
|
||||
self._token_embeddings = TokenEmbeddings(
|
||||
vocab_size=config["vocab_size"],
|
||||
emb_size=config["embed_dim"]
|
||||
)
|
||||
self._position_embeddings = RoPE(
|
||||
head_size=config["embed_dim"] // config["num_q_heads"],
|
||||
max_seq_len=config["max_position_embeddings"]
|
||||
)
|
||||
#self._position_embeddings = PositionalEmbeddings(
|
||||
# max_seq_len=max_seq_len,
|
||||
# emb_size=emb_size
|
||||
#)
|
||||
self._dropout = nn.Dropout(config["dropout"])
|
||||
self._decoders = nn.ModuleList([Decoder(
|
||||
num_q_heads=config["num_q_heads"],
|
||||
num_kv_heads=config["num_kv_heads"],
|
||||
emb_size=config["embed_dim"],
|
||||
head_size=config["embed_dim"] // config["num_q_heads"],
|
||||
max_seq_len=config["max_position_embeddings"],
|
||||
num_experts=config["num_experts"],
|
||||
top_k_experts=config["top_k_experts"],
|
||||
window_size=config["window_size"],
|
||||
rope=self._position_embeddings,
|
||||
dropout=config["dropout"]
|
||||
) for _ in range(config["num_layers"])])
|
||||
self._norm = RMSNorm(config["embed_dim"])
|
||||
self._linear = nn.Linear(config["embed_dim"], config["vocab_size"])
|
||||
|
||||
def forward(self, x: torch.Tensor, use_cache: bool = True, cache: list = None) -> tuple:
|
||||
# Проверка длины последовательности (только при отсутствии кэша)
|
||||
if cache is None and x.size(1) > self._max_seq_len:
|
||||
raise ValueError(f"Длина последовательности {x.size(1)} превышает максимальную {self.max_seq_len}")
|
||||
|
||||
# Эмбеддинги токенов и позиций
|
||||
tok_out = self._token_embeddings(x) # [batch, seq_len, emb_size]
|
||||
#pos_out = self._position_embeddings(x) # [batch, seq_len, emb_size]
|
||||
|
||||
# Комбинирование
|
||||
out = self._dropout(tok_out) # [batch, seq_len, emb_size]
|
||||
|
||||
# Стек декодеров с передачей кэша
|
||||
new_cache = []
|
||||
for i, decoder in enumerate(self._decoders):
|
||||
decoder_cache = cache[i] if cache is not None else None
|
||||
decoder_result = decoder(out, use_cache=use_cache, cache=decoder_cache)
|
||||
|
||||
# Извлекаем результат из кортежа
|
||||
if use_cache:
|
||||
out, decoder_new_cache = decoder_result
|
||||
new_cache.append(decoder_new_cache)
|
||||
else:
|
||||
out = decoder_result[0]
|
||||
|
||||
out = self._norm(out)
|
||||
logits = self._linear(out)
|
||||
|
||||
# Возвращаем результат с учетом use_cache
|
||||
if use_cache:
|
||||
return (logits, new_cache)
|
||||
else:
|
||||
return (logits, None)
|
||||
|
||||
def generate(self,
|
||||
x: torch.Tensor,
|
||||
max_new_tokens: int,
|
||||
do_sample: bool,
|
||||
temperature: float = 1.0,
|
||||
top_k: int = None,
|
||||
top_p: float = None,
|
||||
use_cache: bool = True
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Авторегрессивная генерация токенов с поддержкой greedy, temperature, top-k/top-p sampling
|
||||
и ускорением через attention-кэш (KV-cache, важно для inference на длинных текстах).
|
||||
|
||||
Аргументы:
|
||||
x (torch.Tensor): Входной тензор с токенами shape [batch_size, seq_len].
|
||||
max_new_tokens (int): Максимальное количество новых токенов для генерации.
|
||||
do_sample (bool): Если True — вероятность/случайность (random sampling); если False — жадная генерация (argmax).
|
||||
temperature (float): Температура (>0, по умолчанию 1.0); >1.0 — более случайные выборы, <1.0 — более строгие.
|
||||
top_k (int, optional): top-k sampling; при сэмплировании выбираются только top_k наиболее вероятных токенов.
|
||||
top_p (float, optional): nucleus (top-p) sampling; выбираются токены с накопленной вероятностью ≤ top_p.
|
||||
use_cache (bool, по умолчанию True): Использовать ускорение через KV attention cache для autoregressive режима.
|
||||
|
||||
Возвращает:
|
||||
torch.Tensor: Последовательность индексов токенов shape [batch_size, seq_len + max_new_tokens].
|
||||
|
||||
Исключения:
|
||||
ValueError: Если x длиннее max_seq_len модели.
|
||||
ValueError: Если temperature ≤ 0.
|
||||
ValueError: Если одновременно заданы top_k и top_p.
|
||||
ValueError: Если top_k ≤ 0.
|
||||
ValueError: Если top_p не в диапазоне (0, 1].
|
||||
|
||||
Примеры:
|
||||
>>> # Жадная генерация
|
||||
>>> out = model.generate(input_ids, max_new_tokens=16, do_sample=False)
|
||||
>>> # Сэмплирование с температурой
|
||||
>>> out = model.generate(input_ids, max_new_tokens=16, do_sample=True, temperature=0.8)
|
||||
>>> # Top-k sampling
|
||||
>>> out = model.generate(input_ids, max_new_tokens=16, do_sample=True, top_k=50)
|
||||
>>> # Top-p (nucleus) sampling
|
||||
>>> out = model.generate(input_ids, max_new_tokens=16, do_sample=True, top_p=0.92)
|
||||
>>> # Температура + top-k
|
||||
>>> out = model.generate(input_ids, max_new_tokens=16, do_sample=True, temperature=1.0, top_k=100)
|
||||
|
||||
Примечания:
|
||||
- Одновременно использовать top_k и top_p нельзя.
|
||||
- Параметры temperature, top_k, top_p работают только при do_sample=True.
|
||||
- Для полного воспроизведения результата зафиксируйте seed через torch.manual_seed.
|
||||
- Метод всегда возвращает только индексы токенов; для получения логитов используйте forward.
|
||||
|
||||
Ссылки:
|
||||
- Holtzman et al., "The Curious Case of Neural Text Degeneration" (nucleus/top-p sampling): https://arxiv.org/abs/1904.09751
|
||||
- Mistral: https://arxiv.org/abs/2310.06825
|
||||
"""
|
||||
cache = None
|
||||
|
||||
for _ in range(max_new_tokens):
|
||||
if use_cache and cache is not None:
|
||||
# Используем кэш - передаем только последний токен
|
||||
x_input = x[:, -1:] # [batch_size, 1]
|
||||
else:
|
||||
# Первая итерация или кэш отключен - передаем всю последовательность
|
||||
x_input = x
|
||||
|
||||
# Прямой проход с кэшем
|
||||
logits, new_cache = self.forward(x_input, use_cache=use_cache, cache=cache)
|
||||
|
||||
# Обновляем кэш для следующей итерации
|
||||
if use_cache:
|
||||
cache = new_cache
|
||||
|
||||
last_logits = logits[:, -1, :] # [batch_size, vocab_size]
|
||||
|
||||
# Масштабируем логиты температурой
|
||||
if temperature > 0:
|
||||
logits_scaled = last_logits / temperature
|
||||
else:
|
||||
logits_scaled = last_logits
|
||||
|
||||
if do_sample == True and top_k != None:
|
||||
_, topk_indices = torch.topk(logits_scaled, top_k, dim=-1)
|
||||
|
||||
# # Заменим все НЕ top-k логиты на -inf
|
||||
masked_logits = logits_scaled.clone()
|
||||
vocab_size = logits_scaled.size(-1)
|
||||
|
||||
# создаём маску: 1, если токен НЕ в topk_indices
|
||||
mask = torch.ones_like(logits_scaled, dtype=torch.bool if hasattr(torch, "bool") else torch.uint8)
|
||||
mask.scatter_(1, topk_indices, False if hasattr(torch, "bool") else 0) # 0 там, где top-k индексы
|
||||
masked_logits[mask.bool() if hasattr(torch, "bool") else mask.byte()] = float('-inf')
|
||||
|
||||
logits_scaled = masked_logits
|
||||
|
||||
if do_sample == True and top_p != None:
|
||||
# 1. Применим softmax, чтобы получить вероятности:
|
||||
probs = F.softmax(logits_scaled, dim=-1) # [B, vocab_size]
|
||||
# 2. Отсортируем токены по убыванию вероятностей:
|
||||
sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1)
|
||||
# 3. Посчитаем кумулятивную сумму вероятностей:
|
||||
cum_probs = torch.cumsum(sorted_probs, dim=-1) # [B, vocab_size]
|
||||
# 4. Определим маску: оставить токены, пока сумма < top_p
|
||||
sorted_mask = (cum_probs <= top_p).bool() if hasattr(torch, "bool") else (cum_probs <= top_p).byte() # [B, vocab_size]
|
||||
# Гарантируем, что хотя бы первый токен останется
|
||||
sorted_mask[:, 0] = True if hasattr(torch, "bool") else 1
|
||||
# 5. Преобразуем маску обратно в оригинальный порядок:
|
||||
# Создаём полную маску из 0
|
||||
mask = torch.zeros_like(probs, dtype=torch.bool if hasattr(torch, "bool") else torch.uint8)
|
||||
# Устанавливаем 1 в местах нужных токенов
|
||||
mask.scatter_(dim=1, index=sorted_indices, src=sorted_mask)
|
||||
# 6. Зануляем логиты токенов вне топ-p:
|
||||
logits_scaled[~mask] = float('-inf')
|
||||
|
||||
# 4. Применяем Softmax
|
||||
probs = F.softmax(logits_scaled, dim=-1) # [batch_size, vocab_size]
|
||||
|
||||
|
||||
if do_sample == True:
|
||||
# 5. Если do_sample равен True, то отбираем токен случайно с помощью torch.multinomial
|
||||
next_token = torch.multinomial(probs, num_samples=1) # [batch_size, 1]
|
||||
else:
|
||||
# 5. Если do_sample равен False, то выбираем токен с максимальной вероятностью
|
||||
next_token = torch.argmax(probs, dim=-1, keepdim=True) # [batch_size, 1]
|
||||
|
||||
# 6. Добавляем его к последовательности
|
||||
x = torch.cat([x, next_token], dim=1) # [batch_size, seq_len+1]
|
||||
return x
|
||||
|
||||
|
||||
def save(self, path):
|
||||
torch.save({
|
||||
'model_state_dict': self.state_dict(),
|
||||
'vocab_size': self._vocab_size,
|
||||
'max_seq_len': self._max_seq_len,
|
||||
'emb_size': self._emb_size,
|
||||
'num_heads': self._num_heads,
|
||||
'head_size': self._head_size,
|
||||
'num_layers': self._num_layers
|
||||
}, path)
|
||||
|
||||
@classmethod
|
||||
def load(cls, path, device):
|
||||
checkpoint = torch.load(path, map_location=device)
|
||||
model = cls(
|
||||
vocab_size=checkpoint['vocab_size'],
|
||||
max_seq_len=checkpoint['max_seq_len'],
|
||||
emb_size=checkpoint['emb_size'],
|
||||
num_heads=checkpoint['num_heads'],
|
||||
head_size=checkpoint['head_size'],
|
||||
num_layers=checkpoint['num_layers']
|
||||
)
|
||||
model.load_state_dict(checkpoint['model_state_dict'])
|
||||
model.to(device)
|
||||
return model
|
||||
|
||||
@property
|
||||
def max_seq_len(self) -> int:
|
||||
return self._max_seq_len
|
||||
|
||||
|
||||
|
||||
|
||||
57
llm/tests/models/test_mixtral.py
Normal file
57
llm/tests/models/test_mixtral.py
Normal file
@@ -0,0 +1,57 @@
|
||||
import torch
|
||||
import pytest
|
||||
from llm.models.mixtral.mixtral import Mixtral
|
||||
|
||||
@pytest.fixture
|
||||
def config():
|
||||
return {
|
||||
"vocab_size": 100,
|
||||
"embed_dim": 32,
|
||||
"num_q_heads": 4,
|
||||
"num_kv_heads": 2,
|
||||
"num_layers": 2,
|
||||
"max_position_embeddings": 16,
|
||||
"window_size": 8,
|
||||
"dropout": 0.0,
|
||||
"num_experts": 4,
|
||||
"top_k_experts": 2,
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def model(config):
|
||||
return Mixtral(config)
|
||||
|
||||
def test_forward_basic(model):
|
||||
x = torch.randint(0, 100, (2, 8))
|
||||
logits, cache = model(x)
|
||||
assert logits.shape == (2, 8, 100)
|
||||
assert isinstance(cache, list)
|
||||
assert len(cache) == model._decoders.__len__()
|
||||
|
||||
def test_forward_with_cache(model):
|
||||
x = torch.randint(0, 100, (2, 4))
|
||||
logits, cache = model(x, use_cache=True)
|
||||
x2 = torch.randint(0, 100, (2, 1))
|
||||
logits2, cache2 = model(x2, use_cache=True, cache=cache)
|
||||
assert logits2.shape == (2, 1, 100)
|
||||
assert isinstance(cache2, list)
|
||||
|
||||
def test_generate_and_shape(model):
|
||||
x = torch.randint(0, 100, (1, 5))
|
||||
result = model.generate(x, max_new_tokens=3, do_sample=False)
|
||||
assert result.shape == (1, 8)
|
||||
|
||||
def test_forward_sequence_too_long(model, config):
|
||||
x = torch.randint(0, 100, (1, config["max_position_embeddings"] + 1))
|
||||
with pytest.raises(ValueError):
|
||||
model(x)
|
||||
|
||||
def test_generate_with_sampling_topk(model):
|
||||
x = torch.randint(0, 100, (1, 3))
|
||||
out = model.generate(x, max_new_tokens=2, do_sample=True, top_k=5)
|
||||
assert out.shape == (1, 5)
|
||||
|
||||
def test_generate_with_sampling_topp(model):
|
||||
x = torch.randint(0, 100, (1, 3))
|
||||
out = model.generate(x, max_new_tokens=2, do_sample=True, top_p=0.8)
|
||||
assert out.shape == (1, 5)
|
||||
Reference in New Issue
Block a user