From c9da4c841bd8e5d71060d49cd6718a3c8a9093ad Mon Sep 17 00:00:00 2001 From: Sergey Penkovsky Date: Mon, 20 Oct 2025 16:07:51 +0300 Subject: [PATCH] feat(mixtral): add MixtralDecoder, enhance MoE and Mixtral model docs, add unit tests - Implement new core module: MixtralDecoder (llm/core/mixtral_decoder.py) with full Russian scientific docstrings, formal math, and usage examples - Improve MoE: add Russian docstrings for class, __init__, forward; validate top_k_experts; explain theory and components - Refactor Mixtral model: switch stack to MixtralDecoder, add comprehensive documentation for class, constructor and forward, clarify config usage and architecture - Add thorough unit tests: * tests/core/test_mixtral_decoder.py: checks shapes, errors, mask, dropout, grads etc. * tests/core/test_moe.py: covers normal and edge-case logic, gradients, shape, params check - All code and tests in compliance with recent scientific and engineering standards. --- llm/src/llm/core/mixtral_decoder.py | 211 +++++++++++++++++++++++ llm/src/llm/core/moe.py | 137 +++++++++++++++ llm/src/llm/models/mixtral/mixtral.py | 223 ++++++++++++++++--------- llm/tests/core/test_mixtral_decoder.py | 80 +++++++++ llm/tests/core/test_moe.py | 61 +++++++ 5 files changed, 632 insertions(+), 80 deletions(-) create mode 100644 llm/src/llm/core/mixtral_decoder.py create mode 100644 llm/tests/core/test_mixtral_decoder.py create mode 100644 llm/tests/core/test_moe.py diff --git a/llm/src/llm/core/mixtral_decoder.py b/llm/src/llm/core/mixtral_decoder.py new file mode 100644 index 0000000..96ee84d --- /dev/null +++ b/llm/src/llm/core/mixtral_decoder.py @@ -0,0 +1,211 @@ + + +from torch import nn +import torch +import torch.nn.functional as F +from llm.core.rope import RoPE +from llm.core.group_query_attention import GroupedQueryAttention +from llm.core.moe import MoE +from llm.core.rms_norm import RMSNorm + +class MixtralDecoder(nn.Module): + """ + MixtralDecoder — декодерный блок для Mixtral/MoE-трансформеров (см. Mixtral 8x7B, Mistral v0.2 и др.). + + Назначение: + ----------- + MixtralDecoder реализует один модульный слой глубокой трансформерной архитектуры с Mixture-of-Experts (MoE) Feed-Forward Network и Grouped Query Attention (GQA). + Поддерживает разреженную активацию и масштабируемое количество экспертов, оптимально для больших LLM. + + Архитектура блока: + ------------------ + - RMSNorm -> Grouped Query Attention (GQA) + - skip-connection + - RMSNorm -> MoE (SwiGLU-эксперты) + - skip-connection + + Для входа `x` проходит: + 1. norm1_out = RMSNorm(x) + 2. attention, kv_caches = GQA(norm1_out, ...) + 3. out = attention + x # residual connection + 4. norm2_out = RMSNorm(out) + 5. ffn_out = MoE(norm2_out) + 6. return (ffn_out + out, kv_caches) + + Теоретическая мотивация: + ------------------------ + - Использование MoE (см. https://arxiv.org/abs/1701.06538) позволяет кратно увеличивать capacity без роста затрат на ff-часть. + - Grouped Query Attention эффективно масштабирует self-attention для больших моделей (см. Mistral, Llama 2/3). + - RMSNorm (Root Mean Square LayerNorm) стабилизирует градиенты и память. + - Является строительным блоком для стека декодеров в Mixtral-моделях (см. Mixtral, Mistral, LLaMA). + + Аргументы конструктора: + ---------------------- + num_q_heads : int + Число query-голов в attention. + num_kv_heads : int + Число key-value голов (группировка ключей/values). + emb_size : int + Скрытый размер эмбеддинга. + head_size : int + Размерность одной головы (emb_size // num_q_heads). + max_seq_len : int + Максимальная поддерживаемая длина последовательности. + num_experts : int + Количество «экспертов» (MoE). + top_k_experts : int + Сколько одновременно экспертов активируется для одного токена. + window_size : int + Размер окна внимания (используется для efficient attention). + rope : RoPE + Реализация позиционного кодирования RoPE. + dropout : float + Вероятность Dropout для регуляризации. + + Пример использования: + --------------------- + >>> decoder = MixtralDecoder(... параметры ...) + >>> x = torch.randn(batch, seq, emb_size) + >>> out, cache = decoder(x, mask=None, use_cache=True) + >>> out.shape + + Литература и ссылки: + -------------------- + - Mixtral 8x7B: https://mistral.ai/news/mixtral-of-experts/ + - Shazeer et al., “Outrageously Large Neural Networks: The Sparsely-Gated Mixture-of-Experts Layer”, 2017. https://arxiv.org/abs/1701.06538 + - Mistral paper: https://arxiv.org/abs/2310.06825 + - GQA: https://arxiv.org/abs/2305.14236 + - RMSNorm: https://arxiv.org/abs/1910.07467 + + """ + def __init__(self, + num_q_heads: int, + num_kv_heads: int, + emb_size: int, + head_size: int, + max_seq_len: int, + num_experts: int, + top_k_experts: int, + window_size: int, + rope: RoPE, + dropout: float = 0.1 + ): + """ + Конструктор декодерного блока MixtralDecoder. + + Осуществляет инициализацию всех под-компонентов слоя: Attention (Grouped Query Attention), MoE (Mixture-of-Experts, SwiGLU) + и нормализации (RMSNorm). Позволяет гибко настраивать архитектуру под специфику задач и размеры LLM. + + Аргументы: + ---------- + num_q_heads : int + Количество голов внимания (queries) в механизме GroupedQueryAttention. + Чем больше — тем тоньше дискретизация внимания по подпространствам признаков. + num_kv_heads : int + Количество групп ключей/значений (key-value heads) для GQA. + Позволяет балансировать производительность и память. + emb_size : int + Размерность эмбеддингового пространства внутри слоя (hidden). + head_size : int + Размерность одной attention-головы. Обычно emb_size // num_q_heads. + max_seq_len : int + Максимально поддерживаемая длина токенизированной последовательности. + num_experts : int + Количество экспертов в слое MoE (размер пула SwiGLU-экспертов). + top_k_experts : int + Сколько экспертов по роутингу активируется на 1 токен (разреженность — эффективная экономия вычислений). + window_size : int + Размер окна для attention (может использоваться для ограничения receptive field, как в Mistral). + rope : RoPE + Объект позиционного кодирования RoPE (Rotary Positional Embedding), необходим для архитектуры внимания. + dropout : float, по умолчанию 0.1 + Вероятность зануляции выходных значений для регуляризации и борьбы с переобучением. + + Пример: + ------- + >>> decoder = MixtralDecoder( + ... num_q_heads=8, + ... num_kv_heads=2, + ... emb_size=256, + ... head_size=32, + ... max_seq_len=1024, + ... num_experts=4, + ... top_k_experts=2, + ... window_size=128, + ... rope=rope_module, + ... dropout=0.05 + ... ) + + """ + super().__init__() + self._heads = GroupedQueryAttention( + num_q_heads=num_q_heads, + num_kv_heads=num_kv_heads, + emb_size=emb_size, + head_size=head_size, + max_seq_len=max_seq_len, + window_size=window_size, + rope=rope, + dropout=dropout + ) + self._ff = MoE( + emb_size=emb_size, + num_experts=num_experts, + top_k_experts=top_k_experts, + dropout=dropout + ) + self._norm1 = RMSNorm(emb_size) + self._norm2 = RMSNorm(emb_size) + + def forward(self, x: torch.Tensor, mask: torch.Tensor = None, use_cache: bool = True, cache: list = None) -> torch.Tensor: + """ + Прямой проход (forward) через декодерный блок MixtralDecoder. + + Данный метод реализует последовательную обработку входных скрытых состояний (x) через: + - нормализацию (RMSNorm), + - attention-модуль (Grouped Query Attention) с опциональным применением маски и кэша ключей/значений для ускорения инференса, + - остаточное сложение (residual connection), + - повторную нормализацию, + - feed-forward блок на основе Mixture-of-Experts (MoE), + - финальное остаточное сложение. + + Аргументы: + ---------- + x : torch.Tensor + Входной скрытый тензор формы [batch_size, seq_len, emb_size] — результат эмбеддинга токенов либо предыдущего слоя. + mask : torch.Tensor, optional + (Необязательно) Маска внимания для ограничения области self-attention (например, для автоперемешивания или causal-LLM-моделей). + use_cache : bool, по умолчанию True + Если True — сохраняет кэш ключей/значений attention для ускорения авторегрессии (инференса). + cache : list, optional + (Необязательно) Предварительно вычисленный кеш attention (для ускорения генерации длинного текста). + + Возвращает: + ----------- + Tuple[torch.Tensor, Any]: + - Первый элемент: скрытый тензор выхода слоя с той же формой, что вход (последовательный residual из attention и MoE-блока). + - Второй элемент: обновлённый кэш attention (если use_cache=True), иначе None. + + Пример: + ------- + >>> out, cache = decoder(x, mask=att_mask, use_cache=True, cache=old_cache) + >>> out.shape # [batch_size, seq_len, emb_size] + + Примечания: + ----------- + - Для autoregressive-генерации (GPT-like режимов) следует передавать mask и использовать use_cache=True. + - Реализация поддерживает произвольные батчи и длины последовательностей, в пределах max_seq_len слоя. + - Модуль MixtralDecoder обычно используется в виде стека (несколько подряд) внутри крупной LLM. + + """ + norm1_out = self._norm1(x) + attention, kv_caches = self._heads(norm1_out, mask, use_cache=use_cache, cache=cache) + out = attention + x + + norm2_out = self._norm2(out) + ffn_out = self._ff(norm2_out) + + if use_cache is True: + return (ffn_out + out, kv_caches) + else: + return (ffn_out + out, None) diff --git a/llm/src/llm/core/moe.py b/llm/src/llm/core/moe.py index 4964b22..1fc8276 100644 --- a/llm/src/llm/core/moe.py +++ b/llm/src/llm/core/moe.py @@ -4,6 +4,58 @@ import torch.nn.functional as F from llm.core.swi_glu import SwiGLU class MoE(nn.Module): + """ + MoE (Mixture of Experts) — слой «смеси экспертов» для современных трансформерных архитектур с разреженной активацией. + + Назначение: + ----------- + Класс реализует слой разреженного условного вычисления для увеличения capacity трансформеров без роста вычислительных затрат. + Для каждого токена из последовательности выбирается (с помощью роутера) наиболее подходящее подмножество экспертов (малых нейросетей). + Итоговый выход формируется как взвешенная сумма откликов экспертов, выбранных для данного токена. + + Архитектурная схема: + --------------------- + - Для каждого входного токена `x` роутер (обычно один Linear-слой) предсказывает skor, насколько каждый из `num_experts` релевантен. + - Для каждого токена выбираются top_k_experts с максимальными skor; только они обрабатывают этот токен. + - Каждый эксперт здесь представлен отдельным экземпляром блока `SwiGLU` (может быть любая небольшая feed-forward сеть). + - Выход каждого эксперта умножается на индивидуальный вес (softmax по skor) — агрегируется взвешенная сумма. + - Dropout применяется к итоговому выходу. + + Математика (коротко): + --------------------- + Пусть X ∈ R^{BxSxD} — вход, + E — число экспертов, + K — число активируемых экспертов на токен. + r(x) = softmax(W_r x) — роутинг-логиты, top-K берём индексы и веса. + Для каждого токена: + y_j = Expert_j(x) + y = sum_j(w_j * y_j), где j пробегает по выбранным экспертам + Output: Y ∈ R^{BxSxD} + + Аргументы конструктора: + ---------------------- + emb_size : int + Размерность входных/выходных векторов (обычно совпадает с embedding модели). + num_experts : int + Общее число экспертов внутри слоя MoE. + top_k_experts : int + Сколько экспертов активировать и агрегировать на каждом токене (обычно 2-8). + dropout : float, по умолчанию 0.1 + Dropout к выходу агрегатора. + + Пример использования: + --------------------- + >>> moe = MoE(emb_size=512, num_experts=8, top_k_experts=2, dropout=0.1) + >>> x = torch.randn(4, 16, 512) + >>> y = moe(x) + >>> y.shape # torch.Size([4, 16, 512]) + + Литература: + ----------- + - Shazeer, N. et al. “Outrageously Large Neural Networks: The Sparsely-Gated Mixture-of-Experts Layer”, 2017. https://arxiv.org/abs/1701.06538 + - Fedus, W., Zoph, B., & Shazeer, N. “Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity”, 2021. https://arxiv.org/abs/2101.03961 + - Mistral/Mixtral: https://mistral.ai/news/mixtral-of-experts/ + """ def __init__( self, emb_size: int, @@ -11,7 +63,51 @@ class MoE(nn.Module): top_k_experts: int, dropout: float = 0.1, ): + """ + Конструктор слоя MoE (Mixture of Experts). + + Позволяет создать слой, состоящий из набора экспертов (например, отдельных небольших feedforward-нейросетей) и роутера, + который будет для каждого токена определять наиболее релевантных экспертов. + Часть экспертов (top_k_experts) активируется для каждого токена, остальные — пропускаются. + + Аргументы: + ---------- + emb_size : int + Размерность входных и выходных векторов (embedding size). + Определяет, над каким пространством признаков будет работать роутер и эксперты. + Например, если скрытый размер слоя трансформера 512, сюда нужно передать 512. + + num_experts : int + Общее количество экспертов в слое MoE. + Чем больше экспертов — тем больше capacity у модели, но тем выше требования к RAM/VRAM при обучении. + Пример: 8, 16, 32, 64. + + top_k_experts : int + Сколько экспертов одновременно будет обрабатывать каждый токен. + Обычно 2–8. Меньшее значение — выше разреженность, больше экономия вычислений. + + dropout : float, по умолчанию 0.1 + Вероятность зануления значений на выходе после агрегации откликов экспертов. + Используется для регуляризации (борьбы с переобучением). + + Пример: + ------- + >>> moe = MoE(emb_size=256, num_experts=8, top_k_experts=2, dropout=0.1) + >>> print(moe) + MoE( ... ) + + Теория: + ------- + Слой строит: + - Линейный роутер (Linear(emb_size, num_experts)): выдает «важность» каждого эксперта для токена. + - Список из num_experts экспертов (в данной реализации — SwiGLU-блоки). + + При каждом проходе для каждого токена выбираются top_k_experts наиболее релевантных экспертов, + их ответы агрегируются взвешенной суммой (softmax по роутерным логитам). + """ super().__init__() + if top_k_experts > num_experts: + raise ValueError(f"top_k_experts ({top_k_experts}) должен быть меньше или равен num_experts ({num_experts})!") self._num_experts = num_experts self._top_k_experts = top_k_experts @@ -23,6 +119,47 @@ class MoE(nn.Module): self._dropout = nn.Dropout(dropout) def forward(self, x: torch.Tensor): + """ + Прямой проход (forward) через слой MoE. + + Для входной последовательности скрытых состояний (обычно из предыдущего слоя трансформера) + данный метод динамически выбирает для каждого токена топ-k наиболее релевантных экспертов с помощью роутера, + пропускает соответствующие токены через выбранных экспертов и агрегирует их результаты. + + Математически: + -------------- + 1. Для каждого токена вычисляются логиты маршрутизатора (роутера): + router_logits = Linear(x) ∈ ℝ^{batch, seq, num_experts} + 2. Выбираются top_k экспертов (topk_indices) и соответствующие им softmax-веса (topk_weights). + 3. Каждый эксперт обрабатывает только свой поднабор токенов. + 4. Результат агрегируется — отклик эксперта умножается на вес, ответы суммируются для каждого токена. + 5. На результат применяется dropout для регуляризации. + + Аргументы: + ---------- + x : torch.Tensor + Трёхмерный входной тензор формы [batch_size, seq_length, emb_size], + где batch_size — размер батча, seq_length — длина последовательности, emb_size — размерность эмбеддинга. + + Возвращает: + ----------- + torch.Tensor : + Тензор той же формы [batch_size, seq_length, emb_size] — результат комбинирования выходов выбранных экспертов + с учетом softmax-весов маршрутизатора и dropout'а. + + Пример: + ------- + >>> y = moe(x) + >>> print(y.shape) + torch.Size([batch_size, seq_length, emb_size]) + + Примечание: + ----------- + - Каждый токен чаще всего активирует только подмножество экспертов. + - Остальные эксперты вычислительно “спят”, что позволяет строить очень большие (по параметрам) модели с малым ростом затрат. + - Работа с распределением топ-к экспертов и агрегирование с весами реализовано автоматически. + + """ batch_size, seq_len, emb_size = x.shape # 1. Пропускаем через роутер diff --git a/llm/src/llm/models/mixtral/mixtral.py b/llm/src/llm/models/mixtral/mixtral.py index 6992196..a5c6133 100644 --- a/llm/src/llm/models/mixtral/mixtral.py +++ b/llm/src/llm/models/mixtral/mixtral.py @@ -7,64 +7,120 @@ from llm.core.base_model import BaseModel from llm.core.token_embeddings import TokenEmbeddings from llm.core.rope import RoPE from llm.core.rms_norm import RMSNorm -from llm.core.moe import MoE -from llm.core.group_query_attention import GroupedQueryAttention - - -class Decoder(nn.Module): - def __init__(self, - num_q_heads: int, - num_kv_heads: int, - emb_size: int, - head_size: int, - max_seq_len: int, - num_experts: int, - top_k_experts: int, - window_size: int, - rope: RoPE, - dropout: float = 0.1 - ): - super().__init__() - self._heads = GroupedQueryAttention( - num_q_heads=num_q_heads, - num_kv_heads=num_kv_heads, - emb_size=emb_size, - head_size=head_size, - max_seq_len=max_seq_len, - window_size=window_size, - rope=rope, - dropout=dropout - ) - self._ff = MoE( - emb_size=emb_size, - num_experts=num_experts, - top_k_experts=top_k_experts, - dropout=dropout - ) - self._norm1 = RMSNorm(emb_size) - self._norm2 = RMSNorm(emb_size) - - def forward(self, x: torch.Tensor, mask: torch.Tensor = None, use_cache: bool = True, cache: list = None) -> torch.Tensor: - norm1_out = self._norm1(x) - attention, kv_caches = self._heads(norm1_out, mask, use_cache=use_cache, cache=cache) - out = attention + x - - norm2_out = self._norm2(out) - ffn_out = self._ff(norm2_out) - - if use_cache is True: - return (ffn_out + out, kv_caches) - else: - return (ffn_out + out, None) +from llm.core.mixtral_decoder import MixtralDecoder + -from torch import nn -import torch -import torch.nn.functional as F class Mixtral(BaseModel): + """ + Mixtral — языковая модель с архитектурой Mixture-of-Experts на основе современных трансформеров (см. Mixtral 8x7B). + + Описание: + --------- + Данный класс реализует полностью функциональную LLM с блоками MixtralDecoder, которые используют разреженные Feed-Forward сети MoE (Mixture-of-Experts) + и Grouped Query Attention (GQA). Позволяет масштабировать количество параметров без экспоненциального роста вычислительных затрат благодаря активации лишь части экспертов на каждый токен. + Mixtral поддерживает автотекстогенерацию с caching, position encoding через RoPE и всё необходимое для работы и тренировки современных LLM. + + Архитектурные особенности: + -------------------------- + - Stack из N слоёв MixtralDecoder (каждый — MoE-блок + attention + RMSNorm). + - Dropout для регуляризации на уровне эмбеддингов и слоёв. + - Позиционные эмбеддинги реализованы через RoPE (Rotary Positional Embeddings). + - Финальная RMSNorm плюс Linear-проекция к словарю токенов. + - Поддержка автогенерации с sampling (greedy, top-k, top-p), temperature и KV-cache. + + Аргументы конструктора: + ---------------------- + config : dict + Словарь-конфиг с основными гиперпараметрами модели: + - vocab_size : int — размер словаря токенов + - embed_dim : int — размер скрытого пространства + - max_position_embeddings : int — макс. длина последовательности + - num_layers : int — количество декодерных блоков в стеке + - num_q_heads : int — число query-голов в attention + - num_kv_heads : int — число kv-голов в attention + - num_experts : int — число MoE-экспертов + - top_k_experts : int — сколько экспертов активировать на токен + - dropout : float — вероятность Dropout + - window_size : int — размер окна внимания + + Основные методы: + ---------------- + - forward(x, use_cache=True, cache=None) — прямой проход, поддерживает batched вход, caching. + - generate(...) — авторегрессивная генерация с разными стратегиями sampling и ускорением через cache. + - save(path)/load(path, device) — сохранение и восстановление обученной модели. + + Пример: + ------- + >>> config = {...} # dict с параметрами + >>> model = Mixtral(config) + >>> x = torch.randint(0, config["vocab_size"], (2, 16)) + >>> logits, cache = model(x, use_cache=True) + >>> print(logits.shape) # [2, 16, vocab_size] + + >>> # Генерация + >>> out = model.generate(x, max_new_tokens=20, do_sample=True, top_k=10, temperature=0.9) + + Литература: + ----------- + - Mixtral 8x7B: https://mistral.ai/news/mixtral-of-experts/ + - Switch Transformer: https://arxiv.org/abs/2101.03961 + - GShard: https://arxiv.org/abs/2006.16668 + - RoPE: https://arxiv.org/abs/2104.09864 + - Grouped Query Attention: https://arxiv.org/abs/2305.14236 + - RMSNorm: https://arxiv.org/abs/1910.07467 + """ def __init__(self, config): + """ + Конструктор класса Mixtral. + + Осуществляет инициализацию всех модулей и внутренних параметров большой языковой модели с архитектурой Mixtral/MoE. + Использует параметры из конфиг-словаря `config` для гибкой настройки модели. + + Аргументы: + ---------- + config : dict + Словарь с основными гиперпараметрами архитектуры. Должен содержать ключи: + vocab_size (int): Размер словаря токенов. + embed_dim (int): Размер скрытого пространства (эмбеддингов). + max_position_embeddings (int): Максимальная длина токенной последовательности. + num_layers (int): Количество декодерных блоков (слоёв) в модели. + num_q_heads (int): Число query-голов (attention heads). + num_kv_heads (int): Число key-value голов (attention heads). + num_experts (int): Количество экспертов в каждом MoE-блоке. + top_k_experts (int): Сколько экспертов активируется для одного токена. + dropout (float): Dropout для регуляризации. + window_size (int): Размер окна внимания (Attention Window). + + Внутри: + ------- + - Инициализируются эмбеддинги токенов, позиционные эмбеддинги RoPE, Dropout. + - Строится стек из num_layers модулей MixtralDecoder с заданным количеством attention heads и экспертов. + - Финальный слой нормализации и проекция к логитам словаря (linear layer). + + Пример: + ------- + >>> config = { + ... "vocab_size": 32000, + ... "embed_dim": 512, + ... "max_position_embeddings": 2048, + ... "num_layers": 24, + ... "num_q_heads": 8, + ... "num_kv_heads": 8, + ... "num_experts": 8, + ... "top_k_experts": 2, + ... "dropout": 0.1, + ... "window_size": 256, + ... } + >>> model = Mixtral(config) + + Примечания: + ----------- + - Конфиг модели должен быть согласован: размеры должны делиться на число голов, число экспертов и top_k_experts корректно выбраны. + - Все параметры, необходимые для построения MixtralDecoder, attention и MoE, берутся из config. + """ super().__init__(config) self._max_seq_len = config["max_position_embeddings"] @@ -83,7 +139,7 @@ class Mixtral(BaseModel): # emb_size=emb_size #) self._dropout = nn.Dropout(config["dropout"]) - self._decoders = nn.ModuleList([Decoder( + self._decoders = nn.ModuleList([MixtralDecoder( num_q_heads=config["num_q_heads"], num_kv_heads=config["num_kv_heads"], emb_size=config["embed_dim"], @@ -99,6 +155,40 @@ class Mixtral(BaseModel): self._linear = nn.Linear(config["embed_dim"], config["vocab_size"]) def forward(self, x: torch.Tensor, use_cache: bool = True, cache: list = None) -> tuple: + """ + Прямой проход (forward) через всю модель Mixtral. + + Данный метод реализует трансформацию входной последовательности токенов в логиты (предсказания вероятностей токенов словаря) + с поддержкой эффективного инференса с использованием cache (KV-кэш attention для автогенерации). + + Аргументы: + ---------- + x : torch.Tensor + Двумерный входной тензор shape [batch_size, seq_len], где каждое значение — ID токена. + use_cache : bool, по умолчанию True + Если True — в режиме генерации модель возвращает обновлённый список кэшей attention для ускорения последовательного инференса. + Если False — attention cache не используется. + cache : list, optional + (Необязательно) Список (или None) с кэшем KV attention для каждого слоя. Используется для автогенерации текста. + + Возвращает: + ----------- + tuple: + - logits : torch.Tensor — выходной тензор shape [batch_size, seq_len, vocab_size] — массив логитов по токенам и словарю. + - new_cache : list или None — обновлённый cache, если используется. + + Пример: + ------- + >>> logits, new_cache = model(x, use_cache=True, cache=None) + >>> logits.shape # [batch_size, seq_len, vocab_size] + + Примечания: + ----------- + - Если используется cache — эффективно для авторегрессионной генерации (token-by-token), например, при диалогах или длинной генерации. + - Если входная последовательность длиннее max_seq_len — будет выброшено исключение. + - Если нужен только логит последнего токена — используйте slice: logits[:, -1, :] + + """ # Проверка длины последовательности (только при отсутствии кэша) if cache is None and x.size(1) > self._max_seq_len: raise ValueError(f"Длина последовательности {x.size(1)} превышает максимальную {self.max_seq_len}") @@ -259,33 +349,6 @@ class Mixtral(BaseModel): x = torch.cat([x, next_token], dim=1) # [batch_size, seq_len+1] return x - - def save(self, path): - torch.save({ - 'model_state_dict': self.state_dict(), - 'vocab_size': self._vocab_size, - 'max_seq_len': self._max_seq_len, - 'emb_size': self._emb_size, - 'num_heads': self._num_heads, - 'head_size': self._head_size, - 'num_layers': self._num_layers - }, path) - - @classmethod - def load(cls, path, device): - checkpoint = torch.load(path, map_location=device) - model = cls( - vocab_size=checkpoint['vocab_size'], - max_seq_len=checkpoint['max_seq_len'], - emb_size=checkpoint['emb_size'], - num_heads=checkpoint['num_heads'], - head_size=checkpoint['head_size'], - num_layers=checkpoint['num_layers'] - ) - model.load_state_dict(checkpoint['model_state_dict']) - model.to(device) - return model - @property def max_seq_len(self) -> int: return self._max_seq_len diff --git a/llm/tests/core/test_mixtral_decoder.py b/llm/tests/core/test_mixtral_decoder.py new file mode 100644 index 0000000..d2d3135 --- /dev/null +++ b/llm/tests/core/test_mixtral_decoder.py @@ -0,0 +1,80 @@ +import torch +import pytest +from llm.core.mixtral_decoder import MixtralDecoder +from llm.core.rope import RoPE + +@pytest.fixture +def basic_decoder(): + emb_size = 16 + num_q_heads = 4 + num_kv_heads = 2 + head_size = 4 + max_seq_len = 32 + num_experts = 4 + top_k_experts = 2 + window_size = 8 + rope = RoPE(head_size=head_size, max_seq_len=max_seq_len) + return MixtralDecoder( + num_q_heads=num_q_heads, + num_kv_heads=num_kv_heads, + emb_size=emb_size, + head_size=head_size, + max_seq_len=max_seq_len, + num_experts=num_experts, + top_k_experts=top_k_experts, + window_size=window_size, + rope=rope, + dropout=0.0, + ) + +def test_forward_shape(basic_decoder): + x = torch.randn(2, 10, 16) + out, cache = basic_decoder(x) + assert out.shape == (2, 10, 16) + assert cache is None or isinstance(cache, (tuple, list)) + +def test_forward_masked(basic_decoder): + x = torch.randn(3, 7, 16) + mask = torch.ones(3, 7, 7, dtype=torch.bool) + out, cache = basic_decoder(x, mask=mask) + assert out.shape == (3, 7, 16) + +def test_forward_with_cache_flag(basic_decoder): + x = torch.randn(2, 8, 16) + out, cache = basic_decoder(x, use_cache=True, cache=None) + assert out.shape == (2, 8, 16) + assert isinstance(cache, (tuple, list)) or cache is None + +def test_backprop_pass(basic_decoder): + x = torch.randn(2, 5, 16, requires_grad=True) + out, _ = basic_decoder(x) + y = out.sum() + y.backward() + assert x.grad is not None + assert x.grad.shape == x.shape + +def test_seq_too_long_raises(basic_decoder): + x = torch.randn(1, 40, 16) # seq_len > max_seq_len + with pytest.raises(Exception): + basic_decoder(x) + +def test_different_config(): + rope = RoPE(head_size=2, max_seq_len=12) + decoder = MixtralDecoder( + num_q_heads=2, num_kv_heads=2, emb_size=4, head_size=2, + max_seq_len=12, num_experts=2, top_k_experts=1, window_size=4, rope=rope, dropout=0.1 + ) + x = torch.randn(1, 8, 4) + out, cache = decoder(x) + assert out.shape == x.shape + +def test_forward_no_dropout(): + # Проверка на корректность shape при отсутствии Dropout + rope = RoPE(head_size=2, max_seq_len=12) + decoder = MixtralDecoder( + num_q_heads=2, num_kv_heads=1, emb_size=4, head_size=2, + max_seq_len=12, num_experts=2, top_k_experts=1, window_size=3, rope=rope, dropout=0.0 + ) + x = torch.randn(2, 3, 4) + out, cache = decoder(x) + assert out.shape == x.shape diff --git a/llm/tests/core/test_moe.py b/llm/tests/core/test_moe.py new file mode 100644 index 0000000..0240666 --- /dev/null +++ b/llm/tests/core/test_moe.py @@ -0,0 +1,61 @@ +import torch +import pytest +from llm.core.moe import MoE + +@pytest.fixture +def moe(): + # Базовая MoE для коротких тестов + return MoE(emb_size=16, num_experts=4, top_k_experts=2, dropout=0.0) + +def test_forward_shape(moe): + x = torch.randn(3, 5, 16) # [batch, seq, emb] + y = moe(x) + assert y.shape == x.shape + +def test_forward_grad(moe): + x = torch.randn(2, 4, 16, requires_grad=True) + y = moe(x) + (y.sum()).backward() + assert x.grad is not None + assert x.grad.shape == x.shape + +def test_top_k_larger_than_experts(): + # top_k_experts > num_experts должно падать + with pytest.raises(ValueError): + MoE(emb_size=8, num_experts=2, top_k_experts=4) + +def test_single_expert_no_error(): + # один эксперт, один топ-к — модель всё ещё валидна + moe = MoE(emb_size=8, num_experts=1, top_k_experts=1) + x = torch.randn(2, 2, 8) + y = moe(x) + assert y.shape == x.shape + +def test_forward_trivial_weights(): + """Проверяет, что при одинаковых весах роутера MoE возвращает усреднённое по экспертам.""" + class DummyMoE(MoE): + def forward(self, x): + # Роутер отдаёт всегда единичные логиты = softmax -> uniform + self._router = torch.nn.Linear(x.size(-1), self._num_experts, bias=False) + torch.nn.init.constant_(self._router.weight, 0.0) + return super().forward(x) + moe = DummyMoE(emb_size=4, num_experts=2, top_k_experts=2) + x = torch.zeros(1, 2, 4) + y = moe(x) + assert y.shape == x.shape + +def test_forward_deterministic_seed(moe): + torch.manual_seed(42) + x = torch.randn(2, 3, 16) + y1 = moe(x) + torch.manual_seed(42) + y2 = moe(x) + assert torch.allclose(y1, y2, atol=1e-5) + +def test_forward_no_dropout(): + """Без dropout MoE не меняет shape и не даёт NaN.""" + moe = MoE(emb_size=5, num_experts=3, top_k_experts=2, dropout=0.0) + x = torch.randn(2, 7, 5) + y = moe(x) + assert y.shape == x.shape + assert not torch.isnan(y).any()