Files
llm-arch-research/notebooks/mixstral.ipynb
Sergey Penkovsky b1737bbce2 feat(mixtral): initial implementation of Mixtral MoE model, configs, and tests
- Add Mixtral architecture implementation with MoE support (llm/src/llm/models/mixtral/mixtral.py)
- Introduce generic Mixture-of-Experts (MoE) block (llm/src/llm/core/moe.py)
- Create dedicated configuration files for Mixtral training and generation experiments
- Register and test Mixtral support in experiment runner (run_llm_experiment.py)
- Add unit tests for Mixtral API including forward, caching, and generation modes
- Include Jupyter notebook mixstral.ipynb for architectural exploration and research
- Ensure correct handling of torch bool masks in sampling (top-k, top-p) during generation

BREAKING CHANGE: Adds new model code and test coverage, modifying experiment runner logic to register Mixtral.
2025-10-20 08:12:11 +03:00

1511 lines
70 KiB
Plaintext
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"cells": [
{
"cell_type": "markdown",
"id": "c3276440",
"metadata": {},
"source": [
"# **Mixtral**\n",
"\n",
"<p align=\"center\">\n",
" <img src=\"https://ucarecdn.com/b3eb5219-4c5a-4cc1-9572-b072d262c2c1/\" alt=\"Mixtral\" width=\"1000\" height=\"165\">\n",
"</p>\n",
"\n",
"Mixtral 8x7B вышел в декабре 2023.\n",
"\n",
"В нем сохранились все архитектурные особенности первого Mistral: RoPE, SwiGLU, GQA и SWA. И добавилась одна суперзначимая новинка — Mixture-of-Experts (MoE). Основное предназначение MoE — ускорение инференса для очень крупных моделей (с большим багажом знаний).\n",
"\n",
"Технически MoE меняет облик FeedForward слоя, добавляя в него экспертов — по сути, кучу тех же FeedForward сетей. Но эти эксперты работают не все сразу. Специальная подпрограмма — роутер — решает, кого из них и когда вызывать. Тем самым мы экономим на инференсе.\n",
"\n",
"В Mixtral 8x7B было 8 экспертов, но в единицу времени работали только 2 из них. Общий вес модели — 47B параметров, но из-за того, что задействовано только 2 эксперта, по скорости инференс получается как у примерно 13B моделей. А качество сравнимо с 70B моделями (например, Llama 2 70B).\n",
"\n",
"В дальнейшем эта новинка станет очень популярна в больших языковых моделях.\n"
]
},
{
"cell_type": "markdown",
"id": "e8dd2f09",
"metadata": {},
"source": [
"# Mixture-of-Experts\n",
"\n",
"<p align=\"center\">\n",
" <img src=\"https://ucarecdn.com/ce1ed721-579e-4b24-b9dc-1877051d246b/\" alt=\"MoE\" width=\"1000\" height=\"576\">\n",
"</p>\n",
"\n",
"Mixture-of-Experts (MoE) — это архитектура нейронных сетей, которая позволяет масштабировать модели до огромных размеров (триллионы параметров) без пропорционального увеличения вычислительных затрат. Достигается это за счет замены FeedForward слоя на слой MoE.\n",
"\n",
"Стандартная полносвязная нейросеть (Feed-Forward Network, FFN) в трансформере обрабатывает каждый входной токен одним и тем же набором весов. В MoE эта единственная сеть заменяется набором «‎экспертов» — небольших независимых нейросетей (обычно тех же FNN). И при проходе для каждого токена система динамически выбирает, каким экспертом или экспертами его обработать. В результате мы получаем два основных выигрыша:\n",
"\n",
"* Снижение нагрузки на систему и, как следствие, ускорение инференса.\n",
"* Повышение когнитивных способностей системы: потому что в каждый момент времени над токеном работают наиболее подходящие для этого эксперты.\n",
"\n",
"> Идея с MoE рано или поздно должна была появиться. На FNN приходится примерно 2/3 всех параметров современных LLM — это очень существенно. И причина этого известна: в FNN хранится основная часть знаний модели. И попытка сэкономить на таком большом количестве параметров является очевидным шагом.\n",
"\n",
"### Алгоритм\n",
"\n",
"MoE состоит из трех основных компонентов: маршрутизатора, экспертов и процедуры взвешивания. Работают они следующим образом:\n",
"\n",
"* На вход поступает тензор `x`.\n",
"* Маршрутизатор (Router или Gating Network) — это небольшая нейросеть. Часто просто один линейный слой. Функция маршрутизатора — решать, каким экспертам обработать `x`. Для этого по каждому токену маршрутизатор выдает: топ-k экспертов и их веса.\n",
"* Эксперты (Experts) — это несколько идентичных по структуре, но разных по весам нейронных сетей. Обычно в их роли выступают стандартные FFN-слои. Эксперты параллельны и не зависят друг от друга. Каждый токен из `x` пропускается только через назначенных ему экспертов — эксперты возвращают результат.\n",
"* Выходной тензор — это взвешенная сумма выходов всех экспертов. Веса для взвешивания мы получаем от маршрутизатора (по каждому токену и каждому эксперту).\n",
"\n",
"> Тут нужно сделать лирическое отступление: хотя сами вычисления происходят заметно быстрее (из-за того что в единицу времени отрабатывает только часть экспертов), но вам все равно потребуется разместить в памяти GPU абсолютно всех экспертов одновременно.\n",
"\n",
"### Экспертность\n",
"\n",
"Технически при обучении LLM эксперты в MoE никак принудительно не «специализируются». Т.е. все эксперты независимы и равны между собой. Но исследования показывают, что эксперты часто самоорганизуются и сами специализируются на определенных типах данных или задачах:\n",
"\n",
"* Эксперт 1 — хорошо работает с техническим текстом.\n",
"* Эксперт 2 — лучше справляется с художественным текстом.\n",
"* Эксперт 3 — специализируется на грамматике.\n",
"* Эксперт 4 — хорошо обрабатывает имена собственные.\n",
"* Эксперт 5 — лучше понимает азиатские языки и т.д.\n",
"\n",
"И встречая очередной токен, роутер, на основе его содержимого, решает, к какому эксперту его отправить.\n",
"\n",
"> При обучении роутер может неравномерно распределять внимание между экспертами. Может даже отправлять все данные к одному эксперту. Лечится это специальными балансными лоссами.\n"
]
},
{
"cell_type": "markdown",
"id": "8b5851df",
"metadata": {},
"source": [
"**MoE класс (разработка)**"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "982c3595",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from torch import nn\n",
"\n",
"class SiLU(nn.Module):\n",
" def forward(self, x: torch.Tensor): # [batch_size × seq_len × emb_size]\n",
" return torch.sigmoid(x) * x\n",
" \n",
"class SwiGLU(nn.Module):\n",
" def __init__(self, emb_size: int, dropout: float = 0.1):\n",
" super().__init__()\n",
"\n",
" self._gate = nn.Linear(emb_size, 4 * emb_size)\n",
" self._up = nn.Linear(emb_size, 4 * emb_size)\n",
" self._down = nn.Linear(4 * emb_size, emb_size)\n",
" self._activation = SiLU()\n",
" self._dropout = nn.Dropout(dropout)\n",
"\n",
" def forward(self, x: torch.Tensor): # [batch_size × seq_len × emb_size].\n",
" gate_out = self._gate(x) # [batch, seq, 4*emb]\n",
" activation_out = self._activation(gate_out) # [batch, seq, 4*emb]\n",
" up_out = self._up(x) # [batch, seq, 4*emb]\n",
" out = up_out * activation_out # поэлементное!\n",
" out = self._down(out) # [batch, seq, emb]\n",
" return self._dropout(out)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "34fac9f7",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from torch import nn\n",
"import torch.nn.functional as F\n",
"\n",
"class MoE(nn.Module):\n",
" def __init__(\n",
" self,\n",
" emb_size: int,\n",
" num_experts: int,\n",
" top_k_experts: int,\n",
" dropout: float = 0.1,\n",
" ):\n",
" super().__init__()\n",
" self._num_experts = num_experts\n",
" self._top_k_experts = top_k_experts\n",
"\n",
" self._router = nn.Linear(emb_size, num_experts)\n",
" self._experts = nn.ModuleList([SwiGLU(\n",
" emb_size=emb_size,\n",
" dropout=dropout,\n",
" ) for _ in range(num_experts)])\n",
" self._dropout = nn.Dropout(dropout)\n",
"\n",
" def forward(self, x: torch.Tensor):\n",
" batch_size, seq_len, emb_size = x.shape\n",
" \n",
" # 1. Пропускаем через роутер\n",
" router_logits = self._router(x) # [batch_size, seq_len, num_experts]\n",
" \n",
" # 2. Отбираем топ-k экспертов для каждого токена\n",
" topk_logits, topk_indices = torch.topk(\n",
" router_logits, \n",
" k=self._top_k_experts, \n",
" dim=-1\n",
" ) # topk_logits: [batch_size, seq_len, top_k]\n",
" # topk_indices: [batch_size, seq_len, top_k]\n",
" \n",
" # 3. Получаем веса через softmax и нормируем\n",
" topk_weights = F.softmax(topk_logits, dim=-1) # [batch_size, seq_len, top_k]\n",
" \n",
" # 4. Создаём нулевой тензор для результата\n",
" output = torch.zeros_like(x) # [batch_size, seq_len, emb_size] \n",
"\n",
" # 5. Проходим по всем экспертам\n",
" for expert_id in range(self._num_experts):\n",
" # Шаг 1: Создаём маску - где находится текущий эксперт в топ-k\n",
" expert_mask = (topk_indices == expert_id) # [batch_size, seq_len, top_k]\n",
" # Шаг 2: Проверяем, выбран ли эксперт хотя бы одним токеном\n",
" if not expert_mask.any():\n",
" continue # Эксперт никем не выбран, переходим к следующему\n",
"\n",
" # Шаг 3: Находим токены, которые выбрали этого эксперта\n",
" # (хотя бы в одной из top_k позиций)\n",
" token_mask = expert_mask.any(dim=-1) # [batch_size, seq_len]\n",
"\n",
" # Шаг 4: Отбираем токены из x\n",
" # Отбираем токены для этого эксперта\n",
" expert_input = x[token_mask]\n",
"\n",
" # Пропускаем через эксперта\n",
" # Добавляем batch dimension для SwiGLU и затем убираем\n",
" expert_output = self._experts[expert_id](\n",
" expert_input.unsqueeze(0)\n",
" ).squeeze(0)\n",
"\n",
" # Получаем веса для этого эксперта\n",
" # Для каждого токена может быть несколько весов (если эксперт в топ-k несколько раз)\n",
" # Но на практике каждый эксперт появляется максимум 1 раз в топ-k\n",
" # Находим веса: где expert_mask == True, берём соответствующий вес\n",
" weights_for_expert = torch.zeros(\n",
" batch_size, seq_len, device=x.device\n",
" )\n",
"\n",
" # Для каждой позиции в топ-k\n",
" for k in range(self._top_k_experts):\n",
" mask_k = topk_indices[:, :, k] == expert_id\n",
" weights_for_expert[mask_k] = topk_weights[:, :, k][mask_k]\n",
"\n",
" # Отбираем только веса для выбранных токенов\n",
" selected_weights = weights_for_expert[token_mask] # [num_selected_tokens]\n",
"\n",
"\n",
" # Перемножьте выход эксперта на веса текущего эксперта.\n",
" weighted_output = selected_weights.unsqueeze(-1) * expert_output\n",
"\n",
" # Помещаем результат на своё место в выходном тензоре\n",
" output[token_mask] += weighted_output\n",
" \n",
" out = self._dropout(output)\n",
"\n",
" return out"
]
},
{
"cell_type": "markdown",
"id": "d9135164",
"metadata": {},
"source": []
},
{
"cell_type": "markdown",
"id": "bef0c904",
"metadata": {},
"source": [
"## Full Model"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "35e52050",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from torch import nn\n",
"from torch import Tensor\n",
"import torch.nn.functional as F\n",
"from math import sqrt\n",
"\n",
"\n",
"\n",
"class SiLU(nn.Module):\n",
" def forward(self, x: torch.Tensor): # [batch_size × seq_len × emb_size]\n",
" return torch.sigmoid(x) * x\n",
" \n",
"class RMSNorm(nn.Module):\n",
" def __init__(self, dim: int, eps: float = 1e-6):\n",
" super().__init__()\n",
" self._eps = eps\n",
" self._w = nn.Parameter(torch.ones(dim))\n",
" \n",
" def forward(self, x: torch.Tensor): # [batch_size × seq_len × emb_size]\n",
" rms = (x.pow(2).mean(-1, keepdim=True) + self._eps) ** 0.5\n",
" norm_x = x / rms\n",
" return self._w * norm_x\n",
"\n",
"class SwiGLU(nn.Module):\n",
" def __init__(self, emb_size: int, dropout: float = 0.1):\n",
" super().__init__()\n",
"\n",
" self._gate = nn.Linear(emb_size, 4 * emb_size)\n",
" self._up = nn.Linear(emb_size, 4 * emb_size)\n",
" self._down = nn.Linear(4 * emb_size, emb_size)\n",
" self._activation = SiLU()\n",
" self._dropout = nn.Dropout(dropout)\n",
"\n",
" def forward(self, x: torch.Tensor): # [batch_size × seq_len × emb_size].\n",
" gate_out = self._gate(x) # [batch, seq, 4*emb]\n",
" activation_out = self._activation(gate_out) # [batch, seq, 4*emb]\n",
" up_out = self._up(x) # [batch, seq, 4*emb]\n",
" out = up_out * activation_out # поэлементное!\n",
" out = self._down(out) # [batch, seq, emb]\n",
" return self._dropout(out)\n",
"\n",
"\n",
"class TokenEmbeddings(nn.Module):\n",
" def __init__(self, vocab_size: int, emb_size: int):\n",
" super().__init__()\n",
" self._embedding = nn.Embedding(\n",
" num_embeddings=vocab_size,\n",
" embedding_dim=emb_size\n",
" )\n",
"\n",
" def forward(self, x: Tensor) -> Tensor:\n",
" return self._embedding(x)\n",
"\n",
" @property\n",
" def num_embeddings(self) -> int:\n",
" return self._embedding.num_embeddings\n",
"\n",
" @property\n",
" def embedding_dim(self) -> int:\n",
" return self._embedding.embedding_dim\n",
"\n",
"\n",
"class GELU(nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.sqrt_2_over_pi = torch.sqrt(torch.tensor(2.0) / math.pi)\n",
" \n",
" def forward(self, x: torch.Tensor) -> torch.Tensor:\n",
" return 0.5 * x * (1 + torch.tanh(\n",
" self.sqrt_2_over_pi * (x + 0.044715 * torch.pow(x, 3))\n",
" ))\n",
"\n",
" \n",
" \n",
"import torch\n",
"from torch import nn\n",
"from typing import Optional\n",
"\n",
"\n",
"class RoPE(nn.Module):\n",
"\n",
" def __init__(self, head_size: int, max_seq_len: int, base: int = 10_000):\n",
" super().__init__()\n",
" assert head_size % 2 == 0, \"head_size должен быть четным\"\n",
"\n",
" # Вычисление частот: θ_i = base^(-2i/d) для i ∈ [0, d/2-1]\n",
" freqs = 1.0 / (base ** (2 * torch.arange(head_size // 2).float() / head_size))\n",
"\n",
" # Позиции от 0 до max_seq_len-1\n",
" positions = torch.arange(max_seq_len).float()\n",
"\n",
" # Внешнее произведение: m * θ_i для всех позиций и частот\n",
" freq_matrix = positions.unsqueeze(1) * freqs.unsqueeze(0)\n",
"\n",
" # Предвычисление матриц косинусов и синусов\n",
" self.register_buffer(\"cos_matrix\", torch.cos(freq_matrix))\n",
" self.register_buffer(\"sin_matrix\", torch.sin(freq_matrix))\n",
"\n",
" def forward(self, x: torch.Tensor, start_pos: int = 0) -> torch.Tensor: # [batch_size × seq_len × head_size] [batch_size × num_heads × seq_len × head_size]\n",
" batch_size, num_heads, seq_len, head_size = x.shape\n",
"\n",
" # Берем нужную часть матриц и приводим к типу x\n",
" cos = self.cos_matrix[start_pos:start_pos+seq_len].to(x.dtype) # [seq_len, head_size//2]\n",
" sin = self.sin_matrix[start_pos:start_pos+seq_len].to(x.dtype) # [seq_len, head_size//2]\n",
"\n",
" # Явное изменение формы для broadcasting\n",
" cos = cos.reshape(1, 1, seq_len, head_size // 2)\n",
" sin = sin.reshape(1, 1, seq_len, head_size // 2)\n",
"\n",
" # Разделяем на четные и нечетные компоненты по ПОСЛЕДНЕМУ измерению\n",
" x_even = x[..., 0::2] # [batch_size, num_heads, seq_len, head_size//2]\n",
" x_odd = x[..., 1::2] # [batch_size, num_heads, seq_len, head_size//2]\n",
"\n",
" # Применяем поворот: q' = q * cos(mθ) + rotate(q) * sin(mθ)\n",
" x_rotated_even = x_even * cos - x_odd * sin\n",
" x_rotated_odd = x_even * sin + x_odd * cos\n",
"\n",
" # Объединяем обратно в исходную размерность\n",
" x_rotated = torch.stack([x_rotated_even, x_rotated_odd], dim=-1)\n",
" x_rotated = x_rotated.flatten(-2) # [batch_size, seq_len, head_size]\n",
"\n",
" return x_rotated\n",
"\n",
"\n",
"import torch\n",
"from torch import nn\n",
"import torch.nn.functional as F\n",
"from typing import Optional, Tuple\n",
"\n",
"\n",
" \n",
"class GroupedQueryAttention(nn.Module):\n",
"\n",
" def __init__(\n",
" self,\n",
" num_q_heads: int,\n",
" num_kv_heads: int,\n",
" emb_size: int,\n",
" head_size: int,\n",
" max_seq_len: int,\n",
" window_size: int,\n",
" rope: RoPE = None,\n",
" dropout: float = 0.1,\n",
" ):\n",
" super().__init__()\n",
" self._num_heads = num_q_heads\n",
" self._num_kv_heads = num_kv_heads\n",
" self._head_size = head_size\n",
" self._max_seq_len = max_seq_len\n",
" self._rope = rope\n",
" self._window_size = window_size\n",
"\n",
" self._q = nn.Linear(emb_size, self._num_heads * head_size)\n",
" self._k = nn.Linear(emb_size, num_kv_heads * head_size)\n",
" self._v = nn.Linear(emb_size, num_kv_heads * head_size)\n",
"\n",
" # Создание causal маски\n",
" mask = self._create_sliding_window_mask(max_seq_len, self._window_size)\n",
" self.register_buffer(\n",
" \"_tril_mask\", mask.bool() if hasattr(torch, \"bool\") else mask.byte()\n",
" )\n",
" \n",
" self._layer = nn.Linear(head_size * self._num_heads, emb_size)\n",
" self._dropout = nn.Dropout(dropout)\n",
"\n",
" def forward(\n",
" self,\n",
" x: torch.Tensor,\n",
" mask: torch.Tensor = None,\n",
" use_cache: bool = True,\n",
" cache: list = None,\n",
" ):\n",
" batch_size, seq_len, emb_size = x.shape\n",
"\n",
" if seq_len > self._max_seq_len:\n",
" raise ValueError(\n",
" f\"Длина последовательности {seq_len} превышает максимум {self._max_seq_len}\"\n",
" )\n",
"\n",
" # Пропустите тензор x через матрицы Wq, Wk , Wv, чтобы получить матрицы запроса, ключа и значения.\n",
" k = self._k(x) # [B, T, hs]\n",
" q = self._q(x) # [B, T, hs]\n",
" v = self._v(x) # [B, T, hs]\n",
"\n",
" # Шаг 2: Изменение формы для multi-head\n",
" # [batch_size, seq_len, num_heads * head_size] \n",
" # -> [batch_size, seq_len, num_heads, head_size]\n",
" # Измените форму запроса (query) на batch_size × num_q_heads × seq_len × head_size.\n",
" q = q.reshape(batch_size, seq_len, self._num_heads, self._head_size)\n",
"\n",
" # Измените форму ключа (key) и значения (value) на batch_size × num_kv_heads × seq_len × head_size.\n",
" k = k.reshape(batch_size, seq_len, self._num_kv_heads, self._head_size)\n",
" v = v.reshape(batch_size, seq_len, self._num_kv_heads, self._head_size)\n",
" \n",
"\n",
" # 3. Transpose: [B, T, H, hs] -> [B, H, T, hs]\n",
" q = q.transpose(1, 2)\n",
" k = k.transpose(1, 2)\n",
" v = v.transpose(1, 2)\n",
"\n",
" start_pos = 0\n",
" if cache is not None:\n",
" k_cache, v_cache = cache\n",
" cache_len = k_cache.shape[2]\n",
" start_pos = cache_len\n",
" \n",
" # Пропустите матрицы запроса и ключа через экземпляр rope, чтобы выполнить поворот.\n",
" if self._rope is not None:\n",
" # Применяем RoPE к Q и K (НЕ к V!)\n",
" q = self._rope(q, start_pos=start_pos) # [B, T, hs]\n",
" k = self._rope(k, start_pos=start_pos) # [B, T, hs]\n",
"\n",
" # Если cache пришел, то объединяем кэш и одну строку из ключа и значения. Это будут новые key и value для последующих вычислений.\n",
" # 5. Кэширование (для autoregressive generation)\n",
" if cache is not None:\n",
" k_cache, v_cache = cache\n",
" k = torch.cat([k_cache, k], dim=2) # Concat по seq_len (dim=2)\n",
" v = torch.cat([v_cache, v], dim=2)\n",
"\n",
" # Если use_cache == True, то сохраните матрицы ключа и значения для кэша (это нужно сделать до дублирования голов).\n",
" #if use_cache == True:\n",
" # # Обрезаем до последних window_size токенов\n",
" # k_to_cache = k[:, :, -self._window_size:, :]\n",
" # v_to_cache = v[:, :, -self._window_size:, :]\n",
" # kv_cache = (k_to_cache, v_to_cache)\n",
"\n",
" # Продублируйте головы в тензорах ключа (key) и значения (value), чтобы получился тензор размера на batch_size × num_q_heads × seq_len × head_size.\n",
" #k = self._repeat_kv_heads(k, self._num_heads, self._num_kv_heads)\n",
" #v = self._repeat_kv_heads(v, self._num_heads, self._num_kv_heads)\n",
" k_expanded = self._repeat_kv_heads(k, self._num_heads, self._num_kv_heads)\n",
" v_expanded = self._repeat_kv_heads(v, self._num_heads, self._num_kv_heads)\n",
" \n",
" # Перемножим матрицы запроса и ключа (транспонированную), чтобы вычислить матрицу внимания.\n",
" # И разделить все значения в матрице внимания на корень из head_size.\n",
" scores = q @ k_expanded.transpose(-2, -1) / (self._head_size ** 0.5)\n",
"\n",
" # 8. Применение маски\n",
" k_seq_len = k_expanded.size(2) # Длина K после concat с кэшем\n",
" \n",
" if cache is None:\n",
" # Случай 1: Без кэша - полная квадратная маска\n",
" # scores: [B, H, seq_len, seq_len]\n",
" # Применяем маску [:seq_len, :seq_len]\n",
" scores = scores.masked_fill(\n",
" ~self._tril_mask[:seq_len, :seq_len], \n",
" float(\"-inf\")\n",
" )\n",
"\n",
" # Применить к матрице внимания (построчно) функцию Softmax.\n",
" weights = F.softmax(scores, dim=-1)\n",
"\n",
" # Перемножим матрицу внимания и матрицу значения.\n",
" x_out = weights @ v_expanded # [B, T, hs]\n",
"\n",
" # Измените форму тензора на batch_size × seq_len × num_heads*head_size.\n",
" # Transpose обратно и concatenate heads\n",
" x_out = x_out.transpose(1, 2) # [B, T_q, H, hs]\n",
" x_out = x_out.contiguous() # Важно для reshape!\n",
" concatenated_attention = x_out.reshape(batch_size, seq_len, self._num_heads * self._head_size)\n",
"\n",
" #concatenated_attention = x_out.reshape(batch_size, seq_len, self._num_heads * self._head_size)\n",
"\n",
" # Пропустите получившийся тензор через последний линейный слой.\n",
" # 3. Проецируем в пространство эмбеддингов\n",
" projected_output = self._layer(concatenated_attention)\n",
"\n",
" # 4. Применяем dropout для регуляризации\n",
" output = self._dropout(projected_output)\n",
"\n",
" if use_cache:\n",
" # Обрезаем оригинальный K и V (до дублирования)\n",
" k_to_cache = k[:, :, -self._window_size:, :]\n",
" v_to_cache = v[:, :, -self._window_size:, :]\n",
" kv_cache = (k_to_cache, v_to_cache)\n",
" return output, kv_cache\n",
" else:\n",
" return output, None\n",
"\n",
" def _repeat_kv_heads(\n",
" self,\n",
" kv: torch.Tensor,\n",
" num_q_heads: int,\n",
" num_kv_heads: int\n",
" ) -> torch.Tensor:\n",
" \"\"\"\n",
" Дублирует головы K/V для соответствия количеству голов Q.\n",
"\n",
" Args:\n",
" kv: [batch_size, num_kv_heads, seq_len, head_size]\n",
" num_q_heads: Количество голов Query (например, 8)\n",
" num_kv_heads: Количество голов Key/Value (например, 2)\n",
"\n",
" Returns:\n",
" [batch_size, num_q_heads, seq_len, head_size]\n",
"\n",
" Example:\n",
" num_q_heads=8, num_kv_heads=2\n",
" Каждая голова KV дублируется 4 раза:\n",
" [KV0, KV1] -> [KV0, KV0, KV0, KV0, KV1, KV1, KV1, KV1]\n",
" \"\"\"\n",
" batch_size, num_kv_heads, seq_len, head_size = kv.shape\n",
"\n",
" if num_q_heads == num_kv_heads:\n",
" # Нет необходимости дублировать\n",
" return kv\n",
"\n",
" # Вычисляем сколько раз нужно повторить каждую голову\n",
" num_repeats = num_q_heads // num_kv_heads\n",
"\n",
" # repeat_interleave дублирует каждую голову num_repeats раз\n",
" # [B, num_kv_heads, S, hs] -> [B, num_q_heads, S, hs]\n",
" # [B, num_kv_heads, S, hs] -> [B, num_kv_heads, 1, S, hs]\n",
" kv = kv.unsqueeze(2)\n",
" \n",
" # [B, num_kv_heads, 1, S, hs] -> [B, num_kv_heads, num_repeats, S, hs]\n",
" kv = kv.repeat(1, 1, num_repeats, 1, 1)\n",
" \n",
" # [B, num_kv_heads, num_repeats, S, hs] -> [B, num_q_heads, S, hs]\n",
" kv = kv.reshape(batch_size, num_q_heads, seq_len, head_size)\n",
" \n",
"\n",
" return kv\n",
"\n",
" def _create_sliding_window_mask(\n",
" self,\n",
" max_seq_len: int,\n",
" window_size: int,\n",
" device: torch.device = None\n",
" ) -> torch.Tensor:\n",
" \"\"\"\n",
" Создает маску для Sliding Window Attention.\n",
"\n",
" Args:\n",
" max_seq_len: Максимальная длина последовательности\n",
" window_size: Размер окна внимания\n",
" device: Устройство для размещения тензора\n",
"\n",
" Returns:\n",
" Маска формы [max_seq_len, max_seq_len], где True = разрешено\n",
"\n",
" Example:\n",
" >>> mask = create_sliding_window_mask(8, 3)\n",
" >>> print(mask.int())\n",
" tensor([[1, 0, 0, 0, 0, 0, 0, 0],\n",
" [1, 1, 0, 0, 0, 0, 0, 0],\n",
" [1, 1, 1, 0, 0, 0, 0, 0],\n",
" [0, 1, 1, 1, 0, 0, 0, 0],\n",
" [0, 0, 1, 1, 1, 0, 0, 0],\n",
" [0, 0, 0, 1, 1, 1, 0, 0],\n",
" [0, 0, 0, 0, 1, 1, 1, 0],\n",
" [0, 0, 0, 0, 0, 1, 1, 1]])\n",
" \"\"\"\n",
" row_indices = torch.arange(max_seq_len, device=device).unsqueeze(1) # [max_seq_len, 1]\n",
" col_indices = torch.arange(max_seq_len, device=device).unsqueeze(0) # [1, max_seq_len]\n",
"\n",
" causal_mask = col_indices <= row_indices\n",
"\n",
" window_mask = (row_indices - col_indices) <= window_size\n",
"\n",
" mask = causal_mask & window_mask\n",
" \n",
" return mask\n",
"\n",
"class Decoder(nn.Module):\n",
" def __init__(self, \n",
" num_q_heads: int,\n",
" num_kv_heads: int,\n",
" emb_size: int,\n",
" head_size: int,\n",
" max_seq_len: int,\n",
" num_experts: int,\n",
" top_k_experts: int,\n",
" window_size: int,\n",
" rope: RoPE,\n",
" dropout: float = 0.1\n",
" ):\n",
" super().__init__()\n",
" self._heads = GroupedQueryAttention(\n",
" num_q_heads=num_q_heads, \n",
" num_kv_heads=num_kv_heads,\n",
" emb_size=emb_size, \n",
" head_size=head_size, \n",
" max_seq_len=max_seq_len,\n",
" window_size=window_size,\n",
" rope=rope,\n",
" dropout=dropout\n",
" )\n",
" self._ff = MoE(\n",
" emb_size=emb_size, \n",
" num_experts=num_experts,\n",
" top_k_experts=top_k_experts,\n",
" dropout=dropout\n",
" )\n",
" self._norm1 = RMSNorm(emb_size)\n",
" self._norm2 = RMSNorm(emb_size)\n",
"\n",
" def forward(self, x: torch.Tensor, mask: torch.Tensor = None, use_cache: bool = True, cache: list = None) -> torch.Tensor:\n",
" norm1_out = self._norm1(x)\n",
" attention, kv_caches = self._heads(norm1_out, mask, use_cache=use_cache, cache=cache)\n",
" out = attention + x\n",
" \n",
" norm2_out = self._norm2(out)\n",
" ffn_out = self._ff(norm2_out)\n",
"\n",
" if use_cache is True:\n",
" return (ffn_out + out, kv_caches)\n",
" else:\n",
" return (ffn_out + out, None)\n",
"\n",
"\n",
"\n",
"from torch import nn\n",
"import torch\n",
"import torch.nn.functional as F\n",
"\n",
"class Mixtral(nn.Module):\n",
" def __init__(self,\n",
" vocab_size: int,\n",
" max_seq_len: int,\n",
" emb_size: int,\n",
" num_q_heads: int,\n",
" num_kv_heads: int,\n",
" head_size: int,\n",
" num_layers: int,\n",
" num_experts: int,\n",
" top_k_experts: int,\n",
" window_size: int,\n",
" dropout: float = 0.1,\n",
" device: str = 'cpu'\n",
" ):\n",
" super().__init__()\n",
" self._vocab_size = vocab_size\n",
" self._max_seq_len = max_seq_len\n",
" self._emb_size = emb_size\n",
" self._num_q_heads = num_q_heads\n",
" self._num_kv_heads = num_kv_heads\n",
" self._head_size = head_size\n",
" self._num_layers = num_layers\n",
" self._dropout = dropout\n",
" self._device = device\n",
" \n",
" self.validation_loss = None\n",
"\n",
" # Инициализация слоев\n",
" self._token_embeddings = TokenEmbeddings(\n",
" vocab_size=vocab_size, \n",
" emb_size=emb_size\n",
" )\n",
" self._position_embeddings = RoPE(\n",
" head_size=head_size,\n",
" max_seq_len=max_seq_len\n",
" )\n",
" #self._position_embeddings = PositionalEmbeddings(\n",
" # max_seq_len=max_seq_len, \n",
" # emb_size=emb_size\n",
" #)\n",
" self._dropout = nn.Dropout(dropout)\n",
" self._decoders = nn.ModuleList([Decoder(\n",
" num_q_heads=num_q_heads,\n",
" num_kv_heads=num_kv_heads,\n",
" emb_size=emb_size,\n",
" head_size=head_size,\n",
" max_seq_len=max_seq_len,\n",
" num_experts=num_experts,\n",
" top_k_experts=top_k_experts,\n",
" window_size=window_size,\n",
" rope=self._position_embeddings,\n",
" dropout=dropout \n",
" ) for _ in range(num_layers)])\n",
" self._norm = RMSNorm(emb_size)\n",
" self._linear = nn.Linear(emb_size, vocab_size)\n",
"\n",
" def forward(self, x: torch.Tensor, use_cache: bool = True, cache: list = None) -> tuple:\n",
" # Проверка длины последовательности (только при отсутствии кэша)\n",
" if cache is None and x.size(1) > self._max_seq_len:\n",
" raise ValueError(f\"Длина последовательности {x.size(1)} превышает максимальную {self.max_seq_len}\")\n",
" \n",
" # Эмбеддинги токенов и позиций\n",
" tok_out = self._token_embeddings(x) # [batch, seq_len, emb_size]\n",
" #pos_out = self._position_embeddings(x) # [batch, seq_len, emb_size]\n",
" \n",
" # Комбинирование\n",
" out = self._dropout(tok_out) # [batch, seq_len, emb_size]\n",
" \n",
" # Стек декодеров с передачей кэша\n",
" new_cache = []\n",
" for i, decoder in enumerate(self._decoders):\n",
" decoder_cache = cache[i] if cache is not None else None\n",
" decoder_result = decoder(out, use_cache=use_cache, cache=decoder_cache)\n",
"\n",
" # Извлекаем результат из кортежа\n",
" if use_cache:\n",
" out, decoder_new_cache = decoder_result\n",
" new_cache.append(decoder_new_cache)\n",
" else:\n",
" out = decoder_result[0]\n",
"\n",
" out = self._norm(out)\n",
" logits = self._linear(out)\n",
" \n",
" # Возвращаем результат с учетом use_cache\n",
" if use_cache:\n",
" return (logits, new_cache)\n",
" else:\n",
" return (logits, None)\n",
"\n",
" def generate(self,\n",
" x: torch.Tensor, \n",
" max_new_tokens: int, \n",
" do_sample: bool,\n",
" temperature: float = 1.0,\n",
" top_k: int = None,\n",
" top_p: float = None,\n",
" use_cache: bool = True\n",
" ) -> torch.Tensor:\n",
" cache = None\n",
"\n",
" for _ in range(max_new_tokens):\n",
" if use_cache and cache is not None:\n",
" # Используем кэш - передаем только последний токен\n",
" x_input = x[:, -1:] # [batch_size, 1]\n",
" else:\n",
" # Первая итерация или кэш отключен - передаем всю последовательность\n",
" x_input = x\n",
" \n",
" # Прямой проход с кэшем\n",
" logits, new_cache = self.forward(x_input, use_cache=use_cache, cache=cache)\n",
" \n",
" # Обновляем кэш для следующей итерации\n",
" if use_cache:\n",
" cache = new_cache\n",
"\n",
" last_logits = logits[:, -1, :] # [batch_size, vocab_size]\n",
"\n",
" # Масштабируем логиты температурой\n",
" if temperature > 0:\n",
" logits_scaled = last_logits / temperature\n",
" else:\n",
" logits_scaled = last_logits\n",
"\n",
" if do_sample == True and top_k != None:\n",
" _, topk_indices = torch.topk(logits_scaled, top_k, dim=-1)\n",
"\n",
" # # Заменим все НЕ top-k логиты на -inf\n",
" masked_logits = logits_scaled.clone()\n",
" vocab_size = logits_scaled.size(-1)\n",
"\n",
" # создаём маску: 1, если токен НЕ в topk_indices\n",
" mask = torch.ones_like(logits_scaled, dtype=torch.uint8)\n",
" mask.scatter_(1, topk_indices, 0) # 0 там, где top-k индексы\n",
" masked_logits[mask.byte()] = float('-inf')\n",
"\n",
" logits_scaled = masked_logits\n",
"\n",
" if do_sample == True and top_p != None:\n",
" # 1. Применим softmax, чтобы получить вероятности:\n",
" probs = F.softmax(logits_scaled, dim=-1) # [B, vocab_size]\n",
" # 2. Отсортируем токены по убыванию вероятностей:\n",
" sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1)\n",
" # 3. Посчитаем кумулятивную сумму вероятностей:\n",
" cum_probs = torch.cumsum(sorted_probs, dim=-1) # [B, vocab_size]\n",
" # 4. Определим маску: оставить токены, пока сумма < top_p\n",
" sorted_mask = (cum_probs <= top_p).byte() # [B, vocab_size]\n",
" # Гарантируем, что хотя бы первый токен останется\n",
" sorted_mask[:, 0] = 1\n",
" # 5. Преобразуем маску обратно в оригинальный порядок:\n",
" # Создаём полную маску из 0\n",
" mask = torch.zeros_like(probs, dtype=torch.uint8)\n",
" # Устанавливаем 1 в местах нужных токенов\n",
" mask.scatter_(dim=1, index=sorted_indices, src=sorted_mask)\n",
" # 6. Зануляем логиты токенов вне топ-p:\n",
" logits_scaled[~mask] = float('-inf')\n",
"\n",
" # 4. Применяем Softmax\n",
" probs = F.softmax(logits_scaled, dim=-1) # [batch_size, vocab_size]\n",
"\n",
"\n",
" if do_sample == True:\n",
" # 5. Если do_sample равен True, то отбираем токен случайно с помощью torch.multinomial\n",
" next_token = torch.multinomial(probs, num_samples=1) # [batch_size, 1]\n",
" else:\n",
" # 5. Если do_sample равен False, то выбираем токен с максимальной вероятностью\n",
" next_token = torch.argmax(probs, dim=-1, keepdim=True) # [batch_size, 1]\n",
" \n",
" # 6. Добавляем его к последовательности\n",
" x = torch.cat([x, next_token], dim=1) # [batch_size, seq_len+1]\n",
" return x\n",
"\n",
" def save(self, path):\n",
" torch.save({\n",
" 'model_state_dict': self.state_dict(),\n",
" 'vocab_size': self._vocab_size,\n",
" 'max_seq_len': self._max_seq_len,\n",
" 'emb_size': self._emb_size,\n",
" 'num_heads': self._num_heads,\n",
" 'head_size': self._head_size,\n",
" 'num_layers': self._num_layers\n",
" }, path)\n",
"\n",
" @classmethod\n",
" def load(cls, path, device):\n",
" checkpoint = torch.load(path, map_location=device)\n",
" model = cls(\n",
" vocab_size=checkpoint['vocab_size'],\n",
" max_seq_len=checkpoint['max_seq_len'],\n",
" emb_size=checkpoint['emb_size'],\n",
" num_heads=checkpoint['num_heads'],\n",
" head_size=checkpoint['head_size'],\n",
" num_layers=checkpoint['num_layers']\n",
" )\n",
" model.load_state_dict(checkpoint['model_state_dict'])\n",
" model.to(device)\n",
" return model\n",
"\n",
" @property\n",
" def max_seq_len(self) -> int:\n",
" return self._max_seq_len"
]
},
{
"cell_type": "markdown",
"id": "7f4b3b1e",
"metadata": {},
"source": [
"## 2. Обучение Mixtral\n",
"\n",
"Mixtral обучается в два этапа:\n",
"\n",
"- 1⃣ **Предобучение (Unsupervised Pretraining)** \n",
"- 2⃣ **Дообучение (Supervised Fine-Tuning)**"
]
},
{
"cell_type": "markdown",
"id": "cb0a9172",
"metadata": {},
"source": [
"\n",
"\n",
"### 5.1 Предобучение\n",
"\n",
"На первом этапе модель обучается без разметки: она получает большой корпус текстов и учится **предсказывать следующий токен** по предыдущим.\n",
"\n",
"Функция потерь:\n",
"$$\n",
"L = - \\sum_{t=1}^{T} \\log P(x_t | x_1, x_2, ..., x_{t-1})\n",
"$$\n",
"\n",
"Таким образом, модель учится строить вероятностную модель языка, \"угадывая\" продолжение текста.\n"
]
},
{
"cell_type": "markdown",
"id": "b4bdbb31",
"metadata": {},
"source": [
"Во время **предобучения** Mistral учится **предсказывать следующий токен** (language modeling task). \n",
"Формально: \n",
"$$ \n",
"P(x_t ,|, x_1, x_2, \\dots, x_{t-1}) \n",
"$$ \n",
"То есть, если на вход подаётся предложение `\"I love deep\"`, модель должна предсказать `\"learning\"`.\n"
]
},
{
"cell_type": "markdown",
"id": "379a5c21",
"metadata": {},
"source": [
"### ✅ 5.1.1 Подготовка данных\n",
"\n",
"Создадим **датасет** на основе BPE-токенизатора:"
]
},
{
"cell_type": "markdown",
"id": "8141ac57",
"metadata": {},
"source": [
"**BPE Tokenizator**"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "3fc41e15",
"metadata": {},
"outputs": [],
"source": [
"class BPE:\n",
" def __init__(self, vocab_size: int):\n",
" self.vocab_size = vocab_size\n",
" self.id2token = {}\n",
" self.token2id = {}\n",
"\n",
" def fit(self, text: str):\n",
" # 1. Получаем уникальные токены (символы)\n",
" unique_tokens = sorted(set(text))\n",
" tokens = unique_tokens.copy()\n",
"\n",
" # 2. Разбиваем текст на токены-символы\n",
" sequence = list(text)\n",
"\n",
" # 3. Объединяем токены до достижения нужного размера словаря\n",
" while len(tokens) < self.vocab_size:\n",
" #print(f'len={len(tokens)} < {self.vocab_size}')\n",
" # Считаем частоты пар\n",
" pair_freq = {}\n",
" for i in range(len(sequence) - 1):\n",
" pair = (sequence[i], sequence[i + 1])\n",
" #print(f'pair = {pair}')\n",
" if pair not in pair_freq:\n",
" pair_freq[pair] = 0\n",
" pair_freq[pair] += 1\n",
"\n",
"\n",
" #print(f'pair_freq = {pair_freq}') \n",
" if not pair_freq:\n",
" break # нет пар — выходим\n",
"\n",
" #for x in pair_freq.items():\n",
" # self.debug(x, sequence)\n",
"\n",
" # Находим самую частую пару (в случае равенства — та, что встретилась первой)\n",
" most_frequent_pair = max(pair_freq.items(), key=lambda x: (x[1], -self._pair_first_index(sequence, x[0])))[0]\n",
" #print(most_frequent_pair)\n",
" # Создаем новый токен\n",
" new_token = most_frequent_pair[0] + most_frequent_pair[1]\n",
" #print(f\"new token={new_token}\")\n",
" tokens.append(new_token)\n",
" #print(f\"tokens={tokens}\")\n",
"\n",
" i = 0\n",
" new_sequence = []\n",
"\n",
" while i < len(sequence):\n",
" if i < len(sequence) - 1 and (sequence[i], sequence[i + 1]) == most_frequent_pair:\n",
" new_sequence.append(new_token)\n",
" i += 2 # пропускаем два символа — заменённую пару\n",
" else:\n",
" new_sequence.append(sequence[i])\n",
" i += 1\n",
" sequence = new_sequence\n",
" #break\n",
" \n",
" # 4. Создаем словари\n",
" self.vocab = tokens.copy()\n",
" self.token2id = dict(zip(tokens, range(self.vocab_size)))\n",
" self.id2token = dict(zip(range(self.vocab_size), tokens))\n",
"\n",
" def _pair_first_index(self, sequence, pair):\n",
" for i in range(len(sequence) - 1):\n",
" if (sequence[i], sequence[i + 1]) == pair:\n",
" return i\n",
" return float('inf') # если пара не найдена (в теории не должно случиться)\n",
"\n",
"\n",
" def encode(self, text: str):\n",
" # 1. Разбиваем текст на токены-символы\n",
" sequence = list(text)\n",
" # 2. Инициализация пустого списка токенов\n",
" tokens = []\n",
" # 3. Установить i = 0\n",
" i = 0\n",
" while i < len(text):\n",
" # 3.1 Найти все токены в словаре, начинающиеся с text[i]\n",
" start_char = text[i]\n",
" result = [token for token in self.vocab if token.startswith(start_char)]\n",
" # 3.2 Выбрать самый длинный подходящий токен\n",
" find_token = self._find_max_matching_token(text[i:], result)\n",
" if find_token is None:\n",
" # Обработка неизвестного символа\n",
" tokens.append(text[i]) # Добавляем сам символ как токен\n",
" i += 1\n",
" else:\n",
" # 3.3 Добавить токен в результат\n",
" tokens.append(find_token)\n",
" # 3.4 Увеличить i на длину токена\n",
" i += len(find_token)\n",
"\n",
" # 4. Заменить токены на их ID\n",
" return self._tokens_to_ids(tokens)\n",
"\n",
" def _find_max_matching_token(self, text: str, tokens: list):\n",
" \"\"\"Находит самый длинный токен из списка, с которого начинается текст\"\"\"\n",
" matching = [token for token in tokens if text.startswith(token)]\n",
" return max(matching, key=len) if matching else None\n",
"\n",
" def _tokens_to_ids(self, tokens):\n",
" \"\"\"Конвертирует список токенов в их ID с обработкой неизвестных токенов\"\"\"\n",
" ids = []\n",
" for token in tokens:\n",
" if token in self.token2id:\n",
" ids.append(self.token2id[token])\n",
" else:\n",
" ids.append(0) # Специальное значение\n",
" return ids\n",
"\n",
"\n",
" def decode(self, ids: list) -> str:\n",
" return ''.join(self._ids_to_tokens(ids))\n",
"\n",
" def _ids_to_tokens(self, ids: list) -> list:\n",
" \"\"\"Конвертирует список Ids в их tokens\"\"\"\n",
" tokens = []\n",
" for id in ids:\n",
" if id in self.id2token:\n",
" tokens.append(self.id2token[id])\n",
" else:\n",
" tokens.append('') # Специальное значение\n",
" return tokens\n",
"\n",
"\n",
" def save(self, filename):\n",
" with open(filename, 'wb') as f:\n",
" dill.dump(self, f)\n",
" print(f\"Объект сохранён в {filename}\")\n",
"\n",
"\n",
" @classmethod\n",
" def load(cls, filename):\n",
" with open(filename, 'rb') as f:\n",
" obj = dill.load(f)\n",
" \n",
" print(f\"Объект загружен из {filename}\")\n",
" return obj"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "dc7a3dbf",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from torch.utils.data import Dataset, DataLoader\n",
"\n",
"class GPTDataset(Dataset):\n",
" def __init__(self, text: str, bpe: BPE, block_size: int):\n",
" self.bpe = bpe\n",
" self.block_size = block_size\n",
" self.data = bpe.encode(text)\n",
" \n",
" def __len__(self):\n",
" return len(self.data) - self.block_size\n",
"\n",
" def __getitem__(self, idx):\n",
" x = torch.tensor(self.data[idx:idx+self.block_size], dtype=torch.long)\n",
" y = torch.tensor(self.data[idx+1:idx+self.block_size+1], dtype=torch.long)\n",
" return x, y"
]
},
{
"cell_type": "markdown",
"id": "e8981229",
"metadata": {},
"source": [
"### ✅ 5.1.2 Цикл обучения\n",
"\n",
"Для обучения создадим функцию:"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "4e097434",
"metadata": {},
"outputs": [],
"source": [
"import torch.nn.functional as F\n",
"from torch import optim\n",
"\n",
"def train_mixtral(model, dataset, epochs=5, batch_size=32, lr=3e-4, device='cpu'):\n",
" dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)\n",
" optimizer = optim.AdamW(model.parameters(), lr=lr)\n",
"\n",
" model.to(device)\n",
" model.train()\n",
"\n",
" for epoch in range(epochs):\n",
" total_loss = 0\n",
" for x, y in dataloader:\n",
" x, y = x.to(device), y.to(device)\n",
"\n",
" # Прямой проход\n",
" logits, _ = model(x, use_cache=False) # [B, T, vocab_size]\n",
"\n",
" # Перестроим выход под CrossEntropy\n",
" loss = F.cross_entropy(logits.view(-1, logits.size(-1)), y.view(-1))\n",
"\n",
" # Обратное распространение\n",
" optimizer.zero_grad()\n",
" loss.backward()\n",
" optimizer.step()\n",
"\n",
" total_loss += loss.item()\n",
"\n",
" avg_loss = total_loss / len(dataloader)\n",
" print(f\"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}\")\n",
"\n",
" return model"
]
},
{
"cell_type": "markdown",
"id": "1f536de2",
"metadata": {},
"source": [
"### ✅ 5.1.3 Пример запуска\n",
"\n",
"\n",
"**🧠 Конфигурация Mistral Mini**\n",
"\n",
"\n",
"| Параметр | Значение | Описание |\n",
"| --------------- | -------- | --------------------------------------------- |\n",
"| **vocab_size** | `50257` | Размер словаря (BPE токенизатор OpenAI) |\n",
"| **max_seq_len** | `512` | Максимальная длина входной последовательности |\n",
"| **emb_size** | `256` | Размер эмбеддингов (векторное пространство) |\n",
"| **num_heads** | `4` | Количество голов в multi-head attention |\n",
"| **head_size** | `64` | Размерность одной головы внимания (768 / 12) |\n",
"| **num_layers** | `4` | Количество блоков (декодеров) |\n",
"| **dropout** | `0.1` | Вероятность дропаута |\n"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "83ed4be7",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Dataset length: 20\n",
"Epoch 1/100, Loss: 3.5861\n",
"Epoch 2/100, Loss: 1.1913\n",
"Epoch 3/100, Loss: 0.4864\n",
"Epoch 4/100, Loss: 0.2490\n",
"Epoch 5/100, Loss: 0.1993\n",
"Epoch 6/100, Loss: 0.1350\n",
"Epoch 7/100, Loss: 0.1039\n",
"Epoch 8/100, Loss: 0.1051\n",
"Epoch 9/100, Loss: 0.0730\n",
"Epoch 10/100, Loss: 0.0754\n",
"Epoch 11/100, Loss: 0.0819\n",
"Epoch 12/100, Loss: 0.0664\n",
"Epoch 13/100, Loss: 0.0793\n",
"Epoch 14/100, Loss: 0.0668\n",
"Epoch 15/100, Loss: 0.0818\n",
"Epoch 16/100, Loss: 0.0734\n",
"Epoch 17/100, Loss: 0.0637\n",
"Epoch 18/100, Loss: 0.0584\n",
"Epoch 19/100, Loss: 0.0762\n",
"Epoch 20/100, Loss: 0.0683\n",
"Epoch 21/100, Loss: 0.0624\n",
"Epoch 22/100, Loss: 0.0557\n",
"Epoch 23/100, Loss: 0.0579\n",
"Epoch 24/100, Loss: 0.0558\n",
"Epoch 25/100, Loss: 0.0578\n",
"Epoch 26/100, Loss: 0.0520\n",
"Epoch 27/100, Loss: 0.0642\n",
"Epoch 28/100, Loss: 0.0660\n",
"Epoch 29/100, Loss: 0.0508\n",
"Epoch 30/100, Loss: 0.0542\n",
"Epoch 31/100, Loss: 0.0478\n",
"Epoch 32/100, Loss: 0.0597\n",
"Epoch 33/100, Loss: 0.0578\n",
"Epoch 34/100, Loss: 0.0557\n",
"Epoch 35/100, Loss: 0.0528\n",
"Epoch 36/100, Loss: 0.0521\n",
"Epoch 37/100, Loss: 0.0520\n",
"Epoch 38/100, Loss: 0.0512\n",
"Epoch 39/100, Loss: 0.0501\n",
"Epoch 40/100, Loss: 0.0497\n",
"Epoch 41/100, Loss: 0.0454\n",
"Epoch 42/100, Loss: 0.0519\n",
"Epoch 43/100, Loss: 0.0535\n",
"Epoch 44/100, Loss: 0.0464\n",
"Epoch 45/100, Loss: 0.0459\n",
"Epoch 46/100, Loss: 0.0437\n",
"Epoch 47/100, Loss: 0.0585\n",
"Epoch 48/100, Loss: 0.0469\n",
"Epoch 49/100, Loss: 0.0538\n",
"Epoch 50/100, Loss: 0.0592\n",
"Epoch 51/100, Loss: 0.0520\n",
"Epoch 52/100, Loss: 0.0582\n",
"Epoch 53/100, Loss: 0.0504\n",
"Epoch 54/100, Loss: 0.0471\n",
"Epoch 55/100, Loss: 0.0478\n",
"Epoch 56/100, Loss: 0.0487\n",
"Epoch 57/100, Loss: 0.0507\n",
"Epoch 58/100, Loss: 0.0500\n",
"Epoch 59/100, Loss: 0.0457\n",
"Epoch 60/100, Loss: 0.0493\n",
"Epoch 61/100, Loss: 0.0431\n",
"Epoch 62/100, Loss: 0.0503\n",
"Epoch 63/100, Loss: 0.0436\n",
"Epoch 64/100, Loss: 0.0512\n",
"Epoch 65/100, Loss: 0.0488\n",
"Epoch 66/100, Loss: 0.0436\n",
"Epoch 67/100, Loss: 0.0505\n",
"Epoch 68/100, Loss: 0.0389\n",
"Epoch 69/100, Loss: 0.0447\n",
"Epoch 70/100, Loss: 0.0442\n",
"Epoch 71/100, Loss: 0.0443\n",
"Epoch 72/100, Loss: 0.0460\n",
"Epoch 73/100, Loss: 0.0483\n",
"Epoch 74/100, Loss: 0.0538\n",
"Epoch 75/100, Loss: 0.0437\n",
"Epoch 76/100, Loss: 0.0500\n",
"Epoch 77/100, Loss: 0.0443\n",
"Epoch 78/100, Loss: 0.0461\n",
"Epoch 79/100, Loss: 0.0467\n",
"Epoch 80/100, Loss: 0.0472\n",
"Epoch 81/100, Loss: 0.0507\n",
"Epoch 82/100, Loss: 0.0454\n",
"Epoch 83/100, Loss: 0.0414\n",
"Epoch 84/100, Loss: 0.0468\n",
"Epoch 85/100, Loss: 0.0525\n",
"Epoch 86/100, Loss: 0.0456\n",
"Epoch 87/100, Loss: 0.0512\n",
"Epoch 88/100, Loss: 0.0487\n",
"Epoch 89/100, Loss: 0.0452\n",
"Epoch 90/100, Loss: 0.0436\n",
"Epoch 91/100, Loss: 0.0488\n",
"Epoch 92/100, Loss: 0.0471\n",
"Epoch 93/100, Loss: 0.0442\n",
"Epoch 94/100, Loss: 0.0435\n",
"Epoch 95/100, Loss: 0.0502\n",
"Epoch 96/100, Loss: 0.0430\n",
"Epoch 97/100, Loss: 0.0430\n",
"Epoch 98/100, Loss: 0.0424\n",
"Epoch 99/100, Loss: 0.0421\n",
"Epoch 100/100, Loss: 0.0410\n"
]
},
{
"data": {
"text/plain": [
"Mixtral(\n",
" (_token_embeddings): TokenEmbeddings(\n",
" (_embedding): Embedding(100, 256)\n",
" )\n",
" (_position_embeddings): RoPE()\n",
" (_dropout): Dropout(p=0.1, inplace=False)\n",
" (_decoders): ModuleList(\n",
" (0-3): 4 x Decoder(\n",
" (_heads): GroupedQueryAttention(\n",
" (_rope): RoPE()\n",
" (_q): Linear(in_features=256, out_features=256, bias=True)\n",
" (_k): Linear(in_features=256, out_features=128, bias=True)\n",
" (_v): Linear(in_features=256, out_features=128, bias=True)\n",
" (_layer): Linear(in_features=256, out_features=256, bias=True)\n",
" (_dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (_ff): MoE(\n",
" (_router): Linear(in_features=256, out_features=8, bias=True)\n",
" (_experts): ModuleList(\n",
" (0-7): 8 x SwiGLU(\n",
" (_gate): Linear(in_features=256, out_features=1024, bias=True)\n",
" (_up): Linear(in_features=256, out_features=1024, bias=True)\n",
" (_down): Linear(in_features=1024, out_features=256, bias=True)\n",
" (_activation): SiLU()\n",
" (_dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (_dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (_norm1): RMSNorm()\n",
" (_norm2): RMSNorm()\n",
" )\n",
" )\n",
" (_norm): RMSNorm()\n",
" (_linear): Linear(in_features=256, out_features=100, bias=True)\n",
")"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# 1. Исходный текст\n",
"text = \"Deep learning is amazing. Transformers changed the world. Attention is all you need. GPT models revolutionized NLP.\"\n",
"\n",
"# 2. Обучаем токенизатор\n",
"bpe = BPE(vocab_size=100)\n",
"bpe.fit(text)\n",
"\n",
"# 3. Создаем датасет\n",
"dataset = GPTDataset(text, bpe, block_size=8)\n",
"print(f\"Dataset length: {len(dataset)}\")\n",
"\n",
"# 4. Инициализируем модель\n",
"model = Mixtral(\n",
" vocab_size=len(bpe.vocab), # размер словаря BPE\n",
" max_seq_len=512, # GPT-2 использует контекст в 512 токена\n",
" emb_size=256, # размер эмбеддингов\n",
" num_q_heads=4, # количество голов внимания\n",
" num_kv_heads=2, # количество голов внимания\n",
" head_size=64, # размер каждой головы (256 / 4)\n",
" num_layers=4, # количество блоков Transformer\n",
" num_experts=8,\n",
" top_k_experts=2,\n",
" window_size=8,\n",
" dropout=0.1 # стандартный dropout GPT-2\n",
")\n",
"\n",
"# 5. Обучаем\n",
"train_mixtral(model, dataset, epochs=100, batch_size=4)"
]
},
{
"cell_type": "markdown",
"id": "72e91ceb",
"metadata": {},
"source": [
"\n",
"---\n",
"\n",
"### 5.2 Дообучение\n",
"\n",
"После предобучения LLAMA уже знает структуру и грамматику языка. \n",
"На втором этапе она дообучается на конкретных задачах (например, классификация, QA) с помощью размеченных данных.\n",
"\n",
"Технически это почти то же обучение, только:\n",
"\n",
"- Загружаем модель с уже обученными весами.\n",
"- Используем новые данные.\n",
"- Можно уменьшить скорость обучения.\n",
"- Иногда замораживают часть слоёв (например, эмбеддинги).\n"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "0a6f9730",
"metadata": {},
"outputs": [],
"source": [
"def fine_tune_mixtral(model, dataset, epochs=3, batch_size=16, lr=1e-5, device='cpu', freeze_embeddings=True):\n",
" if freeze_embeddings:\n",
" for param in model._token_embeddings.parameters():\n",
" param.requires_grad = False\n",
" for param in model._position_embeddings.parameters():\n",
" param.requires_grad = False\n",
"\n",
" dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)\n",
" optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=lr)\n",
"\n",
" model.to(device)\n",
" model.train()\n",
"\n",
" for epoch in range(epochs):\n",
" total_loss = 0\n",
" for x, y in dataloader:\n",
" x, y = x.to(device), y.to(device)\n",
" logits, _ = model(x, use_cache=False)\n",
" loss = F.cross_entropy(logits.view(-1, logits.size(-1)), y.view(-1))\n",
" optimizer.zero_grad()\n",
" loss.backward()\n",
" optimizer.step()\n",
" total_loss += loss.item()\n",
" print(f\"Fine-tune Epoch {epoch+1}/{epochs}, Loss: {total_loss / len(dataloader):.4f}\")"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "c9e28f47",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Fine-tune Epoch 1/10, Loss: 4.8596\n",
"Fine-tune Epoch 2/10, Loss: 2.6883\n",
"Fine-tune Epoch 3/10, Loss: 1.5315\n",
"Fine-tune Epoch 4/10, Loss: 1.1258\n",
"Fine-tune Epoch 5/10, Loss: 0.9248\n",
"Fine-tune Epoch 6/10, Loss: 0.7725\n",
"Fine-tune Epoch 7/10, Loss: 0.6405\n",
"Fine-tune Epoch 8/10, Loss: 0.5367\n",
"Fine-tune Epoch 9/10, Loss: 0.4469\n",
"Fine-tune Epoch 10/10, Loss: 0.4168\n"
]
}
],
"source": [
"# Например, мы хотим дообучить модель на стиле коротких технических фраз\n",
"fine_tune_text = \"\"\"\n",
"Transformers revolutionize NLP.\n",
"Deep learning enables self-attention.\n",
"GPT generates text autoregressively.\n",
"\"\"\"\n",
"\n",
"dataset = GPTDataset(fine_tune_text, bpe, block_size=8)\n",
"\n",
"\n",
"# Запуск дообучения\n",
"fine_tune_mixtral(model, dataset, epochs=10, batch_size=4, lr=1e-4)"
]
},
{
"cell_type": "markdown",
"id": "fd470fdd",
"metadata": {},
"source": [
"## 📝 6. Генерация текста после обучения"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "245dd064",
"metadata": {},
"outputs": [],
"source": [
"def generate_text(model, bpe, prompt: str, max_new_tokens=20, device='cpu'):\n",
" model.eval()\n",
" ids = torch.tensor([bpe.encode(prompt)], dtype=torch.long).to(device)\n",
" out = model.generate(ids, max_new_tokens=max_new_tokens, do_sample=True)\n",
" text = bpe.decode(out[0].tolist())\n",
" return text"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "1f5db85f",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Deep learning ena les self atelf a\n"
]
}
],
"source": [
"print(generate_text(model, bpe, \"Deep learning\", max_new_tokens=20))"
]
},
{
"cell_type": "markdown",
"id": "094ab394",
"metadata": {},
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.9"
}
},
"nbformat": 4,
"nbformat_minor": 5
}