mirror of
https://github.com/pese-git/llm-arch-research.git
synced 2026-01-23 13:00:54 +00:00
- Add Mixtral architecture implementation with MoE support (llm/src/llm/models/mixtral/mixtral.py) - Introduce generic Mixture-of-Experts (MoE) block (llm/src/llm/core/moe.py) - Create dedicated configuration files for Mixtral training and generation experiments - Register and test Mixtral support in experiment runner (run_llm_experiment.py) - Add unit tests for Mixtral API including forward, caching, and generation modes - Include Jupyter notebook mixstral.ipynb for architectural exploration and research - Ensure correct handling of torch bool masks in sampling (top-k, top-p) during generation BREAKING CHANGE: Adds new model code and test coverage, modifying experiment runner logic to register Mixtral.
1511 lines
70 KiB
Plaintext
1511 lines
70 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "c3276440",
|
||
"metadata": {},
|
||
"source": [
|
||
"# **Mixtral**\n",
|
||
"\n",
|
||
"<p align=\"center\">\n",
|
||
" <img src=\"https://ucarecdn.com/b3eb5219-4c5a-4cc1-9572-b072d262c2c1/\" alt=\"Mixtral\" width=\"1000\" height=\"165\">\n",
|
||
"</p>\n",
|
||
"\n",
|
||
"Mixtral 8x7B вышел в декабре 2023.\n",
|
||
"\n",
|
||
"В нем сохранились все архитектурные особенности первого Mistral: RoPE, SwiGLU, GQA и SWA. И добавилась одна суперзначимая новинка — Mixture-of-Experts (MoE). Основное предназначение MoE — ускорение инференса для очень крупных моделей (с большим багажом знаний).\n",
|
||
"\n",
|
||
"Технически MoE меняет облик FeedForward слоя, добавляя в него экспертов — по сути, кучу тех же FeedForward сетей. Но эти эксперты работают не все сразу. Специальная подпрограмма — роутер — решает, кого из них и когда вызывать. Тем самым мы экономим на инференсе.\n",
|
||
"\n",
|
||
"В Mixtral 8x7B было 8 экспертов, но в единицу времени работали только 2 из них. Общий вес модели — 47B параметров, но из-за того, что задействовано только 2 эксперта, по скорости инференс получается как у примерно 13B моделей. А качество сравнимо с 70B моделями (например, Llama 2 70B).\n",
|
||
"\n",
|
||
"В дальнейшем эта новинка станет очень популярна в больших языковых моделях.\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "e8dd2f09",
|
||
"metadata": {},
|
||
"source": [
|
||
"# Mixture-of-Experts\n",
|
||
"\n",
|
||
"<p align=\"center\">\n",
|
||
" <img src=\"https://ucarecdn.com/ce1ed721-579e-4b24-b9dc-1877051d246b/\" alt=\"MoE\" width=\"1000\" height=\"576\">\n",
|
||
"</p>\n",
|
||
"\n",
|
||
"Mixture-of-Experts (MoE) — это архитектура нейронных сетей, которая позволяет масштабировать модели до огромных размеров (триллионы параметров) без пропорционального увеличения вычислительных затрат. Достигается это за счет замены FeedForward слоя на слой MoE.\n",
|
||
"\n",
|
||
"Стандартная полносвязная нейросеть (Feed-Forward Network, FFN) в трансформере обрабатывает каждый входной токен одним и тем же набором весов. В MoE эта единственная сеть заменяется набором «экспертов» — небольших независимых нейросетей (обычно тех же FNN). И при проходе для каждого токена система динамически выбирает, каким экспертом или экспертами его обработать. В результате мы получаем два основных выигрыша:\n",
|
||
"\n",
|
||
"* Снижение нагрузки на систему и, как следствие, ускорение инференса.\n",
|
||
"* Повышение когнитивных способностей системы: потому что в каждый момент времени над токеном работают наиболее подходящие для этого эксперты.\n",
|
||
"\n",
|
||
"> Идея с MoE рано или поздно должна была появиться. На FNN приходится примерно 2/3 всех параметров современных LLM — это очень существенно. И причина этого известна: в FNN хранится основная часть знаний модели. И попытка сэкономить на таком большом количестве параметров является очевидным шагом.\n",
|
||
"\n",
|
||
"### Алгоритм\n",
|
||
"\n",
|
||
"MoE состоит из трех основных компонентов: маршрутизатора, экспертов и процедуры взвешивания. Работают они следующим образом:\n",
|
||
"\n",
|
||
"* На вход поступает тензор `x`.\n",
|
||
"* Маршрутизатор (Router или Gating Network) — это небольшая нейросеть. Часто просто один линейный слой. Функция маршрутизатора — решать, каким экспертам обработать `x`. Для этого по каждому токену маршрутизатор выдает: топ-k экспертов и их веса.\n",
|
||
"* Эксперты (Experts) — это несколько идентичных по структуре, но разных по весам нейронных сетей. Обычно в их роли выступают стандартные FFN-слои. Эксперты параллельны и не зависят друг от друга. Каждый токен из `x` пропускается только через назначенных ему экспертов — эксперты возвращают результат.\n",
|
||
"* Выходной тензор — это взвешенная сумма выходов всех экспертов. Веса для взвешивания мы получаем от маршрутизатора (по каждому токену и каждому эксперту).\n",
|
||
"\n",
|
||
"> Тут нужно сделать лирическое отступление: хотя сами вычисления происходят заметно быстрее (из-за того что в единицу времени отрабатывает только часть экспертов), но вам все равно потребуется разместить в памяти GPU абсолютно всех экспертов одновременно.\n",
|
||
"\n",
|
||
"### Экспертность\n",
|
||
"\n",
|
||
"Технически при обучении LLM эксперты в MoE никак принудительно не «специализируются». Т.е. все эксперты независимы и равны между собой. Но исследования показывают, что эксперты часто самоорганизуются и сами специализируются на определенных типах данных или задачах:\n",
|
||
"\n",
|
||
"* Эксперт 1 — хорошо работает с техническим текстом.\n",
|
||
"* Эксперт 2 — лучше справляется с художественным текстом.\n",
|
||
"* Эксперт 3 — специализируется на грамматике.\n",
|
||
"* Эксперт 4 — хорошо обрабатывает имена собственные.\n",
|
||
"* Эксперт 5 — лучше понимает азиатские языки и т.д.\n",
|
||
"\n",
|
||
"И встречая очередной токен, роутер, на основе его содержимого, решает, к какому эксперту его отправить.\n",
|
||
"\n",
|
||
"> При обучении роутер может неравномерно распределять внимание между экспертами. Может даже отправлять все данные к одному эксперту. Лечится это специальными балансными лоссами.\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "8b5851df",
|
||
"metadata": {},
|
||
"source": [
|
||
"**MoE класс (разработка)**"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 8,
|
||
"id": "982c3595",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"import torch\n",
|
||
"from torch import nn\n",
|
||
"\n",
|
||
"class SiLU(nn.Module):\n",
|
||
" def forward(self, x: torch.Tensor): # [batch_size × seq_len × emb_size]\n",
|
||
" return torch.sigmoid(x) * x\n",
|
||
" \n",
|
||
"class SwiGLU(nn.Module):\n",
|
||
" def __init__(self, emb_size: int, dropout: float = 0.1):\n",
|
||
" super().__init__()\n",
|
||
"\n",
|
||
" self._gate = nn.Linear(emb_size, 4 * emb_size)\n",
|
||
" self._up = nn.Linear(emb_size, 4 * emb_size)\n",
|
||
" self._down = nn.Linear(4 * emb_size, emb_size)\n",
|
||
" self._activation = SiLU()\n",
|
||
" self._dropout = nn.Dropout(dropout)\n",
|
||
"\n",
|
||
" def forward(self, x: torch.Tensor): # [batch_size × seq_len × emb_size].\n",
|
||
" gate_out = self._gate(x) # [batch, seq, 4*emb]\n",
|
||
" activation_out = self._activation(gate_out) # [batch, seq, 4*emb]\n",
|
||
" up_out = self._up(x) # [batch, seq, 4*emb]\n",
|
||
" out = up_out * activation_out # поэлементное!\n",
|
||
" out = self._down(out) # [batch, seq, emb]\n",
|
||
" return self._dropout(out)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 9,
|
||
"id": "34fac9f7",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"import torch\n",
|
||
"from torch import nn\n",
|
||
"import torch.nn.functional as F\n",
|
||
"\n",
|
||
"class MoE(nn.Module):\n",
|
||
" def __init__(\n",
|
||
" self,\n",
|
||
" emb_size: int,\n",
|
||
" num_experts: int,\n",
|
||
" top_k_experts: int,\n",
|
||
" dropout: float = 0.1,\n",
|
||
" ):\n",
|
||
" super().__init__()\n",
|
||
" self._num_experts = num_experts\n",
|
||
" self._top_k_experts = top_k_experts\n",
|
||
"\n",
|
||
" self._router = nn.Linear(emb_size, num_experts)\n",
|
||
" self._experts = nn.ModuleList([SwiGLU(\n",
|
||
" emb_size=emb_size,\n",
|
||
" dropout=dropout,\n",
|
||
" ) for _ in range(num_experts)])\n",
|
||
" self._dropout = nn.Dropout(dropout)\n",
|
||
"\n",
|
||
" def forward(self, x: torch.Tensor):\n",
|
||
" batch_size, seq_len, emb_size = x.shape\n",
|
||
" \n",
|
||
" # 1. Пропускаем через роутер\n",
|
||
" router_logits = self._router(x) # [batch_size, seq_len, num_experts]\n",
|
||
" \n",
|
||
" # 2. Отбираем топ-k экспертов для каждого токена\n",
|
||
" topk_logits, topk_indices = torch.topk(\n",
|
||
" router_logits, \n",
|
||
" k=self._top_k_experts, \n",
|
||
" dim=-1\n",
|
||
" ) # topk_logits: [batch_size, seq_len, top_k]\n",
|
||
" # topk_indices: [batch_size, seq_len, top_k]\n",
|
||
" \n",
|
||
" # 3. Получаем веса через softmax и нормируем\n",
|
||
" topk_weights = F.softmax(topk_logits, dim=-1) # [batch_size, seq_len, top_k]\n",
|
||
" \n",
|
||
" # 4. Создаём нулевой тензор для результата\n",
|
||
" output = torch.zeros_like(x) # [batch_size, seq_len, emb_size] \n",
|
||
"\n",
|
||
" # 5. Проходим по всем экспертам\n",
|
||
" for expert_id in range(self._num_experts):\n",
|
||
" # Шаг 1: Создаём маску - где находится текущий эксперт в топ-k\n",
|
||
" expert_mask = (topk_indices == expert_id) # [batch_size, seq_len, top_k]\n",
|
||
" # Шаг 2: Проверяем, выбран ли эксперт хотя бы одним токеном\n",
|
||
" if not expert_mask.any():\n",
|
||
" continue # Эксперт никем не выбран, переходим к следующему\n",
|
||
"\n",
|
||
" # Шаг 3: Находим токены, которые выбрали этого эксперта\n",
|
||
" # (хотя бы в одной из top_k позиций)\n",
|
||
" token_mask = expert_mask.any(dim=-1) # [batch_size, seq_len]\n",
|
||
"\n",
|
||
" # Шаг 4: Отбираем токены из x\n",
|
||
" # Отбираем токены для этого эксперта\n",
|
||
" expert_input = x[token_mask]\n",
|
||
"\n",
|
||
" # Пропускаем через эксперта\n",
|
||
" # Добавляем batch dimension для SwiGLU и затем убираем\n",
|
||
" expert_output = self._experts[expert_id](\n",
|
||
" expert_input.unsqueeze(0)\n",
|
||
" ).squeeze(0)\n",
|
||
"\n",
|
||
" # Получаем веса для этого эксперта\n",
|
||
" # Для каждого токена может быть несколько весов (если эксперт в топ-k несколько раз)\n",
|
||
" # Но на практике каждый эксперт появляется максимум 1 раз в топ-k\n",
|
||
" # Находим веса: где expert_mask == True, берём соответствующий вес\n",
|
||
" weights_for_expert = torch.zeros(\n",
|
||
" batch_size, seq_len, device=x.device\n",
|
||
" )\n",
|
||
"\n",
|
||
" # Для каждой позиции в топ-k\n",
|
||
" for k in range(self._top_k_experts):\n",
|
||
" mask_k = topk_indices[:, :, k] == expert_id\n",
|
||
" weights_for_expert[mask_k] = topk_weights[:, :, k][mask_k]\n",
|
||
"\n",
|
||
" # Отбираем только веса для выбранных токенов\n",
|
||
" selected_weights = weights_for_expert[token_mask] # [num_selected_tokens]\n",
|
||
"\n",
|
||
"\n",
|
||
" # Перемножьте выход эксперта на веса текущего эксперта.\n",
|
||
" weighted_output = selected_weights.unsqueeze(-1) * expert_output\n",
|
||
"\n",
|
||
" # Помещаем результат на своё место в выходном тензоре\n",
|
||
" output[token_mask] += weighted_output\n",
|
||
" \n",
|
||
" out = self._dropout(output)\n",
|
||
"\n",
|
||
" return out"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "d9135164",
|
||
"metadata": {},
|
||
"source": []
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "bef0c904",
|
||
"metadata": {},
|
||
"source": [
|
||
"## Full Model"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 10,
|
||
"id": "35e52050",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"import torch\n",
|
||
"from torch import nn\n",
|
||
"from torch import Tensor\n",
|
||
"import torch.nn.functional as F\n",
|
||
"from math import sqrt\n",
|
||
"\n",
|
||
"\n",
|
||
"\n",
|
||
"class SiLU(nn.Module):\n",
|
||
" def forward(self, x: torch.Tensor): # [batch_size × seq_len × emb_size]\n",
|
||
" return torch.sigmoid(x) * x\n",
|
||
" \n",
|
||
"class RMSNorm(nn.Module):\n",
|
||
" def __init__(self, dim: int, eps: float = 1e-6):\n",
|
||
" super().__init__()\n",
|
||
" self._eps = eps\n",
|
||
" self._w = nn.Parameter(torch.ones(dim))\n",
|
||
" \n",
|
||
" def forward(self, x: torch.Tensor): # [batch_size × seq_len × emb_size]\n",
|
||
" rms = (x.pow(2).mean(-1, keepdim=True) + self._eps) ** 0.5\n",
|
||
" norm_x = x / rms\n",
|
||
" return self._w * norm_x\n",
|
||
"\n",
|
||
"class SwiGLU(nn.Module):\n",
|
||
" def __init__(self, emb_size: int, dropout: float = 0.1):\n",
|
||
" super().__init__()\n",
|
||
"\n",
|
||
" self._gate = nn.Linear(emb_size, 4 * emb_size)\n",
|
||
" self._up = nn.Linear(emb_size, 4 * emb_size)\n",
|
||
" self._down = nn.Linear(4 * emb_size, emb_size)\n",
|
||
" self._activation = SiLU()\n",
|
||
" self._dropout = nn.Dropout(dropout)\n",
|
||
"\n",
|
||
" def forward(self, x: torch.Tensor): # [batch_size × seq_len × emb_size].\n",
|
||
" gate_out = self._gate(x) # [batch, seq, 4*emb]\n",
|
||
" activation_out = self._activation(gate_out) # [batch, seq, 4*emb]\n",
|
||
" up_out = self._up(x) # [batch, seq, 4*emb]\n",
|
||
" out = up_out * activation_out # поэлементное!\n",
|
||
" out = self._down(out) # [batch, seq, emb]\n",
|
||
" return self._dropout(out)\n",
|
||
"\n",
|
||
"\n",
|
||
"class TokenEmbeddings(nn.Module):\n",
|
||
" def __init__(self, vocab_size: int, emb_size: int):\n",
|
||
" super().__init__()\n",
|
||
" self._embedding = nn.Embedding(\n",
|
||
" num_embeddings=vocab_size,\n",
|
||
" embedding_dim=emb_size\n",
|
||
" )\n",
|
||
"\n",
|
||
" def forward(self, x: Tensor) -> Tensor:\n",
|
||
" return self._embedding(x)\n",
|
||
"\n",
|
||
" @property\n",
|
||
" def num_embeddings(self) -> int:\n",
|
||
" return self._embedding.num_embeddings\n",
|
||
"\n",
|
||
" @property\n",
|
||
" def embedding_dim(self) -> int:\n",
|
||
" return self._embedding.embedding_dim\n",
|
||
"\n",
|
||
"\n",
|
||
"class GELU(nn.Module):\n",
|
||
" def __init__(self):\n",
|
||
" super().__init__()\n",
|
||
" self.sqrt_2_over_pi = torch.sqrt(torch.tensor(2.0) / math.pi)\n",
|
||
" \n",
|
||
" def forward(self, x: torch.Tensor) -> torch.Tensor:\n",
|
||
" return 0.5 * x * (1 + torch.tanh(\n",
|
||
" self.sqrt_2_over_pi * (x + 0.044715 * torch.pow(x, 3))\n",
|
||
" ))\n",
|
||
"\n",
|
||
" \n",
|
||
" \n",
|
||
"import torch\n",
|
||
"from torch import nn\n",
|
||
"from typing import Optional\n",
|
||
"\n",
|
||
"\n",
|
||
"class RoPE(nn.Module):\n",
|
||
"\n",
|
||
" def __init__(self, head_size: int, max_seq_len: int, base: int = 10_000):\n",
|
||
" super().__init__()\n",
|
||
" assert head_size % 2 == 0, \"head_size должен быть четным\"\n",
|
||
"\n",
|
||
" # Вычисление частот: θ_i = base^(-2i/d) для i ∈ [0, d/2-1]\n",
|
||
" freqs = 1.0 / (base ** (2 * torch.arange(head_size // 2).float() / head_size))\n",
|
||
"\n",
|
||
" # Позиции от 0 до max_seq_len-1\n",
|
||
" positions = torch.arange(max_seq_len).float()\n",
|
||
"\n",
|
||
" # Внешнее произведение: m * θ_i для всех позиций и частот\n",
|
||
" freq_matrix = positions.unsqueeze(1) * freqs.unsqueeze(0)\n",
|
||
"\n",
|
||
" # Предвычисление матриц косинусов и синусов\n",
|
||
" self.register_buffer(\"cos_matrix\", torch.cos(freq_matrix))\n",
|
||
" self.register_buffer(\"sin_matrix\", torch.sin(freq_matrix))\n",
|
||
"\n",
|
||
" def forward(self, x: torch.Tensor, start_pos: int = 0) -> torch.Tensor: # [batch_size × seq_len × head_size] [batch_size × num_heads × seq_len × head_size]\n",
|
||
" batch_size, num_heads, seq_len, head_size = x.shape\n",
|
||
"\n",
|
||
" # Берем нужную часть матриц и приводим к типу x\n",
|
||
" cos = self.cos_matrix[start_pos:start_pos+seq_len].to(x.dtype) # [seq_len, head_size//2]\n",
|
||
" sin = self.sin_matrix[start_pos:start_pos+seq_len].to(x.dtype) # [seq_len, head_size//2]\n",
|
||
"\n",
|
||
" # Явное изменение формы для broadcasting\n",
|
||
" cos = cos.reshape(1, 1, seq_len, head_size // 2)\n",
|
||
" sin = sin.reshape(1, 1, seq_len, head_size // 2)\n",
|
||
"\n",
|
||
" # Разделяем на четные и нечетные компоненты по ПОСЛЕДНЕМУ измерению\n",
|
||
" x_even = x[..., 0::2] # [batch_size, num_heads, seq_len, head_size//2]\n",
|
||
" x_odd = x[..., 1::2] # [batch_size, num_heads, seq_len, head_size//2]\n",
|
||
"\n",
|
||
" # Применяем поворот: q' = q * cos(mθ) + rotate(q) * sin(mθ)\n",
|
||
" x_rotated_even = x_even * cos - x_odd * sin\n",
|
||
" x_rotated_odd = x_even * sin + x_odd * cos\n",
|
||
"\n",
|
||
" # Объединяем обратно в исходную размерность\n",
|
||
" x_rotated = torch.stack([x_rotated_even, x_rotated_odd], dim=-1)\n",
|
||
" x_rotated = x_rotated.flatten(-2) # [batch_size, seq_len, head_size]\n",
|
||
"\n",
|
||
" return x_rotated\n",
|
||
"\n",
|
||
"\n",
|
||
"import torch\n",
|
||
"from torch import nn\n",
|
||
"import torch.nn.functional as F\n",
|
||
"from typing import Optional, Tuple\n",
|
||
"\n",
|
||
"\n",
|
||
" \n",
|
||
"class GroupedQueryAttention(nn.Module):\n",
|
||
"\n",
|
||
" def __init__(\n",
|
||
" self,\n",
|
||
" num_q_heads: int,\n",
|
||
" num_kv_heads: int,\n",
|
||
" emb_size: int,\n",
|
||
" head_size: int,\n",
|
||
" max_seq_len: int,\n",
|
||
" window_size: int,\n",
|
||
" rope: RoPE = None,\n",
|
||
" dropout: float = 0.1,\n",
|
||
" ):\n",
|
||
" super().__init__()\n",
|
||
" self._num_heads = num_q_heads\n",
|
||
" self._num_kv_heads = num_kv_heads\n",
|
||
" self._head_size = head_size\n",
|
||
" self._max_seq_len = max_seq_len\n",
|
||
" self._rope = rope\n",
|
||
" self._window_size = window_size\n",
|
||
"\n",
|
||
" self._q = nn.Linear(emb_size, self._num_heads * head_size)\n",
|
||
" self._k = nn.Linear(emb_size, num_kv_heads * head_size)\n",
|
||
" self._v = nn.Linear(emb_size, num_kv_heads * head_size)\n",
|
||
"\n",
|
||
" # Создание causal маски\n",
|
||
" mask = self._create_sliding_window_mask(max_seq_len, self._window_size)\n",
|
||
" self.register_buffer(\n",
|
||
" \"_tril_mask\", mask.bool() if hasattr(torch, \"bool\") else mask.byte()\n",
|
||
" )\n",
|
||
" \n",
|
||
" self._layer = nn.Linear(head_size * self._num_heads, emb_size)\n",
|
||
" self._dropout = nn.Dropout(dropout)\n",
|
||
"\n",
|
||
" def forward(\n",
|
||
" self,\n",
|
||
" x: torch.Tensor,\n",
|
||
" mask: torch.Tensor = None,\n",
|
||
" use_cache: bool = True,\n",
|
||
" cache: list = None,\n",
|
||
" ):\n",
|
||
" batch_size, seq_len, emb_size = x.shape\n",
|
||
"\n",
|
||
" if seq_len > self._max_seq_len:\n",
|
||
" raise ValueError(\n",
|
||
" f\"Длина последовательности {seq_len} превышает максимум {self._max_seq_len}\"\n",
|
||
" )\n",
|
||
"\n",
|
||
" # Пропустите тензор x через матрицы Wq, Wk , Wv, чтобы получить матрицы запроса, ключа и значения.\n",
|
||
" k = self._k(x) # [B, T, hs]\n",
|
||
" q = self._q(x) # [B, T, hs]\n",
|
||
" v = self._v(x) # [B, T, hs]\n",
|
||
"\n",
|
||
" # Шаг 2: Изменение формы для multi-head\n",
|
||
" # [batch_size, seq_len, num_heads * head_size] \n",
|
||
" # -> [batch_size, seq_len, num_heads, head_size]\n",
|
||
" # Измените форму запроса (query) на batch_size × num_q_heads × seq_len × head_size.\n",
|
||
" q = q.reshape(batch_size, seq_len, self._num_heads, self._head_size)\n",
|
||
"\n",
|
||
" # Измените форму ключа (key) и значения (value) на batch_size × num_kv_heads × seq_len × head_size.\n",
|
||
" k = k.reshape(batch_size, seq_len, self._num_kv_heads, self._head_size)\n",
|
||
" v = v.reshape(batch_size, seq_len, self._num_kv_heads, self._head_size)\n",
|
||
" \n",
|
||
"\n",
|
||
" # 3. Transpose: [B, T, H, hs] -> [B, H, T, hs]\n",
|
||
" q = q.transpose(1, 2)\n",
|
||
" k = k.transpose(1, 2)\n",
|
||
" v = v.transpose(1, 2)\n",
|
||
"\n",
|
||
" start_pos = 0\n",
|
||
" if cache is not None:\n",
|
||
" k_cache, v_cache = cache\n",
|
||
" cache_len = k_cache.shape[2]\n",
|
||
" start_pos = cache_len\n",
|
||
" \n",
|
||
" # Пропустите матрицы запроса и ключа через экземпляр rope, чтобы выполнить поворот.\n",
|
||
" if self._rope is not None:\n",
|
||
" # Применяем RoPE к Q и K (НЕ к V!)\n",
|
||
" q = self._rope(q, start_pos=start_pos) # [B, T, hs]\n",
|
||
" k = self._rope(k, start_pos=start_pos) # [B, T, hs]\n",
|
||
"\n",
|
||
" # Если cache пришел, то объединяем кэш и одну строку из ключа и значения. Это будут новые key и value для последующих вычислений.\n",
|
||
" # 5. Кэширование (для autoregressive generation)\n",
|
||
" if cache is not None:\n",
|
||
" k_cache, v_cache = cache\n",
|
||
" k = torch.cat([k_cache, k], dim=2) # Concat по seq_len (dim=2)\n",
|
||
" v = torch.cat([v_cache, v], dim=2)\n",
|
||
"\n",
|
||
" # Если use_cache == True, то сохраните матрицы ключа и значения для кэша (это нужно сделать до дублирования голов).\n",
|
||
" #if use_cache == True:\n",
|
||
" # # Обрезаем до последних window_size токенов\n",
|
||
" # k_to_cache = k[:, :, -self._window_size:, :]\n",
|
||
" # v_to_cache = v[:, :, -self._window_size:, :]\n",
|
||
" # kv_cache = (k_to_cache, v_to_cache)\n",
|
||
"\n",
|
||
" # Продублируйте головы в тензорах ключа (key) и значения (value), чтобы получился тензор размера на batch_size × num_q_heads × seq_len × head_size.\n",
|
||
" #k = self._repeat_kv_heads(k, self._num_heads, self._num_kv_heads)\n",
|
||
" #v = self._repeat_kv_heads(v, self._num_heads, self._num_kv_heads)\n",
|
||
" k_expanded = self._repeat_kv_heads(k, self._num_heads, self._num_kv_heads)\n",
|
||
" v_expanded = self._repeat_kv_heads(v, self._num_heads, self._num_kv_heads)\n",
|
||
" \n",
|
||
" # Перемножим матрицы запроса и ключа (транспонированную), чтобы вычислить матрицу внимания.\n",
|
||
" # И разделить все значения в матрице внимания на корень из head_size.\n",
|
||
" scores = q @ k_expanded.transpose(-2, -1) / (self._head_size ** 0.5)\n",
|
||
"\n",
|
||
" # 8. Применение маски\n",
|
||
" k_seq_len = k_expanded.size(2) # Длина K после concat с кэшем\n",
|
||
" \n",
|
||
" if cache is None:\n",
|
||
" # Случай 1: Без кэша - полная квадратная маска\n",
|
||
" # scores: [B, H, seq_len, seq_len]\n",
|
||
" # Применяем маску [:seq_len, :seq_len]\n",
|
||
" scores = scores.masked_fill(\n",
|
||
" ~self._tril_mask[:seq_len, :seq_len], \n",
|
||
" float(\"-inf\")\n",
|
||
" )\n",
|
||
"\n",
|
||
" # Применить к матрице внимания (построчно) функцию Softmax.\n",
|
||
" weights = F.softmax(scores, dim=-1)\n",
|
||
"\n",
|
||
" # Перемножим матрицу внимания и матрицу значения.\n",
|
||
" x_out = weights @ v_expanded # [B, T, hs]\n",
|
||
"\n",
|
||
" # Измените форму тензора на batch_size × seq_len × num_heads*head_size.\n",
|
||
" # Transpose обратно и concatenate heads\n",
|
||
" x_out = x_out.transpose(1, 2) # [B, T_q, H, hs]\n",
|
||
" x_out = x_out.contiguous() # Важно для reshape!\n",
|
||
" concatenated_attention = x_out.reshape(batch_size, seq_len, self._num_heads * self._head_size)\n",
|
||
"\n",
|
||
" #concatenated_attention = x_out.reshape(batch_size, seq_len, self._num_heads * self._head_size)\n",
|
||
"\n",
|
||
" # Пропустите получившийся тензор через последний линейный слой.\n",
|
||
" # 3. Проецируем в пространство эмбеддингов\n",
|
||
" projected_output = self._layer(concatenated_attention)\n",
|
||
"\n",
|
||
" # 4. Применяем dropout для регуляризации\n",
|
||
" output = self._dropout(projected_output)\n",
|
||
"\n",
|
||
" if use_cache:\n",
|
||
" # Обрезаем оригинальный K и V (до дублирования)\n",
|
||
" k_to_cache = k[:, :, -self._window_size:, :]\n",
|
||
" v_to_cache = v[:, :, -self._window_size:, :]\n",
|
||
" kv_cache = (k_to_cache, v_to_cache)\n",
|
||
" return output, kv_cache\n",
|
||
" else:\n",
|
||
" return output, None\n",
|
||
"\n",
|
||
" def _repeat_kv_heads(\n",
|
||
" self,\n",
|
||
" kv: torch.Tensor,\n",
|
||
" num_q_heads: int,\n",
|
||
" num_kv_heads: int\n",
|
||
" ) -> torch.Tensor:\n",
|
||
" \"\"\"\n",
|
||
" Дублирует головы K/V для соответствия количеству голов Q.\n",
|
||
"\n",
|
||
" Args:\n",
|
||
" kv: [batch_size, num_kv_heads, seq_len, head_size]\n",
|
||
" num_q_heads: Количество голов Query (например, 8)\n",
|
||
" num_kv_heads: Количество голов Key/Value (например, 2)\n",
|
||
"\n",
|
||
" Returns:\n",
|
||
" [batch_size, num_q_heads, seq_len, head_size]\n",
|
||
"\n",
|
||
" Example:\n",
|
||
" num_q_heads=8, num_kv_heads=2\n",
|
||
" Каждая голова KV дублируется 4 раза:\n",
|
||
" [KV0, KV1] -> [KV0, KV0, KV0, KV0, KV1, KV1, KV1, KV1]\n",
|
||
" \"\"\"\n",
|
||
" batch_size, num_kv_heads, seq_len, head_size = kv.shape\n",
|
||
"\n",
|
||
" if num_q_heads == num_kv_heads:\n",
|
||
" # Нет необходимости дублировать\n",
|
||
" return kv\n",
|
||
"\n",
|
||
" # Вычисляем сколько раз нужно повторить каждую голову\n",
|
||
" num_repeats = num_q_heads // num_kv_heads\n",
|
||
"\n",
|
||
" # repeat_interleave дублирует каждую голову num_repeats раз\n",
|
||
" # [B, num_kv_heads, S, hs] -> [B, num_q_heads, S, hs]\n",
|
||
" # [B, num_kv_heads, S, hs] -> [B, num_kv_heads, 1, S, hs]\n",
|
||
" kv = kv.unsqueeze(2)\n",
|
||
" \n",
|
||
" # [B, num_kv_heads, 1, S, hs] -> [B, num_kv_heads, num_repeats, S, hs]\n",
|
||
" kv = kv.repeat(1, 1, num_repeats, 1, 1)\n",
|
||
" \n",
|
||
" # [B, num_kv_heads, num_repeats, S, hs] -> [B, num_q_heads, S, hs]\n",
|
||
" kv = kv.reshape(batch_size, num_q_heads, seq_len, head_size)\n",
|
||
" \n",
|
||
"\n",
|
||
" return kv\n",
|
||
"\n",
|
||
" def _create_sliding_window_mask(\n",
|
||
" self,\n",
|
||
" max_seq_len: int,\n",
|
||
" window_size: int,\n",
|
||
" device: torch.device = None\n",
|
||
" ) -> torch.Tensor:\n",
|
||
" \"\"\"\n",
|
||
" Создает маску для Sliding Window Attention.\n",
|
||
"\n",
|
||
" Args:\n",
|
||
" max_seq_len: Максимальная длина последовательности\n",
|
||
" window_size: Размер окна внимания\n",
|
||
" device: Устройство для размещения тензора\n",
|
||
"\n",
|
||
" Returns:\n",
|
||
" Маска формы [max_seq_len, max_seq_len], где True = разрешено\n",
|
||
"\n",
|
||
" Example:\n",
|
||
" >>> mask = create_sliding_window_mask(8, 3)\n",
|
||
" >>> print(mask.int())\n",
|
||
" tensor([[1, 0, 0, 0, 0, 0, 0, 0],\n",
|
||
" [1, 1, 0, 0, 0, 0, 0, 0],\n",
|
||
" [1, 1, 1, 0, 0, 0, 0, 0],\n",
|
||
" [0, 1, 1, 1, 0, 0, 0, 0],\n",
|
||
" [0, 0, 1, 1, 1, 0, 0, 0],\n",
|
||
" [0, 0, 0, 1, 1, 1, 0, 0],\n",
|
||
" [0, 0, 0, 0, 1, 1, 1, 0],\n",
|
||
" [0, 0, 0, 0, 0, 1, 1, 1]])\n",
|
||
" \"\"\"\n",
|
||
" row_indices = torch.arange(max_seq_len, device=device).unsqueeze(1) # [max_seq_len, 1]\n",
|
||
" col_indices = torch.arange(max_seq_len, device=device).unsqueeze(0) # [1, max_seq_len]\n",
|
||
"\n",
|
||
" causal_mask = col_indices <= row_indices\n",
|
||
"\n",
|
||
" window_mask = (row_indices - col_indices) <= window_size\n",
|
||
"\n",
|
||
" mask = causal_mask & window_mask\n",
|
||
" \n",
|
||
" return mask\n",
|
||
"\n",
|
||
"class Decoder(nn.Module):\n",
|
||
" def __init__(self, \n",
|
||
" num_q_heads: int,\n",
|
||
" num_kv_heads: int,\n",
|
||
" emb_size: int,\n",
|
||
" head_size: int,\n",
|
||
" max_seq_len: int,\n",
|
||
" num_experts: int,\n",
|
||
" top_k_experts: int,\n",
|
||
" window_size: int,\n",
|
||
" rope: RoPE,\n",
|
||
" dropout: float = 0.1\n",
|
||
" ):\n",
|
||
" super().__init__()\n",
|
||
" self._heads = GroupedQueryAttention(\n",
|
||
" num_q_heads=num_q_heads, \n",
|
||
" num_kv_heads=num_kv_heads,\n",
|
||
" emb_size=emb_size, \n",
|
||
" head_size=head_size, \n",
|
||
" max_seq_len=max_seq_len,\n",
|
||
" window_size=window_size,\n",
|
||
" rope=rope,\n",
|
||
" dropout=dropout\n",
|
||
" )\n",
|
||
" self._ff = MoE(\n",
|
||
" emb_size=emb_size, \n",
|
||
" num_experts=num_experts,\n",
|
||
" top_k_experts=top_k_experts,\n",
|
||
" dropout=dropout\n",
|
||
" )\n",
|
||
" self._norm1 = RMSNorm(emb_size)\n",
|
||
" self._norm2 = RMSNorm(emb_size)\n",
|
||
"\n",
|
||
" def forward(self, x: torch.Tensor, mask: torch.Tensor = None, use_cache: bool = True, cache: list = None) -> torch.Tensor:\n",
|
||
" norm1_out = self._norm1(x)\n",
|
||
" attention, kv_caches = self._heads(norm1_out, mask, use_cache=use_cache, cache=cache)\n",
|
||
" out = attention + x\n",
|
||
" \n",
|
||
" norm2_out = self._norm2(out)\n",
|
||
" ffn_out = self._ff(norm2_out)\n",
|
||
"\n",
|
||
" if use_cache is True:\n",
|
||
" return (ffn_out + out, kv_caches)\n",
|
||
" else:\n",
|
||
" return (ffn_out + out, None)\n",
|
||
"\n",
|
||
"\n",
|
||
"\n",
|
||
"from torch import nn\n",
|
||
"import torch\n",
|
||
"import torch.nn.functional as F\n",
|
||
"\n",
|
||
"class Mixtral(nn.Module):\n",
|
||
" def __init__(self,\n",
|
||
" vocab_size: int,\n",
|
||
" max_seq_len: int,\n",
|
||
" emb_size: int,\n",
|
||
" num_q_heads: int,\n",
|
||
" num_kv_heads: int,\n",
|
||
" head_size: int,\n",
|
||
" num_layers: int,\n",
|
||
" num_experts: int,\n",
|
||
" top_k_experts: int,\n",
|
||
" window_size: int,\n",
|
||
" dropout: float = 0.1,\n",
|
||
" device: str = 'cpu'\n",
|
||
" ):\n",
|
||
" super().__init__()\n",
|
||
" self._vocab_size = vocab_size\n",
|
||
" self._max_seq_len = max_seq_len\n",
|
||
" self._emb_size = emb_size\n",
|
||
" self._num_q_heads = num_q_heads\n",
|
||
" self._num_kv_heads = num_kv_heads\n",
|
||
" self._head_size = head_size\n",
|
||
" self._num_layers = num_layers\n",
|
||
" self._dropout = dropout\n",
|
||
" self._device = device\n",
|
||
" \n",
|
||
" self.validation_loss = None\n",
|
||
"\n",
|
||
" # Инициализация слоев\n",
|
||
" self._token_embeddings = TokenEmbeddings(\n",
|
||
" vocab_size=vocab_size, \n",
|
||
" emb_size=emb_size\n",
|
||
" )\n",
|
||
" self._position_embeddings = RoPE(\n",
|
||
" head_size=head_size,\n",
|
||
" max_seq_len=max_seq_len\n",
|
||
" )\n",
|
||
" #self._position_embeddings = PositionalEmbeddings(\n",
|
||
" # max_seq_len=max_seq_len, \n",
|
||
" # emb_size=emb_size\n",
|
||
" #)\n",
|
||
" self._dropout = nn.Dropout(dropout)\n",
|
||
" self._decoders = nn.ModuleList([Decoder(\n",
|
||
" num_q_heads=num_q_heads,\n",
|
||
" num_kv_heads=num_kv_heads,\n",
|
||
" emb_size=emb_size,\n",
|
||
" head_size=head_size,\n",
|
||
" max_seq_len=max_seq_len,\n",
|
||
" num_experts=num_experts,\n",
|
||
" top_k_experts=top_k_experts,\n",
|
||
" window_size=window_size,\n",
|
||
" rope=self._position_embeddings,\n",
|
||
" dropout=dropout \n",
|
||
" ) for _ in range(num_layers)])\n",
|
||
" self._norm = RMSNorm(emb_size)\n",
|
||
" self._linear = nn.Linear(emb_size, vocab_size)\n",
|
||
"\n",
|
||
" def forward(self, x: torch.Tensor, use_cache: bool = True, cache: list = None) -> tuple:\n",
|
||
" # Проверка длины последовательности (только при отсутствии кэша)\n",
|
||
" if cache is None and x.size(1) > self._max_seq_len:\n",
|
||
" raise ValueError(f\"Длина последовательности {x.size(1)} превышает максимальную {self.max_seq_len}\")\n",
|
||
" \n",
|
||
" # Эмбеддинги токенов и позиций\n",
|
||
" tok_out = self._token_embeddings(x) # [batch, seq_len, emb_size]\n",
|
||
" #pos_out = self._position_embeddings(x) # [batch, seq_len, emb_size]\n",
|
||
" \n",
|
||
" # Комбинирование\n",
|
||
" out = self._dropout(tok_out) # [batch, seq_len, emb_size]\n",
|
||
" \n",
|
||
" # Стек декодеров с передачей кэша\n",
|
||
" new_cache = []\n",
|
||
" for i, decoder in enumerate(self._decoders):\n",
|
||
" decoder_cache = cache[i] if cache is not None else None\n",
|
||
" decoder_result = decoder(out, use_cache=use_cache, cache=decoder_cache)\n",
|
||
"\n",
|
||
" # Извлекаем результат из кортежа\n",
|
||
" if use_cache:\n",
|
||
" out, decoder_new_cache = decoder_result\n",
|
||
" new_cache.append(decoder_new_cache)\n",
|
||
" else:\n",
|
||
" out = decoder_result[0]\n",
|
||
"\n",
|
||
" out = self._norm(out)\n",
|
||
" logits = self._linear(out)\n",
|
||
" \n",
|
||
" # Возвращаем результат с учетом use_cache\n",
|
||
" if use_cache:\n",
|
||
" return (logits, new_cache)\n",
|
||
" else:\n",
|
||
" return (logits, None)\n",
|
||
"\n",
|
||
" def generate(self,\n",
|
||
" x: torch.Tensor, \n",
|
||
" max_new_tokens: int, \n",
|
||
" do_sample: bool,\n",
|
||
" temperature: float = 1.0,\n",
|
||
" top_k: int = None,\n",
|
||
" top_p: float = None,\n",
|
||
" use_cache: bool = True\n",
|
||
" ) -> torch.Tensor:\n",
|
||
" cache = None\n",
|
||
"\n",
|
||
" for _ in range(max_new_tokens):\n",
|
||
" if use_cache and cache is not None:\n",
|
||
" # Используем кэш - передаем только последний токен\n",
|
||
" x_input = x[:, -1:] # [batch_size, 1]\n",
|
||
" else:\n",
|
||
" # Первая итерация или кэш отключен - передаем всю последовательность\n",
|
||
" x_input = x\n",
|
||
" \n",
|
||
" # Прямой проход с кэшем\n",
|
||
" logits, new_cache = self.forward(x_input, use_cache=use_cache, cache=cache)\n",
|
||
" \n",
|
||
" # Обновляем кэш для следующей итерации\n",
|
||
" if use_cache:\n",
|
||
" cache = new_cache\n",
|
||
"\n",
|
||
" last_logits = logits[:, -1, :] # [batch_size, vocab_size]\n",
|
||
"\n",
|
||
" # Масштабируем логиты температурой\n",
|
||
" if temperature > 0:\n",
|
||
" logits_scaled = last_logits / temperature\n",
|
||
" else:\n",
|
||
" logits_scaled = last_logits\n",
|
||
"\n",
|
||
" if do_sample == True and top_k != None:\n",
|
||
" _, topk_indices = torch.topk(logits_scaled, top_k, dim=-1)\n",
|
||
"\n",
|
||
" # # Заменим все НЕ top-k логиты на -inf\n",
|
||
" masked_logits = logits_scaled.clone()\n",
|
||
" vocab_size = logits_scaled.size(-1)\n",
|
||
"\n",
|
||
" # создаём маску: 1, если токен НЕ в topk_indices\n",
|
||
" mask = torch.ones_like(logits_scaled, dtype=torch.uint8)\n",
|
||
" mask.scatter_(1, topk_indices, 0) # 0 там, где top-k индексы\n",
|
||
" masked_logits[mask.byte()] = float('-inf')\n",
|
||
"\n",
|
||
" logits_scaled = masked_logits\n",
|
||
"\n",
|
||
" if do_sample == True and top_p != None:\n",
|
||
" # 1. Применим softmax, чтобы получить вероятности:\n",
|
||
" probs = F.softmax(logits_scaled, dim=-1) # [B, vocab_size]\n",
|
||
" # 2. Отсортируем токены по убыванию вероятностей:\n",
|
||
" sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1)\n",
|
||
" # 3. Посчитаем кумулятивную сумму вероятностей:\n",
|
||
" cum_probs = torch.cumsum(sorted_probs, dim=-1) # [B, vocab_size]\n",
|
||
" # 4. Определим маску: оставить токены, пока сумма < top_p\n",
|
||
" sorted_mask = (cum_probs <= top_p).byte() # [B, vocab_size]\n",
|
||
" # Гарантируем, что хотя бы первый токен останется\n",
|
||
" sorted_mask[:, 0] = 1\n",
|
||
" # 5. Преобразуем маску обратно в оригинальный порядок:\n",
|
||
" # Создаём полную маску из 0\n",
|
||
" mask = torch.zeros_like(probs, dtype=torch.uint8)\n",
|
||
" # Устанавливаем 1 в местах нужных токенов\n",
|
||
" mask.scatter_(dim=1, index=sorted_indices, src=sorted_mask)\n",
|
||
" # 6. Зануляем логиты токенов вне топ-p:\n",
|
||
" logits_scaled[~mask] = float('-inf')\n",
|
||
"\n",
|
||
" # 4. Применяем Softmax\n",
|
||
" probs = F.softmax(logits_scaled, dim=-1) # [batch_size, vocab_size]\n",
|
||
"\n",
|
||
"\n",
|
||
" if do_sample == True:\n",
|
||
" # 5. Если do_sample равен True, то отбираем токен случайно с помощью torch.multinomial\n",
|
||
" next_token = torch.multinomial(probs, num_samples=1) # [batch_size, 1]\n",
|
||
" else:\n",
|
||
" # 5. Если do_sample равен False, то выбираем токен с максимальной вероятностью\n",
|
||
" next_token = torch.argmax(probs, dim=-1, keepdim=True) # [batch_size, 1]\n",
|
||
" \n",
|
||
" # 6. Добавляем его к последовательности\n",
|
||
" x = torch.cat([x, next_token], dim=1) # [batch_size, seq_len+1]\n",
|
||
" return x\n",
|
||
"\n",
|
||
" def save(self, path):\n",
|
||
" torch.save({\n",
|
||
" 'model_state_dict': self.state_dict(),\n",
|
||
" 'vocab_size': self._vocab_size,\n",
|
||
" 'max_seq_len': self._max_seq_len,\n",
|
||
" 'emb_size': self._emb_size,\n",
|
||
" 'num_heads': self._num_heads,\n",
|
||
" 'head_size': self._head_size,\n",
|
||
" 'num_layers': self._num_layers\n",
|
||
" }, path)\n",
|
||
"\n",
|
||
" @classmethod\n",
|
||
" def load(cls, path, device):\n",
|
||
" checkpoint = torch.load(path, map_location=device)\n",
|
||
" model = cls(\n",
|
||
" vocab_size=checkpoint['vocab_size'],\n",
|
||
" max_seq_len=checkpoint['max_seq_len'],\n",
|
||
" emb_size=checkpoint['emb_size'],\n",
|
||
" num_heads=checkpoint['num_heads'],\n",
|
||
" head_size=checkpoint['head_size'],\n",
|
||
" num_layers=checkpoint['num_layers']\n",
|
||
" )\n",
|
||
" model.load_state_dict(checkpoint['model_state_dict'])\n",
|
||
" model.to(device)\n",
|
||
" return model\n",
|
||
"\n",
|
||
" @property\n",
|
||
" def max_seq_len(self) -> int:\n",
|
||
" return self._max_seq_len"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "7f4b3b1e",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 2. Обучение Mixtral\n",
|
||
"\n",
|
||
"Mixtral обучается в два этапа:\n",
|
||
"\n",
|
||
"- 1️⃣ **Предобучение (Unsupervised Pretraining)** \n",
|
||
"- 2️⃣ **Дообучение (Supervised Fine-Tuning)**"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "cb0a9172",
|
||
"metadata": {},
|
||
"source": [
|
||
"\n",
|
||
"\n",
|
||
"### 5.1 Предобучение\n",
|
||
"\n",
|
||
"На первом этапе модель обучается без разметки: она получает большой корпус текстов и учится **предсказывать следующий токен** по предыдущим.\n",
|
||
"\n",
|
||
"Функция потерь:\n",
|
||
"$$\n",
|
||
"L = - \\sum_{t=1}^{T} \\log P(x_t | x_1, x_2, ..., x_{t-1})\n",
|
||
"$$\n",
|
||
"\n",
|
||
"Таким образом, модель учится строить вероятностную модель языка, \"угадывая\" продолжение текста.\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "b4bdbb31",
|
||
"metadata": {},
|
||
"source": [
|
||
"Во время **предобучения** Mistral учится **предсказывать следующий токен** (language modeling task). \n",
|
||
"Формально: \n",
|
||
"$$ \n",
|
||
"P(x_t ,|, x_1, x_2, \\dots, x_{t-1}) \n",
|
||
"$$ \n",
|
||
"То есть, если на вход подаётся предложение `\"I love deep\"`, модель должна предсказать `\"learning\"`.\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "379a5c21",
|
||
"metadata": {},
|
||
"source": [
|
||
"### ✅ 5.1.1 Подготовка данных\n",
|
||
"\n",
|
||
"Создадим **датасет** на основе BPE-токенизатора:"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "8141ac57",
|
||
"metadata": {},
|
||
"source": [
|
||
"**BPE Tokenizator**"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 11,
|
||
"id": "3fc41e15",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"class BPE:\n",
|
||
" def __init__(self, vocab_size: int):\n",
|
||
" self.vocab_size = vocab_size\n",
|
||
" self.id2token = {}\n",
|
||
" self.token2id = {}\n",
|
||
"\n",
|
||
" def fit(self, text: str):\n",
|
||
" # 1. Получаем уникальные токены (символы)\n",
|
||
" unique_tokens = sorted(set(text))\n",
|
||
" tokens = unique_tokens.copy()\n",
|
||
"\n",
|
||
" # 2. Разбиваем текст на токены-символы\n",
|
||
" sequence = list(text)\n",
|
||
"\n",
|
||
" # 3. Объединяем токены до достижения нужного размера словаря\n",
|
||
" while len(tokens) < self.vocab_size:\n",
|
||
" #print(f'len={len(tokens)} < {self.vocab_size}')\n",
|
||
" # Считаем частоты пар\n",
|
||
" pair_freq = {}\n",
|
||
" for i in range(len(sequence) - 1):\n",
|
||
" pair = (sequence[i], sequence[i + 1])\n",
|
||
" #print(f'pair = {pair}')\n",
|
||
" if pair not in pair_freq:\n",
|
||
" pair_freq[pair] = 0\n",
|
||
" pair_freq[pair] += 1\n",
|
||
"\n",
|
||
"\n",
|
||
" #print(f'pair_freq = {pair_freq}') \n",
|
||
" if not pair_freq:\n",
|
||
" break # нет пар — выходим\n",
|
||
"\n",
|
||
" #for x in pair_freq.items():\n",
|
||
" # self.debug(x, sequence)\n",
|
||
"\n",
|
||
" # Находим самую частую пару (в случае равенства — та, что встретилась первой)\n",
|
||
" most_frequent_pair = max(pair_freq.items(), key=lambda x: (x[1], -self._pair_first_index(sequence, x[0])))[0]\n",
|
||
" #print(most_frequent_pair)\n",
|
||
" # Создаем новый токен\n",
|
||
" new_token = most_frequent_pair[0] + most_frequent_pair[1]\n",
|
||
" #print(f\"new token={new_token}\")\n",
|
||
" tokens.append(new_token)\n",
|
||
" #print(f\"tokens={tokens}\")\n",
|
||
"\n",
|
||
" i = 0\n",
|
||
" new_sequence = []\n",
|
||
"\n",
|
||
" while i < len(sequence):\n",
|
||
" if i < len(sequence) - 1 and (sequence[i], sequence[i + 1]) == most_frequent_pair:\n",
|
||
" new_sequence.append(new_token)\n",
|
||
" i += 2 # пропускаем два символа — заменённую пару\n",
|
||
" else:\n",
|
||
" new_sequence.append(sequence[i])\n",
|
||
" i += 1\n",
|
||
" sequence = new_sequence\n",
|
||
" #break\n",
|
||
" \n",
|
||
" # 4. Создаем словари\n",
|
||
" self.vocab = tokens.copy()\n",
|
||
" self.token2id = dict(zip(tokens, range(self.vocab_size)))\n",
|
||
" self.id2token = dict(zip(range(self.vocab_size), tokens))\n",
|
||
"\n",
|
||
" def _pair_first_index(self, sequence, pair):\n",
|
||
" for i in range(len(sequence) - 1):\n",
|
||
" if (sequence[i], sequence[i + 1]) == pair:\n",
|
||
" return i\n",
|
||
" return float('inf') # если пара не найдена (в теории не должно случиться)\n",
|
||
"\n",
|
||
"\n",
|
||
" def encode(self, text: str):\n",
|
||
" # 1. Разбиваем текст на токены-символы\n",
|
||
" sequence = list(text)\n",
|
||
" # 2. Инициализация пустого списка токенов\n",
|
||
" tokens = []\n",
|
||
" # 3. Установить i = 0\n",
|
||
" i = 0\n",
|
||
" while i < len(text):\n",
|
||
" # 3.1 Найти все токены в словаре, начинающиеся с text[i]\n",
|
||
" start_char = text[i]\n",
|
||
" result = [token for token in self.vocab if token.startswith(start_char)]\n",
|
||
" # 3.2 Выбрать самый длинный подходящий токен\n",
|
||
" find_token = self._find_max_matching_token(text[i:], result)\n",
|
||
" if find_token is None:\n",
|
||
" # Обработка неизвестного символа\n",
|
||
" tokens.append(text[i]) # Добавляем сам символ как токен\n",
|
||
" i += 1\n",
|
||
" else:\n",
|
||
" # 3.3 Добавить токен в результат\n",
|
||
" tokens.append(find_token)\n",
|
||
" # 3.4 Увеличить i на длину токена\n",
|
||
" i += len(find_token)\n",
|
||
"\n",
|
||
" # 4. Заменить токены на их ID\n",
|
||
" return self._tokens_to_ids(tokens)\n",
|
||
"\n",
|
||
" def _find_max_matching_token(self, text: str, tokens: list):\n",
|
||
" \"\"\"Находит самый длинный токен из списка, с которого начинается текст\"\"\"\n",
|
||
" matching = [token for token in tokens if text.startswith(token)]\n",
|
||
" return max(matching, key=len) if matching else None\n",
|
||
"\n",
|
||
" def _tokens_to_ids(self, tokens):\n",
|
||
" \"\"\"Конвертирует список токенов в их ID с обработкой неизвестных токенов\"\"\"\n",
|
||
" ids = []\n",
|
||
" for token in tokens:\n",
|
||
" if token in self.token2id:\n",
|
||
" ids.append(self.token2id[token])\n",
|
||
" else:\n",
|
||
" ids.append(0) # Специальное значение\n",
|
||
" return ids\n",
|
||
"\n",
|
||
"\n",
|
||
" def decode(self, ids: list) -> str:\n",
|
||
" return ''.join(self._ids_to_tokens(ids))\n",
|
||
"\n",
|
||
" def _ids_to_tokens(self, ids: list) -> list:\n",
|
||
" \"\"\"Конвертирует список Ids в их tokens\"\"\"\n",
|
||
" tokens = []\n",
|
||
" for id in ids:\n",
|
||
" if id in self.id2token:\n",
|
||
" tokens.append(self.id2token[id])\n",
|
||
" else:\n",
|
||
" tokens.append('') # Специальное значение\n",
|
||
" return tokens\n",
|
||
"\n",
|
||
"\n",
|
||
" def save(self, filename):\n",
|
||
" with open(filename, 'wb') as f:\n",
|
||
" dill.dump(self, f)\n",
|
||
" print(f\"Объект сохранён в {filename}\")\n",
|
||
"\n",
|
||
"\n",
|
||
" @classmethod\n",
|
||
" def load(cls, filename):\n",
|
||
" with open(filename, 'rb') as f:\n",
|
||
" obj = dill.load(f)\n",
|
||
" \n",
|
||
" print(f\"Объект загружен из {filename}\")\n",
|
||
" return obj"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 12,
|
||
"id": "dc7a3dbf",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"import torch\n",
|
||
"from torch.utils.data import Dataset, DataLoader\n",
|
||
"\n",
|
||
"class GPTDataset(Dataset):\n",
|
||
" def __init__(self, text: str, bpe: BPE, block_size: int):\n",
|
||
" self.bpe = bpe\n",
|
||
" self.block_size = block_size\n",
|
||
" self.data = bpe.encode(text)\n",
|
||
" \n",
|
||
" def __len__(self):\n",
|
||
" return len(self.data) - self.block_size\n",
|
||
"\n",
|
||
" def __getitem__(self, idx):\n",
|
||
" x = torch.tensor(self.data[idx:idx+self.block_size], dtype=torch.long)\n",
|
||
" y = torch.tensor(self.data[idx+1:idx+self.block_size+1], dtype=torch.long)\n",
|
||
" return x, y"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "e8981229",
|
||
"metadata": {},
|
||
"source": [
|
||
"### ✅ 5.1.2 Цикл обучения\n",
|
||
"\n",
|
||
"Для обучения создадим функцию:"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 15,
|
||
"id": "4e097434",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"import torch.nn.functional as F\n",
|
||
"from torch import optim\n",
|
||
"\n",
|
||
"def train_mixtral(model, dataset, epochs=5, batch_size=32, lr=3e-4, device='cpu'):\n",
|
||
" dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)\n",
|
||
" optimizer = optim.AdamW(model.parameters(), lr=lr)\n",
|
||
"\n",
|
||
" model.to(device)\n",
|
||
" model.train()\n",
|
||
"\n",
|
||
" for epoch in range(epochs):\n",
|
||
" total_loss = 0\n",
|
||
" for x, y in dataloader:\n",
|
||
" x, y = x.to(device), y.to(device)\n",
|
||
"\n",
|
||
" # Прямой проход\n",
|
||
" logits, _ = model(x, use_cache=False) # [B, T, vocab_size]\n",
|
||
"\n",
|
||
" # Перестроим выход под CrossEntropy\n",
|
||
" loss = F.cross_entropy(logits.view(-1, logits.size(-1)), y.view(-1))\n",
|
||
"\n",
|
||
" # Обратное распространение\n",
|
||
" optimizer.zero_grad()\n",
|
||
" loss.backward()\n",
|
||
" optimizer.step()\n",
|
||
"\n",
|
||
" total_loss += loss.item()\n",
|
||
"\n",
|
||
" avg_loss = total_loss / len(dataloader)\n",
|
||
" print(f\"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}\")\n",
|
||
"\n",
|
||
" return model"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "1f536de2",
|
||
"metadata": {},
|
||
"source": [
|
||
"### ✅ 5.1.3 Пример запуска\n",
|
||
"\n",
|
||
"\n",
|
||
"**🧠 Конфигурация Mistral Mini**\n",
|
||
"\n",
|
||
"\n",
|
||
"| Параметр | Значение | Описание |\n",
|
||
"| --------------- | -------- | --------------------------------------------- |\n",
|
||
"| **vocab_size** | `50257` | Размер словаря (BPE токенизатор OpenAI) |\n",
|
||
"| **max_seq_len** | `512` | Максимальная длина входной последовательности |\n",
|
||
"| **emb_size** | `256` | Размер эмбеддингов (векторное пространство) |\n",
|
||
"| **num_heads** | `4` | Количество голов в multi-head attention |\n",
|
||
"| **head_size** | `64` | Размерность одной головы внимания (768 / 12) |\n",
|
||
"| **num_layers** | `4` | Количество блоков (декодеров) |\n",
|
||
"| **dropout** | `0.1` | Вероятность дропаута |\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 14,
|
||
"id": "83ed4be7",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Dataset length: 20\n",
|
||
"Epoch 1/100, Loss: 3.5861\n",
|
||
"Epoch 2/100, Loss: 1.1913\n",
|
||
"Epoch 3/100, Loss: 0.4864\n",
|
||
"Epoch 4/100, Loss: 0.2490\n",
|
||
"Epoch 5/100, Loss: 0.1993\n",
|
||
"Epoch 6/100, Loss: 0.1350\n",
|
||
"Epoch 7/100, Loss: 0.1039\n",
|
||
"Epoch 8/100, Loss: 0.1051\n",
|
||
"Epoch 9/100, Loss: 0.0730\n",
|
||
"Epoch 10/100, Loss: 0.0754\n",
|
||
"Epoch 11/100, Loss: 0.0819\n",
|
||
"Epoch 12/100, Loss: 0.0664\n",
|
||
"Epoch 13/100, Loss: 0.0793\n",
|
||
"Epoch 14/100, Loss: 0.0668\n",
|
||
"Epoch 15/100, Loss: 0.0818\n",
|
||
"Epoch 16/100, Loss: 0.0734\n",
|
||
"Epoch 17/100, Loss: 0.0637\n",
|
||
"Epoch 18/100, Loss: 0.0584\n",
|
||
"Epoch 19/100, Loss: 0.0762\n",
|
||
"Epoch 20/100, Loss: 0.0683\n",
|
||
"Epoch 21/100, Loss: 0.0624\n",
|
||
"Epoch 22/100, Loss: 0.0557\n",
|
||
"Epoch 23/100, Loss: 0.0579\n",
|
||
"Epoch 24/100, Loss: 0.0558\n",
|
||
"Epoch 25/100, Loss: 0.0578\n",
|
||
"Epoch 26/100, Loss: 0.0520\n",
|
||
"Epoch 27/100, Loss: 0.0642\n",
|
||
"Epoch 28/100, Loss: 0.0660\n",
|
||
"Epoch 29/100, Loss: 0.0508\n",
|
||
"Epoch 30/100, Loss: 0.0542\n",
|
||
"Epoch 31/100, Loss: 0.0478\n",
|
||
"Epoch 32/100, Loss: 0.0597\n",
|
||
"Epoch 33/100, Loss: 0.0578\n",
|
||
"Epoch 34/100, Loss: 0.0557\n",
|
||
"Epoch 35/100, Loss: 0.0528\n",
|
||
"Epoch 36/100, Loss: 0.0521\n",
|
||
"Epoch 37/100, Loss: 0.0520\n",
|
||
"Epoch 38/100, Loss: 0.0512\n",
|
||
"Epoch 39/100, Loss: 0.0501\n",
|
||
"Epoch 40/100, Loss: 0.0497\n",
|
||
"Epoch 41/100, Loss: 0.0454\n",
|
||
"Epoch 42/100, Loss: 0.0519\n",
|
||
"Epoch 43/100, Loss: 0.0535\n",
|
||
"Epoch 44/100, Loss: 0.0464\n",
|
||
"Epoch 45/100, Loss: 0.0459\n",
|
||
"Epoch 46/100, Loss: 0.0437\n",
|
||
"Epoch 47/100, Loss: 0.0585\n",
|
||
"Epoch 48/100, Loss: 0.0469\n",
|
||
"Epoch 49/100, Loss: 0.0538\n",
|
||
"Epoch 50/100, Loss: 0.0592\n",
|
||
"Epoch 51/100, Loss: 0.0520\n",
|
||
"Epoch 52/100, Loss: 0.0582\n",
|
||
"Epoch 53/100, Loss: 0.0504\n",
|
||
"Epoch 54/100, Loss: 0.0471\n",
|
||
"Epoch 55/100, Loss: 0.0478\n",
|
||
"Epoch 56/100, Loss: 0.0487\n",
|
||
"Epoch 57/100, Loss: 0.0507\n",
|
||
"Epoch 58/100, Loss: 0.0500\n",
|
||
"Epoch 59/100, Loss: 0.0457\n",
|
||
"Epoch 60/100, Loss: 0.0493\n",
|
||
"Epoch 61/100, Loss: 0.0431\n",
|
||
"Epoch 62/100, Loss: 0.0503\n",
|
||
"Epoch 63/100, Loss: 0.0436\n",
|
||
"Epoch 64/100, Loss: 0.0512\n",
|
||
"Epoch 65/100, Loss: 0.0488\n",
|
||
"Epoch 66/100, Loss: 0.0436\n",
|
||
"Epoch 67/100, Loss: 0.0505\n",
|
||
"Epoch 68/100, Loss: 0.0389\n",
|
||
"Epoch 69/100, Loss: 0.0447\n",
|
||
"Epoch 70/100, Loss: 0.0442\n",
|
||
"Epoch 71/100, Loss: 0.0443\n",
|
||
"Epoch 72/100, Loss: 0.0460\n",
|
||
"Epoch 73/100, Loss: 0.0483\n",
|
||
"Epoch 74/100, Loss: 0.0538\n",
|
||
"Epoch 75/100, Loss: 0.0437\n",
|
||
"Epoch 76/100, Loss: 0.0500\n",
|
||
"Epoch 77/100, Loss: 0.0443\n",
|
||
"Epoch 78/100, Loss: 0.0461\n",
|
||
"Epoch 79/100, Loss: 0.0467\n",
|
||
"Epoch 80/100, Loss: 0.0472\n",
|
||
"Epoch 81/100, Loss: 0.0507\n",
|
||
"Epoch 82/100, Loss: 0.0454\n",
|
||
"Epoch 83/100, Loss: 0.0414\n",
|
||
"Epoch 84/100, Loss: 0.0468\n",
|
||
"Epoch 85/100, Loss: 0.0525\n",
|
||
"Epoch 86/100, Loss: 0.0456\n",
|
||
"Epoch 87/100, Loss: 0.0512\n",
|
||
"Epoch 88/100, Loss: 0.0487\n",
|
||
"Epoch 89/100, Loss: 0.0452\n",
|
||
"Epoch 90/100, Loss: 0.0436\n",
|
||
"Epoch 91/100, Loss: 0.0488\n",
|
||
"Epoch 92/100, Loss: 0.0471\n",
|
||
"Epoch 93/100, Loss: 0.0442\n",
|
||
"Epoch 94/100, Loss: 0.0435\n",
|
||
"Epoch 95/100, Loss: 0.0502\n",
|
||
"Epoch 96/100, Loss: 0.0430\n",
|
||
"Epoch 97/100, Loss: 0.0430\n",
|
||
"Epoch 98/100, Loss: 0.0424\n",
|
||
"Epoch 99/100, Loss: 0.0421\n",
|
||
"Epoch 100/100, Loss: 0.0410\n"
|
||
]
|
||
},
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"Mixtral(\n",
|
||
" (_token_embeddings): TokenEmbeddings(\n",
|
||
" (_embedding): Embedding(100, 256)\n",
|
||
" )\n",
|
||
" (_position_embeddings): RoPE()\n",
|
||
" (_dropout): Dropout(p=0.1, inplace=False)\n",
|
||
" (_decoders): ModuleList(\n",
|
||
" (0-3): 4 x Decoder(\n",
|
||
" (_heads): GroupedQueryAttention(\n",
|
||
" (_rope): RoPE()\n",
|
||
" (_q): Linear(in_features=256, out_features=256, bias=True)\n",
|
||
" (_k): Linear(in_features=256, out_features=128, bias=True)\n",
|
||
" (_v): Linear(in_features=256, out_features=128, bias=True)\n",
|
||
" (_layer): Linear(in_features=256, out_features=256, bias=True)\n",
|
||
" (_dropout): Dropout(p=0.1, inplace=False)\n",
|
||
" )\n",
|
||
" (_ff): MoE(\n",
|
||
" (_router): Linear(in_features=256, out_features=8, bias=True)\n",
|
||
" (_experts): ModuleList(\n",
|
||
" (0-7): 8 x SwiGLU(\n",
|
||
" (_gate): Linear(in_features=256, out_features=1024, bias=True)\n",
|
||
" (_up): Linear(in_features=256, out_features=1024, bias=True)\n",
|
||
" (_down): Linear(in_features=1024, out_features=256, bias=True)\n",
|
||
" (_activation): SiLU()\n",
|
||
" (_dropout): Dropout(p=0.1, inplace=False)\n",
|
||
" )\n",
|
||
" )\n",
|
||
" (_dropout): Dropout(p=0.1, inplace=False)\n",
|
||
" )\n",
|
||
" (_norm1): RMSNorm()\n",
|
||
" (_norm2): RMSNorm()\n",
|
||
" )\n",
|
||
" )\n",
|
||
" (_norm): RMSNorm()\n",
|
||
" (_linear): Linear(in_features=256, out_features=100, bias=True)\n",
|
||
")"
|
||
]
|
||
},
|
||
"execution_count": 14,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"# 1. Исходный текст\n",
|
||
"text = \"Deep learning is amazing. Transformers changed the world. Attention is all you need. GPT models revolutionized NLP.\"\n",
|
||
"\n",
|
||
"# 2. Обучаем токенизатор\n",
|
||
"bpe = BPE(vocab_size=100)\n",
|
||
"bpe.fit(text)\n",
|
||
"\n",
|
||
"# 3. Создаем датасет\n",
|
||
"dataset = GPTDataset(text, bpe, block_size=8)\n",
|
||
"print(f\"Dataset length: {len(dataset)}\")\n",
|
||
"\n",
|
||
"# 4. Инициализируем модель\n",
|
||
"model = Mixtral(\n",
|
||
" vocab_size=len(bpe.vocab), # размер словаря BPE\n",
|
||
" max_seq_len=512, # GPT-2 использует контекст в 512 токена\n",
|
||
" emb_size=256, # размер эмбеддингов\n",
|
||
" num_q_heads=4, # количество голов внимания\n",
|
||
" num_kv_heads=2, # количество голов внимания\n",
|
||
" head_size=64, # размер каждой головы (256 / 4)\n",
|
||
" num_layers=4, # количество блоков Transformer\n",
|
||
" num_experts=8,\n",
|
||
" top_k_experts=2,\n",
|
||
" window_size=8,\n",
|
||
" dropout=0.1 # стандартный dropout GPT-2\n",
|
||
")\n",
|
||
"\n",
|
||
"# 5. Обучаем\n",
|
||
"train_mixtral(model, dataset, epochs=100, batch_size=4)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "72e91ceb",
|
||
"metadata": {},
|
||
"source": [
|
||
"\n",
|
||
"---\n",
|
||
"\n",
|
||
"### 5.2 Дообучение\n",
|
||
"\n",
|
||
"После предобучения LLAMA уже знает структуру и грамматику языка. \n",
|
||
"На втором этапе она дообучается на конкретных задачах (например, классификация, QA) с помощью размеченных данных.\n",
|
||
"\n",
|
||
"Технически это почти то же обучение, только:\n",
|
||
"\n",
|
||
"- Загружаем модель с уже обученными весами.\n",
|
||
"- Используем новые данные.\n",
|
||
"- Можно уменьшить скорость обучения.\n",
|
||
"- Иногда замораживают часть слоёв (например, эмбеддинги).\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 16,
|
||
"id": "0a6f9730",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"def fine_tune_mixtral(model, dataset, epochs=3, batch_size=16, lr=1e-5, device='cpu', freeze_embeddings=True):\n",
|
||
" if freeze_embeddings:\n",
|
||
" for param in model._token_embeddings.parameters():\n",
|
||
" param.requires_grad = False\n",
|
||
" for param in model._position_embeddings.parameters():\n",
|
||
" param.requires_grad = False\n",
|
||
"\n",
|
||
" dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)\n",
|
||
" optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=lr)\n",
|
||
"\n",
|
||
" model.to(device)\n",
|
||
" model.train()\n",
|
||
"\n",
|
||
" for epoch in range(epochs):\n",
|
||
" total_loss = 0\n",
|
||
" for x, y in dataloader:\n",
|
||
" x, y = x.to(device), y.to(device)\n",
|
||
" logits, _ = model(x, use_cache=False)\n",
|
||
" loss = F.cross_entropy(logits.view(-1, logits.size(-1)), y.view(-1))\n",
|
||
" optimizer.zero_grad()\n",
|
||
" loss.backward()\n",
|
||
" optimizer.step()\n",
|
||
" total_loss += loss.item()\n",
|
||
" print(f\"Fine-tune Epoch {epoch+1}/{epochs}, Loss: {total_loss / len(dataloader):.4f}\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 17,
|
||
"id": "c9e28f47",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Fine-tune Epoch 1/10, Loss: 4.8596\n",
|
||
"Fine-tune Epoch 2/10, Loss: 2.6883\n",
|
||
"Fine-tune Epoch 3/10, Loss: 1.5315\n",
|
||
"Fine-tune Epoch 4/10, Loss: 1.1258\n",
|
||
"Fine-tune Epoch 5/10, Loss: 0.9248\n",
|
||
"Fine-tune Epoch 6/10, Loss: 0.7725\n",
|
||
"Fine-tune Epoch 7/10, Loss: 0.6405\n",
|
||
"Fine-tune Epoch 8/10, Loss: 0.5367\n",
|
||
"Fine-tune Epoch 9/10, Loss: 0.4469\n",
|
||
"Fine-tune Epoch 10/10, Loss: 0.4168\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"# Например, мы хотим дообучить модель на стиле коротких технических фраз\n",
|
||
"fine_tune_text = \"\"\"\n",
|
||
"Transformers revolutionize NLP.\n",
|
||
"Deep learning enables self-attention.\n",
|
||
"GPT generates text autoregressively.\n",
|
||
"\"\"\"\n",
|
||
"\n",
|
||
"dataset = GPTDataset(fine_tune_text, bpe, block_size=8)\n",
|
||
"\n",
|
||
"\n",
|
||
"# Запуск дообучения\n",
|
||
"fine_tune_mixtral(model, dataset, epochs=10, batch_size=4, lr=1e-4)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "fd470fdd",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 📝 6. Генерация текста после обучения"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 18,
|
||
"id": "245dd064",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"def generate_text(model, bpe, prompt: str, max_new_tokens=20, device='cpu'):\n",
|
||
" model.eval()\n",
|
||
" ids = torch.tensor([bpe.encode(prompt)], dtype=torch.long).to(device)\n",
|
||
" out = model.generate(ids, max_new_tokens=max_new_tokens, do_sample=True)\n",
|
||
" text = bpe.decode(out[0].tolist())\n",
|
||
" return text"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 19,
|
||
"id": "1f5db85f",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Deep learning ena les self atelf a\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"print(generate_text(model, bpe, \"Deep learning\", max_new_tokens=20))"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "094ab394",
|
||
"metadata": {},
|
||
"source": []
|
||
}
|
||
],
|
||
"metadata": {
|
||
"kernelspec": {
|
||
"display_name": ".venv",
|
||
"language": "python",
|
||
"name": "python3"
|
||
},
|
||
"language_info": {
|
||
"codemirror_mode": {
|
||
"name": "ipython",
|
||
"version": 3
|
||
},
|
||
"file_extension": ".py",
|
||
"mimetype": "text/x-python",
|
||
"name": "python",
|
||
"nbconvert_exporter": "python",
|
||
"pygments_lexer": "ipython3",
|
||
"version": "3.10.9"
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 5
|
||
}
|