Files
llm-arch-research/notebooks/gemma.ipynb
Sergey Penkovsky cfb4b6dfb1 feat(gemma): initial implementation of Gemma model and configs
- Add core Gemma model (architecture, attention, GeGLU, RoPE, RMSNorm, etc)
- Add configs for training and generation: gemma_train.json, gemma_generate.json
- Add Gemma notebook for exploratory analysis and demonstration
- Add __init__.py for Gemma submodule
- Update run_llm_experiment.py to support Gemma experiment configs

test(gemma): add comprehensive unit tests for Gemma

- Test forward pass (with/without cache)
- Test autoregressive generation (greedy, top-k, top-p)
- Test shape correctness and max sequence length errors
- Test multi-layer stack and token embeddings

docs: add documentation notebook for Gemma usage and analysis

Closes: #issue (if applicable)
2025-10-21 01:02:15 +03:00

1345 lines
56 KiB
Plaintext
Raw Permalink Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"cells": [
{
"cell_type": "markdown",
"id": "1636810a",
"metadata": {},
"source": [
"# Gemma\n",
"\n",
"<p align=\"center\">\n",
" <img src=\"https://ucarecdn.com/7b36315f-fc63-4f57-8344-12f1a993376f/\" alt=\"arch\" width=\"1000\" height=\"165\">\n",
"</p>\n",
"\n",
"Gemma 1 вышла в феврале 2024 года.\n",
"\n",
"По архитектуре модель больше всего похожа на Llama'у. Содержит уже знакомые нам RoPE и RMSNorm. Но есть и новинки:\n",
"\n",
"* **Multi-Query Attention (MQA)** — крайне экономный вариант механизма внимания.\n",
"* **GeGLU** — гибридная функция активации. Почти клон SwiGLU :)\n",
"\n",
"Обе довольно легкие для внедрения, по сравнению с прошлыми новинками :)\n"
]
},
{
"cell_type": "markdown",
"id": "cea30169",
"metadata": {},
"source": [
"# Multi-Query Attention\n",
"\n",
"По своей сути, Multi-Query Attention (MQA) — это частный случай Grouped Query Attention (GQA), который мы реализовали в уроке про Mistral.\n",
"\n",
"<p align=\"center\">\n",
" <img src=\"https://ucarecdn.com/d1b0b294-9c02-4de1-9074-ae8ebfe96f6e/\" alt=\"mqa\" width=\"1000\" height=\"217\">\n",
"</p>\n",
"\n",
"В GQA на каждую голову приходится один вектор запроса (query). При этом каждый вектор ключа (key) и значения (value) обслуживает ( n )-голов.\n",
"Так вот, в MQA на все головы (в одном блоке декодера) приходится всего по одному вектору ключа (key) и одному вектору значения (value). Это такая радикальная форма экономии :)\n"
]
},
{
"cell_type": "markdown",
"id": "539550fe",
"metadata": {},
"source": [
"**Multi-Query Attention (разработка)**"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "6be61c63",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from torch import nn\n",
"from typing import Optional\n",
"\n",
"\n",
"class RoPE(nn.Module):\n",
"\n",
" def __init__(self, head_size: int, max_seq_len: int, base: int = 10_000):\n",
" super().__init__()\n",
" assert head_size % 2 == 0, \"head_size должен быть четным\"\n",
"\n",
" # Вычисление частот: θ_i = base^(-2i/d) для i ∈ [0, d/2-1]\n",
" freqs = 1.0 / (base ** (2 * torch.arange(head_size // 2).float() / head_size))\n",
"\n",
" # Позиции от 0 до max_seq_len-1\n",
" positions = torch.arange(max_seq_len).float()\n",
"\n",
" # Внешнее произведение: m * θ_i для всех позиций и частот\n",
" freq_matrix = positions.unsqueeze(1) * freqs.unsqueeze(0)\n",
"\n",
" # Предвычисление матриц косинусов и синусов\n",
" self.register_buffer(\"cos_matrix\", torch.cos(freq_matrix))\n",
" self.register_buffer(\"sin_matrix\", torch.sin(freq_matrix))\n",
"\n",
" def forward(self, x: torch.Tensor) -> torch.Tensor: # [batch_size × seq_len × head_size] [batch_size × num_heads × seq_len × head_size]\n",
" batch_size, num_heads, seq_len, head_size = x.shape\n",
"\n",
" # Берем нужную часть матриц и приводим к типу x\n",
" cos = self.cos_matrix[:seq_len].to(x.dtype) # [seq_len, head_size//2]\n",
" sin = self.sin_matrix[:seq_len].to(x.dtype) # [seq_len, head_size//2]\n",
"\n",
" # Явное изменение формы для broadcasting\n",
" cos = cos.reshape(1, 1, seq_len, head_size // 2)\n",
" sin = sin.reshape(1, 1, seq_len, head_size // 2)\n",
"\n",
" # Разделяем на четные и нечетные компоненты по ПОСЛЕДНЕМУ измерению\n",
" x_even = x[..., 0::2] # [batch_size, num_heads, seq_len, head_size//2]\n",
" x_odd = x[..., 1::2] # [batch_size, num_heads, seq_len, head_size//2]\n",
"\n",
" # Применяем поворот: q' = q * cos(mθ) + rotate(q) * sin(mθ)\n",
" x_rotated_even = x_even * cos - x_odd * sin\n",
" x_rotated_odd = x_even * sin + x_odd * cos\n",
"\n",
" # Объединяем обратно в исходную размерность\n",
" x_rotated = torch.stack([x_rotated_even, x_rotated_odd], dim=-1)\n",
" x_rotated = x_rotated.flatten(-2) # [batch_size, seq_len, head_size]\n",
"\n",
" return x_rotated"
]
},
{
"cell_type": "code",
"execution_count": 34,
"id": "811921b1",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from torch import nn\n",
"import torch.nn.functional as F\n",
"\n",
"class MultiQueryAttention(nn.Module):\n",
" def __init__(\n",
" self,\n",
" num_q_heads: int,\n",
" emb_size: int,\n",
" head_size: int,\n",
" max_seq_len: int,\n",
" rope: RoPE = None,\n",
" dropout: float = 0.1,\n",
" ):\n",
" super().__init__()\n",
" self._num_q_heads = num_q_heads\n",
" self._head_size = head_size\n",
" self._max_seq_len = max_seq_len\n",
" self._rope = rope\n",
" \n",
" self._q = nn.Linear(emb_size, num_q_heads * head_size)\n",
" self._k = nn.Linear(emb_size, head_size)\n",
" self._v = nn.Linear(emb_size, head_size)\n",
"\n",
" # Создание causal маски\n",
" mask = torch.tril(torch.ones(max_seq_len, max_seq_len))\n",
" self.register_buffer(\n",
" \"_tril_mask\", mask.bool() if hasattr(torch, \"bool\") else mask.byte()\n",
" )\n",
" \n",
" self._layer = nn.Linear(num_q_heads * head_size, emb_size)\n",
" self._dropout = nn.Dropout(dropout)\n",
"\n",
" def forward(\n",
" self,\n",
" x: torch.Tensor,\n",
" mask: torch.Tensor = None,\n",
" use_cache: bool = True,\n",
" cache: list = None,\n",
" ):\n",
" batch_size, seq_len, emb_size = x.shape\n",
" if seq_len > self._max_seq_len:\n",
" raise ValueError(\n",
" f\"Длина последовательности {seq_len} превышает максимум {self._max_seq_len}\"\n",
" )\n",
"\n",
" # Пропустите тензор x через матрицы Wq, Wk , Wv, чтобы получить матрицы запроса, ключа и значения.\n",
" k = self._k(x) # [B, T, hs]\n",
" q = self._q(x) # [B, T, hs]\n",
" v = self._v(x) # [B, T, hs]\n",
"\n",
" # Шаг 2: Изменение формы для multi-head\n",
" # [batch_size, seq_len, num_heads * head_size] \n",
" # -> [batch_size, seq_len, num_heads, head_size]\n",
" q = q.reshape(batch_size, seq_len, self._num_q_heads, self._head_size)\n",
" k = k.reshape(batch_size, seq_len, 1, self._head_size)\n",
" v = v.reshape(batch_size, seq_len, 1, self._head_size)\n",
"\n",
" # 3. Transpose: [B, T, H, hs] -> [B, H, T, hs]\n",
" q = q.transpose(1, 2)\n",
" k = k.transpose(1, 2)\n",
" v = v.transpose(1, 2)\n",
"\n",
" # Пропустите матрицы запроса и ключа через экземпляр rope, чтобы выполнить поворот.\n",
" if self._rope is not None:\n",
" # Применяем RoPE к Q и K (НЕ к V!)\n",
" q = self._rope(q) # [B, T, hs]\n",
" k = self._rope(k) # [B, T, hs]\n",
"\n",
"\n",
" # Если cache пришел, то объединяем кэш и одну строку из ключа и значения. Это будут новые key и value для последующих вычислений.\n",
" # 5. Кэширование (для autoregressive generation)\n",
" if cache is not None:\n",
" k_cache, v_cache = cache\n",
" k = torch.cat([k_cache, k], dim=2) # Concat по seq_len (dim=2)\n",
" v = torch.cat([v_cache, v], dim=2)\n",
"\n",
"\n",
" # Перемножим матрицы запроса и ключа (транспонированную), чтобы вычислить матрицу внимания.\n",
" # И разделить все значения в матрице внимания на корень из head_size.\n",
" scores = q @ k.transpose(-2, -1) / (self._head_size ** 0.5)\n",
"\n",
" # Если cache пришел, то маску не накладываем. Иначе наложите на матрицу внимания треугольную маску, созданную при инициализации. Все скрытые значения должны быть приведены к минус бесконечности: float('-inf').\n",
" if cache is None:\n",
" scores = scores.masked_fill(\n",
" ~self._tril_mask[:seq_len, :seq_len], float(\"-inf\")\n",
" )\n",
"\n",
" # Применить к матрице внимания (построчно) функцию Softmax.\n",
" weights = F.softmax(scores, dim=-1)\n",
"\n",
" # Перемножим матрицу внимания и матрицу значения.\n",
" x_out = weights @ v # [B, T, hs]\n",
"\n",
"\n",
" # Измените форму тензора на batch_size × seq_len × num_heads*head_size.\n",
" # Transpose обратно и concatenate heads\n",
" x_out = x_out.transpose(1, 2) # [B, T_q, H, hs]\n",
" x_out = x_out.contiguous() # Важно для reshape!\n",
" concatenated_attention = x_out.reshape(batch_size, seq_len, self._num_q_heads * self._head_size)\n",
"\n",
"\n",
" # Пропустите получившийся тензор через последний линейный слой.\n",
" # 3. Проецируем в пространство эмбеддингов\n",
" projected_output = self._layer(concatenated_attention)\n",
"\n",
"\n",
" # 4. Применяем dropout для регуляризации\n",
" final_output = self._dropout(projected_output)\n",
"\n",
" if use_cache is True:\n",
" return (final_output, (k, v))\n",
" else:\n",
" return (final_output, None)\n"
]
},
{
"cell_type": "code",
"execution_count": 35,
"id": "97771d9a",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"✅ Test 1 - Output shape: torch.Size([2, 10, 512])\n",
"✅ Test 2 - First output shape: torch.Size([2, 5, 512])\n",
"✅ Test 2 - Second output shape: torch.Size([2, 1, 512])\n",
"\n",
"✅ Все тесты пройдены!\n"
]
}
],
"source": [
"\n",
"# Параметры\n",
"batch_size = 2\n",
"seq_len = 10\n",
"emb_size = 512\n",
"num_q_heads = 8\n",
"head_size = 64\n",
"max_seq_len = 512\n",
"\n",
" # Создание модели\n",
"rope = RoPE(head_size=head_size, max_seq_len=max_seq_len)\n",
"mha = MultiQueryAttention(\n",
" num_q_heads=num_q_heads,\n",
" emb_size=emb_size,\n",
" head_size=head_size,\n",
" max_seq_len=max_seq_len,\n",
" rope=rope,\n",
" dropout=0.1,\n",
")\n",
"\n",
" # Тест 1: Обычный forward pass\n",
"x = torch.randn(batch_size, seq_len, emb_size)\n",
"output, cache = mha(x, use_cache=False)\n",
"print(f\"✅ Test 1 - Output shape: {output.shape}\") # [2, 10, 512]\n",
"assert output.shape == (batch_size, seq_len, emb_size)\n",
"\n",
" # Тест 2: С кэшированием\n",
"x1 = torch.randn(batch_size, 5, emb_size)\n",
"output1, cache1 = mha(x1, use_cache=True)\n",
"print(f\"✅ Test 2 - First output shape: {output1.shape}\") # [2, 5, 512]\n",
"\n",
"x2 = torch.randn(batch_size, 1, emb_size)\n",
"output2, cache2 = mha(x2, use_cache=True, cache=cache1)\n",
"print(f\"✅ Test 2 - Second output shape: {output2.shape}\") # [2, 1, 512]\n",
"\n",
"print(\"\\n✅ Все тесты пройдены!\")"
]
},
{
"cell_type": "markdown",
"id": "b5875022",
"metadata": {},
"source": [
"Вот конвертированный Markdown для твоего HTML:\n",
"\n",
"---\n",
"\n",
"# GeGLU\n",
"\n",
"<p align=\"center\">\n",
" <img src=\"https://ucarecdn.com/292c5458-2544-4e1f-a3f1-a981faeab39d/\" alt=\"geglu\" width=\"1000\" height=\"586\">\n",
"</p>\n",
"\n",
"GeGLU — это гибридная функция активации.\n",
"По сути, это та же **SwiGLU**, которую мы реализовали в **Llama**, просто у неё в качестве базовой функции вместо **SiLU** (как в Llama) используется **GELU** (как в GPT-2).\n"
]
},
{
"cell_type": "markdown",
"id": "a3345832",
"metadata": {},
"source": [
"**GeGLU (разработка)**"
]
},
{
"cell_type": "code",
"execution_count": 65,
"id": "82f52110",
"metadata": {},
"outputs": [],
"source": [
"import math\n",
"\n",
"class GELU(nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.sqrt_2_over_pi = torch.sqrt(torch.tensor(2.0) / math.pi)\n",
" \n",
" def forward(self, x: torch.Tensor) -> torch.Tensor:\n",
" return 0.5 * x * (1 + torch.tanh(\n",
" self.sqrt_2_over_pi * (x + 0.044715 * torch.pow(x, 3))\n",
" ))\n",
"\n",
"class GeGLU(nn.Module):\n",
" def __init__(self, emb_size: int, dropout: float = 0.1):\n",
" super().__init__()\n",
"\n",
" self._gate = nn.Linear(emb_size, 4 * emb_size)\n",
" self._up = nn.Linear(emb_size, 4 * emb_size)\n",
" self._down = nn.Linear(4 * emb_size, emb_size)\n",
" self._activation = GELU()\n",
" self._dropout = nn.Dropout(dropout)\n",
"\n",
" def forward(self, x: torch.Tensor): # [batch_size × seq_len × emb_size].\n",
" gate_out = self._gate(x) # [batch, seq, 4*emb]\n",
" activation_out = self._activation(gate_out) # [batch, seq, 4*emb]\n",
" up_out = self._up(x) # [batch, seq, 4*emb]\n",
" out = up_out * activation_out # поэлементное!\n",
" out = self._down(out) # [batch, seq, emb]\n",
" return self._dropout(out)\n"
]
},
{
"cell_type": "markdown",
"id": "db378855",
"metadata": {},
"source": [
"# Full Model"
]
},
{
"cell_type": "code",
"execution_count": 66,
"id": "568437e8",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from torch import nn\n",
"from torch import Tensor\n",
"import torch.nn.functional as F\n",
"from math import sqrt\n",
"\n",
"\n",
" \n",
"class RMSNorm(nn.Module):\n",
" def __init__(self, dim: int, eps: float = 1e-6):\n",
" super().__init__()\n",
" self._eps = eps\n",
" self._w = nn.Parameter(torch.ones(dim))\n",
" \n",
" def forward(self, x: torch.Tensor): # [batch_size × seq_len × emb_size]\n",
" rms = (x.pow(2).mean(-1, keepdim=True) + self._eps) ** 0.5\n",
" norm_x = x / rms\n",
" return self._w * norm_x\n",
"\n",
"class TokenEmbeddings(nn.Module):\n",
" def __init__(self, vocab_size: int, emb_size: int):\n",
" super().__init__()\n",
" self._embedding = nn.Embedding(\n",
" num_embeddings=vocab_size,\n",
" embedding_dim=emb_size\n",
" )\n",
"\n",
" def forward(self, x: Tensor) -> Tensor:\n",
" return self._embedding(x)\n",
"\n",
" @property\n",
" def num_embeddings(self) -> int:\n",
" return self._embedding.num_embeddings\n",
"\n",
" @property\n",
" def embedding_dim(self) -> int:\n",
" return self._embedding.embedding_dim\n",
"\n",
"\n",
"\n",
"import torch\n",
"from torch import nn\n",
"from typing import Optional\n",
"\n",
"\n",
"class RoPE(nn.Module):\n",
"\n",
" def __init__(self, head_size: int, max_seq_len: int, base: int = 10_000):\n",
" super().__init__()\n",
" assert head_size % 2 == 0, \"head_size должен быть четным\"\n",
"\n",
" # Вычисление частот: θ_i = base^(-2i/d) для i ∈ [0, d/2-1]\n",
" freqs = 1.0 / (base ** (2 * torch.arange(head_size // 2).float() / head_size))\n",
"\n",
" # Позиции от 0 до max_seq_len-1\n",
" positions = torch.arange(max_seq_len).float()\n",
"\n",
" # Внешнее произведение: m * θ_i для всех позиций и частот\n",
" freq_matrix = positions.unsqueeze(1) * freqs.unsqueeze(0)\n",
"\n",
" # Предвычисление матриц косинусов и синусов\n",
" self.register_buffer(\"cos_matrix\", torch.cos(freq_matrix))\n",
" self.register_buffer(\"sin_matrix\", torch.sin(freq_matrix))\n",
"\n",
" def forward(self, x: torch.Tensor) -> torch.Tensor: # [batch_size × seq_len × head_size] [batch_size × num_heads × seq_len × head_size]\n",
" batch_size, num_heads, seq_len, head_size = x.shape\n",
"\n",
" # Берем нужную часть матриц и приводим к типу x\n",
" cos = self.cos_matrix[:seq_len].to(x.dtype) # [seq_len, head_size//2]\n",
" sin = self.sin_matrix[:seq_len].to(x.dtype) # [seq_len, head_size//2]\n",
"\n",
" # Явное изменение формы для broadcasting\n",
" cos = cos.reshape(1, 1, seq_len, head_size // 2)\n",
" sin = sin.reshape(1, 1, seq_len, head_size // 2)\n",
"\n",
" # Разделяем на четные и нечетные компоненты по ПОСЛЕДНЕМУ измерению\n",
" x_even = x[..., 0::2] # [batch_size, num_heads, seq_len, head_size//2]\n",
" x_odd = x[..., 1::2] # [batch_size, num_heads, seq_len, head_size//2]\n",
"\n",
" # Применяем поворот: q' = q * cos(mθ) + rotate(q) * sin(mθ)\n",
" x_rotated_even = x_even * cos - x_odd * sin\n",
" x_rotated_odd = x_even * sin + x_odd * cos\n",
"\n",
" # Объединяем обратно в исходную размерность\n",
" x_rotated = torch.stack([x_rotated_even, x_rotated_odd], dim=-1)\n",
" x_rotated = x_rotated.flatten(-2) # [batch_size, seq_len, head_size]\n",
"\n",
" return x_rotated\n",
"\n",
"\n",
"class Decoder(nn.Module):\n",
" def __init__(self, \n",
" num_q_heads: int,\n",
" emb_size: int,\n",
" head_size: int,\n",
" max_seq_len: int,\n",
" rope: RoPE,\n",
" dropout: float = 0.1\n",
" ):\n",
" super().__init__()\n",
" self._heads = MultiQueryAttention(\n",
" num_q_heads=num_q_heads, \n",
" emb_size=emb_size, \n",
" head_size=head_size, \n",
" max_seq_len=max_seq_len,\n",
" rope=rope,\n",
" dropout=dropout\n",
" )\n",
" self._ff = GeGLU(emb_size=emb_size, dropout=dropout)\n",
" self._norm1 = RMSNorm(emb_size)\n",
" self._norm2 = RMSNorm(emb_size)\n",
"\n",
" def forward(self, x: torch.Tensor, mask: torch.Tensor = None, use_cache: bool = True, cache: list = None) -> torch.Tensor:\n",
" norm1_out = self._norm1(x)\n",
" attention, kv_caches = self._heads(norm1_out, mask, use_cache=use_cache, cache=cache)\n",
" out = attention + x\n",
" \n",
" norm2_out = self._norm2(out)\n",
" ffn_out = self._ff(norm2_out)\n",
"\n",
" if use_cache is True:\n",
" return (ffn_out + out, kv_caches)\n",
" else:\n",
" return (ffn_out + out, None)\n",
"\n",
"\n",
"\n",
"from torch import nn\n",
"import torch\n",
"import torch.nn.functional as F\n",
"\n",
"class Gemma(nn.Module):\n",
" def __init__(self,\n",
" vocab_size: int,\n",
" max_seq_len: int,\n",
" emb_size: int,\n",
" num_q_heads: int,\n",
" head_size: int,\n",
" num_layers: int,\n",
" dropout: float = 0.1,\n",
" device: str = 'cpu'\n",
" ):\n",
" super().__init__()\n",
" self._vocab_size = vocab_size\n",
" self._max_seq_len = max_seq_len\n",
" self._emb_size = emb_size\n",
" self._num_q_heads = num_q_heads\n",
" self._head_size = head_size\n",
" self._num_layers = num_layers\n",
" self._dropout = dropout\n",
" self._device = device\n",
" \n",
" self.validation_loss = None\n",
"\n",
" # Инициализация слоев\n",
" self._token_embeddings = TokenEmbeddings(\n",
" vocab_size=vocab_size, \n",
" emb_size=emb_size\n",
" )\n",
" self._position_embeddings = RoPE(\n",
" head_size=head_size,\n",
" max_seq_len=max_seq_len\n",
" )\n",
" #self._position_embeddings = PositionalEmbeddings(\n",
" # max_seq_len=max_seq_len, \n",
" # emb_size=emb_size\n",
" #)\n",
" self._dropout = nn.Dropout(dropout)\n",
" self._decoders = nn.ModuleList([Decoder(\n",
" num_q_heads=num_q_heads,\n",
" emb_size=emb_size,\n",
" head_size=head_size,\n",
" max_seq_len=max_seq_len,\n",
" rope=self._position_embeddings,\n",
" dropout=dropout \n",
" ) for _ in range(num_layers)])\n",
" self._norm = RMSNorm(emb_size)\n",
" self._linear = nn.Linear(emb_size, vocab_size)\n",
"\n",
" def forward(self, x: torch.Tensor, use_cache: bool = True, cache: list = None) -> tuple:\n",
" # Проверка длины последовательности (только при отсутствии кэша)\n",
" if cache is None and x.size(1) > self._max_seq_len:\n",
" raise ValueError(f\"Длина последовательности {x.size(1)} превышает максимальную {self.max_seq_len}\")\n",
" \n",
" # Эмбеддинги токенов и позиций\n",
" tok_out = self._token_embeddings(x) # [batch, seq_len, emb_size]\n",
" #pos_out = self._position_embeddings(x) # [batch, seq_len, emb_size]\n",
" \n",
" # Комбинирование\n",
" out = self._dropout(tok_out) # [batch, seq_len, emb_size]\n",
" \n",
" # Стек декодеров с передачей кэша\n",
" new_cache = []\n",
" for i, decoder in enumerate(self._decoders):\n",
" decoder_cache = cache[i] if cache is not None else None\n",
" decoder_result = decoder(out, use_cache=use_cache, cache=decoder_cache)\n",
"\n",
" # Извлекаем результат из кортежа\n",
" if use_cache:\n",
" out, decoder_new_cache = decoder_result\n",
" new_cache.append(decoder_new_cache)\n",
" else:\n",
" out = decoder_result[0]\n",
"\n",
" out = self._norm(out)\n",
" logits = self._linear(out)\n",
" \n",
" # Возвращаем результат с учетом use_cache\n",
" if use_cache:\n",
" return (logits, new_cache)\n",
" else:\n",
" return (logits, None)\n",
"\n",
" def generate(self,\n",
" x: torch.Tensor, \n",
" max_new_tokens: int, \n",
" do_sample: bool,\n",
" temperature: float = 1.0,\n",
" top_k: int = None,\n",
" top_p: float = None,\n",
" use_cache: bool = True\n",
" ) -> torch.Tensor:\n",
" cache = None\n",
"\n",
" for _ in range(max_new_tokens):\n",
" if use_cache and cache is not None:\n",
" # Используем кэш - передаем только последний токен\n",
" x_input = x[:, -1:] # [batch_size, 1]\n",
" else:\n",
" # Первая итерация или кэш отключен - передаем всю последовательность\n",
" x_input = x\n",
" \n",
" # Прямой проход с кэшем\n",
" logits, new_cache = self.forward(x_input, use_cache=use_cache, cache=cache)\n",
" \n",
" # Обновляем кэш для следующей итерации\n",
" if use_cache:\n",
" cache = new_cache\n",
"\n",
" last_logits = logits[:, -1, :] # [batch_size, vocab_size]\n",
"\n",
" # Масштабируем логиты температурой\n",
" if temperature > 0:\n",
" logits_scaled = last_logits / temperature\n",
" else:\n",
" logits_scaled = last_logits\n",
"\n",
" if do_sample == True and top_k != None:\n",
" _, topk_indices = torch.topk(logits_scaled, top_k, dim=-1)\n",
"\n",
" # # Заменим все НЕ top-k логиты на -inf\n",
" masked_logits = logits_scaled.clone()\n",
" vocab_size = logits_scaled.size(-1)\n",
"\n",
" # создаём маску: 1, если токен НЕ в topk_indices\n",
" mask = torch.ones_like(logits_scaled, dtype=torch.uint8)\n",
" mask.scatter_(1, topk_indices, 0) # 0 там, где top-k индексы\n",
" masked_logits[mask.byte()] = float('-inf')\n",
"\n",
" logits_scaled = masked_logits\n",
"\n",
" if do_sample == True and top_p != None:\n",
" # 1. Применим softmax, чтобы получить вероятности:\n",
" probs = F.softmax(logits_scaled, dim=-1) # [B, vocab_size]\n",
" # 2. Отсортируем токены по убыванию вероятностей:\n",
" sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1)\n",
" # 3. Посчитаем кумулятивную сумму вероятностей:\n",
" cum_probs = torch.cumsum(sorted_probs, dim=-1) # [B, vocab_size]\n",
" # 4. Определим маску: оставить токены, пока сумма < top_p\n",
" sorted_mask = (cum_probs <= top_p).byte() # [B, vocab_size]\n",
" # Гарантируем, что хотя бы первый токен останется\n",
" sorted_mask[:, 0] = 1\n",
" # 5. Преобразуем маску обратно в оригинальный порядок:\n",
" # Создаём полную маску из 0\n",
" mask = torch.zeros_like(probs, dtype=torch.uint8)\n",
" # Устанавливаем 1 в местах нужных токенов\n",
" mask.scatter_(dim=1, index=sorted_indices, src=sorted_mask)\n",
" # 6. Зануляем логиты токенов вне топ-p:\n",
" logits_scaled[~mask] = float('-inf')\n",
"\n",
" # 4. Применяем Softmax\n",
" probs = F.softmax(logits_scaled, dim=-1) # [batch_size, vocab_size]\n",
"\n",
"\n",
" if do_sample == True:\n",
" # 5. Если do_sample равен True, то отбираем токен случайно с помощью torch.multinomial\n",
" next_token = torch.multinomial(probs, num_samples=1) # [batch_size, 1]\n",
" else:\n",
" # 5. Если do_sample равен False, то выбираем токен с максимальной вероятностью\n",
" next_token = torch.argmax(probs, dim=-1, keepdim=True) # [batch_size, 1]\n",
" \n",
" # 6. Добавляем его к последовательности\n",
" x = torch.cat([x, next_token], dim=1) # [batch_size, seq_len+1]\n",
" return x\n",
"\n",
" def save(self, path):\n",
" torch.save({\n",
" 'model_state_dict': self.state_dict(),\n",
" 'vocab_size': self._vocab_size,\n",
" 'max_seq_len': self._max_seq_len,\n",
" 'emb_size': self._emb_size,\n",
" 'num_heads': self._num_heads,\n",
" 'head_size': self._head_size,\n",
" 'num_layers': self._num_layers\n",
" }, path)\n",
"\n",
" @classmethod\n",
" def load(cls, path, device):\n",
" checkpoint = torch.load(path, map_location=device)\n",
" model = cls(\n",
" vocab_size=checkpoint['vocab_size'],\n",
" max_seq_len=checkpoint['max_seq_len'],\n",
" emb_size=checkpoint['emb_size'],\n",
" num_heads=checkpoint['num_heads'],\n",
" head_size=checkpoint['head_size'],\n",
" num_layers=checkpoint['num_layers']\n",
" )\n",
" model.load_state_dict(checkpoint['model_state_dict'])\n",
" model.to(device)\n",
" return model\n",
"\n",
" @property\n",
" def max_seq_len(self) -> int:\n",
" return self._max_seq_len\n",
"\n",
"\n",
"\n",
"\n"
]
},
{
"cell_type": "markdown",
"id": "42746fea",
"metadata": {},
"source": [
"## 2. Обучение Gemma\n",
"\n",
"Gemma обучается в два этапа:\n",
"\n",
"- 1⃣ **Предобучение (Unsupervised Pretraining)** \n",
"- 2⃣ **Дообучение (Supervised Fine-Tuning)**"
]
},
{
"cell_type": "markdown",
"id": "f6b0234d",
"metadata": {},
"source": [
"\n",
"\n",
"### 5.1 Предобучение\n",
"\n",
"На первом этапе модель обучается без разметки: она получает большой корпус текстов и учится **предсказывать следующий токен** по предыдущим.\n",
"\n",
"Функция потерь:\n",
"$$\n",
"L = - \\sum_{t=1}^{T} \\log P(x_t | x_1, x_2, ..., x_{t-1})\n",
"$$\n",
"\n",
"Таким образом, модель учится строить вероятностную модель языка, \"угадывая\" продолжение текста.\n"
]
},
{
"cell_type": "markdown",
"id": "82c94641",
"metadata": {},
"source": [
"Во время **предобучения** Mistral учится **предсказывать следующий токен** (language modeling task). \n",
"Формально: \n",
"$$ \n",
"P(x_t ,|, x_1, x_2, \\dots, x_{t-1}) \n",
"$$ \n",
"То есть, если на вход подаётся предложение `\"I love deep\"`, модель должна предсказать `\"learning\"`.\n"
]
},
{
"cell_type": "markdown",
"id": "b064fadc",
"metadata": {},
"source": [
"### ✅ 5.1.1 Подготовка данных\n",
"\n",
"Создадим **датасет** на основе BPE-токенизатора:"
]
},
{
"cell_type": "markdown",
"id": "f1516a37",
"metadata": {},
"source": [
"**BPE Tokenizator**"
]
},
{
"cell_type": "code",
"execution_count": 67,
"id": "8a5a975a",
"metadata": {},
"outputs": [],
"source": [
"class BPE:\n",
" def __init__(self, vocab_size: int):\n",
" self.vocab_size = vocab_size\n",
" self.id2token = {}\n",
" self.token2id = {}\n",
"\n",
" def fit(self, text: str):\n",
" # 1. Получаем уникальные токены (символы)\n",
" unique_tokens = sorted(set(text))\n",
" tokens = unique_tokens.copy()\n",
"\n",
" # 2. Разбиваем текст на токены-символы\n",
" sequence = list(text)\n",
"\n",
" # 3. Объединяем токены до достижения нужного размера словаря\n",
" while len(tokens) < self.vocab_size:\n",
" #print(f'len={len(tokens)} < {self.vocab_size}')\n",
" # Считаем частоты пар\n",
" pair_freq = {}\n",
" for i in range(len(sequence) - 1):\n",
" pair = (sequence[i], sequence[i + 1])\n",
" #print(f'pair = {pair}')\n",
" if pair not in pair_freq:\n",
" pair_freq[pair] = 0\n",
" pair_freq[pair] += 1\n",
"\n",
"\n",
" #print(f'pair_freq = {pair_freq}') \n",
" if not pair_freq:\n",
" break # нет пар — выходим\n",
"\n",
" #for x in pair_freq.items():\n",
" # self.debug(x, sequence)\n",
"\n",
" # Находим самую частую пару (в случае равенства — та, что встретилась первой)\n",
" most_frequent_pair = max(pair_freq.items(), key=lambda x: (x[1], -self._pair_first_index(sequence, x[0])))[0]\n",
" #print(most_frequent_pair)\n",
" # Создаем новый токен\n",
" new_token = most_frequent_pair[0] + most_frequent_pair[1]\n",
" #print(f\"new token={new_token}\")\n",
" tokens.append(new_token)\n",
" #print(f\"tokens={tokens}\")\n",
"\n",
" i = 0\n",
" new_sequence = []\n",
"\n",
" while i < len(sequence):\n",
" if i < len(sequence) - 1 and (sequence[i], sequence[i + 1]) == most_frequent_pair:\n",
" new_sequence.append(new_token)\n",
" i += 2 # пропускаем два символа — заменённую пару\n",
" else:\n",
" new_sequence.append(sequence[i])\n",
" i += 1\n",
" sequence = new_sequence\n",
" #break\n",
" \n",
" # 4. Создаем словари\n",
" self.vocab = tokens.copy()\n",
" self.token2id = dict(zip(tokens, range(self.vocab_size)))\n",
" self.id2token = dict(zip(range(self.vocab_size), tokens))\n",
"\n",
" def _pair_first_index(self, sequence, pair):\n",
" for i in range(len(sequence) - 1):\n",
" if (sequence[i], sequence[i + 1]) == pair:\n",
" return i\n",
" return float('inf') # если пара не найдена (в теории не должно случиться)\n",
"\n",
"\n",
" def encode(self, text: str):\n",
" # 1. Разбиваем текст на токены-символы\n",
" sequence = list(text)\n",
" # 2. Инициализация пустого списка токенов\n",
" tokens = []\n",
" # 3. Установить i = 0\n",
" i = 0\n",
" while i < len(text):\n",
" # 3.1 Найти все токены в словаре, начинающиеся с text[i]\n",
" start_char = text[i]\n",
" result = [token for token in self.vocab if token.startswith(start_char)]\n",
" # 3.2 Выбрать самый длинный подходящий токен\n",
" find_token = self._find_max_matching_token(text[i:], result)\n",
" if find_token is None:\n",
" # Обработка неизвестного символа\n",
" tokens.append(text[i]) # Добавляем сам символ как токен\n",
" i += 1\n",
" else:\n",
" # 3.3 Добавить токен в результат\n",
" tokens.append(find_token)\n",
" # 3.4 Увеличить i на длину токена\n",
" i += len(find_token)\n",
"\n",
" # 4. Заменить токены на их ID\n",
" return self._tokens_to_ids(tokens)\n",
"\n",
" def _find_max_matching_token(self, text: str, tokens: list):\n",
" \"\"\"Находит самый длинный токен из списка, с которого начинается текст\"\"\"\n",
" matching = [token for token in tokens if text.startswith(token)]\n",
" return max(matching, key=len) if matching else None\n",
"\n",
" def _tokens_to_ids(self, tokens):\n",
" \"\"\"Конвертирует список токенов в их ID с обработкой неизвестных токенов\"\"\"\n",
" ids = []\n",
" for token in tokens:\n",
" if token in self.token2id:\n",
" ids.append(self.token2id[token])\n",
" else:\n",
" ids.append(0) # Специальное значение\n",
" return ids\n",
"\n",
"\n",
" def decode(self, ids: list) -> str:\n",
" return ''.join(self._ids_to_tokens(ids))\n",
"\n",
" def _ids_to_tokens(self, ids: list) -> list:\n",
" \"\"\"Конвертирует список Ids в их tokens\"\"\"\n",
" tokens = []\n",
" for id in ids:\n",
" if id in self.id2token:\n",
" tokens.append(self.id2token[id])\n",
" else:\n",
" tokens.append('') # Специальное значение\n",
" return tokens\n",
"\n",
"\n",
" def save(self, filename):\n",
" with open(filename, 'wb') as f:\n",
" dill.dump(self, f)\n",
" print(f\"Объект сохранён в {filename}\")\n",
"\n",
"\n",
" @classmethod\n",
" def load(cls, filename):\n",
" with open(filename, 'rb') as f:\n",
" obj = dill.load(f)\n",
" \n",
" print(f\"Объект загружен из {filename}\")\n",
" return obj"
]
},
{
"cell_type": "code",
"execution_count": 68,
"id": "1927f6d2",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from torch.utils.data import Dataset, DataLoader\n",
"\n",
"class GPTDataset(Dataset):\n",
" def __init__(self, text: str, bpe: BPE, block_size: int):\n",
" self.bpe = bpe\n",
" self.block_size = block_size\n",
" self.data = bpe.encode(text)\n",
" \n",
" def __len__(self):\n",
" return len(self.data) - self.block_size\n",
"\n",
" def __getitem__(self, idx):\n",
" x = torch.tensor(self.data[idx:idx+self.block_size], dtype=torch.long)\n",
" y = torch.tensor(self.data[idx+1:idx+self.block_size+1], dtype=torch.long)\n",
" return x, y"
]
},
{
"cell_type": "markdown",
"id": "a14d0c68",
"metadata": {},
"source": [
"### ✅ 5.1.2 Цикл обучения\n",
"\n",
"Для обучения создадим функцию:"
]
},
{
"cell_type": "code",
"execution_count": 69,
"id": "7c5c57b0",
"metadata": {},
"outputs": [],
"source": [
"import torch.nn.functional as F\n",
"from torch import optim\n",
"\n",
"def train_gemma(model, dataset, epochs=5, batch_size=32, lr=3e-4, device='cpu'):\n",
" dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)\n",
" optimizer = optim.AdamW(model.parameters(), lr=lr)\n",
"\n",
" model.to(device)\n",
" model.train()\n",
"\n",
" for epoch in range(epochs):\n",
" total_loss = 0\n",
" for x, y in dataloader:\n",
" x, y = x.to(device), y.to(device)\n",
"\n",
" # Прямой проход\n",
" logits, _ = model(x, use_cache=False) # [B, T, vocab_size]\n",
"\n",
" # Перестроим выход под CrossEntropy\n",
" loss = F.cross_entropy(logits.view(-1, logits.size(-1)), y.view(-1))\n",
"\n",
" # Обратное распространение\n",
" optimizer.zero_grad()\n",
" loss.backward()\n",
" optimizer.step()\n",
"\n",
" total_loss += loss.item()\n",
"\n",
" avg_loss = total_loss / len(dataloader)\n",
" print(f\"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}\")\n",
"\n",
" return model"
]
},
{
"cell_type": "markdown",
"id": "f96dea80",
"metadata": {},
"source": [
"### ✅ 5.1.3 Пример запуска\n",
"\n",
"\n",
"**🧠 Конфигурация Gemma Mini**\n",
"\n",
"\n",
"| Параметр | Значение | Описание |\n",
"| --------------- | -------- | --------------------------------------------- |\n",
"| **vocab_size** | `50257` | Размер словаря (BPE токенизатор OpenAI) |\n",
"| **max_seq_len** | `512` | Максимальная длина входной последовательности |\n",
"| **emb_size** | `256` | Размер эмбеддингов (векторное пространство) |\n",
"| **num_heads** | `4` | Количество голов в multi-head attention |\n",
"| **head_size** | `64` | Размерность одной головы внимания (768 / 12) |\n",
"| **num_layers** | `4` | Количество блоков (декодеров) |\n",
"| **dropout** | `0.1` | Вероятность дропаута |\n"
]
},
{
"cell_type": "code",
"execution_count": 70,
"id": "cda62fc2",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Dataset length: 20\n",
"Epoch 1/100, Loss: 3.8178\n",
"Epoch 2/100, Loss: 1.5683\n",
"Epoch 3/100, Loss: 0.6454\n",
"Epoch 4/100, Loss: 0.3353\n",
"Epoch 5/100, Loss: 0.2306\n",
"Epoch 6/100, Loss: 0.1581\n",
"Epoch 7/100, Loss: 0.1253\n",
"Epoch 8/100, Loss: 0.1063\n",
"Epoch 9/100, Loss: 0.0923\n",
"Epoch 10/100, Loss: 0.0909\n",
"Epoch 11/100, Loss: 0.0761\n",
"Epoch 12/100, Loss: 0.0932\n",
"Epoch 13/100, Loss: 0.0775\n",
"Epoch 14/100, Loss: 0.0797\n",
"Epoch 15/100, Loss: 0.0623\n",
"Epoch 16/100, Loss: 0.0795\n",
"Epoch 17/100, Loss: 0.0703\n",
"Epoch 18/100, Loss: 0.0581\n",
"Epoch 19/100, Loss: 0.0613\n",
"Epoch 20/100, Loss: 0.0660\n",
"Epoch 21/100, Loss: 0.0731\n",
"Epoch 22/100, Loss: 0.0644\n",
"Epoch 23/100, Loss: 0.0602\n",
"Epoch 24/100, Loss: 0.0557\n",
"Epoch 25/100, Loss: 0.0595\n",
"Epoch 26/100, Loss: 0.0688\n",
"Epoch 27/100, Loss: 0.0545\n",
"Epoch 28/100, Loss: 0.0561\n",
"Epoch 29/100, Loss: 0.0581\n",
"Epoch 30/100, Loss: 0.0627\n",
"Epoch 31/100, Loss: 0.0555\n",
"Epoch 32/100, Loss: 0.0538\n",
"Epoch 33/100, Loss: 0.0531\n",
"Epoch 34/100, Loss: 0.0535\n",
"Epoch 35/100, Loss: 0.0474\n",
"Epoch 36/100, Loss: 0.0516\n",
"Epoch 37/100, Loss: 0.0540\n",
"Epoch 38/100, Loss: 0.0533\n",
"Epoch 39/100, Loss: 0.0519\n",
"Epoch 40/100, Loss: 0.0606\n",
"Epoch 41/100, Loss: 0.0489\n",
"Epoch 42/100, Loss: 0.0513\n",
"Epoch 43/100, Loss: 0.0563\n",
"Epoch 44/100, Loss: 0.0522\n",
"Epoch 45/100, Loss: 0.0512\n",
"Epoch 46/100, Loss: 0.0490\n",
"Epoch 47/100, Loss: 0.0469\n",
"Epoch 48/100, Loss: 0.0500\n",
"Epoch 49/100, Loss: 0.0497\n",
"Epoch 50/100, Loss: 0.0532\n",
"Epoch 51/100, Loss: 0.0557\n",
"Epoch 52/100, Loss: 0.0480\n",
"Epoch 53/100, Loss: 0.0593\n",
"Epoch 54/100, Loss: 0.0498\n",
"Epoch 55/100, Loss: 0.0476\n",
"Epoch 56/100, Loss: 0.0496\n",
"Epoch 57/100, Loss: 0.0445\n",
"Epoch 58/100, Loss: 0.0494\n",
"Epoch 59/100, Loss: 0.0572\n",
"Epoch 60/100, Loss: 0.0490\n",
"Epoch 61/100, Loss: 0.0580\n",
"Epoch 62/100, Loss: 0.0499\n",
"Epoch 63/100, Loss: 0.0501\n",
"Epoch 64/100, Loss: 0.0538\n",
"Epoch 65/100, Loss: 0.0484\n",
"Epoch 66/100, Loss: 0.0520\n",
"Epoch 67/100, Loss: 0.0527\n",
"Epoch 68/100, Loss: 0.0501\n",
"Epoch 69/100, Loss: 0.0506\n",
"Epoch 70/100, Loss: 0.0480\n",
"Epoch 71/100, Loss: 0.0470\n",
"Epoch 72/100, Loss: 0.0498\n",
"Epoch 73/100, Loss: 0.0484\n",
"Epoch 74/100, Loss: 0.0435\n",
"Epoch 75/100, Loss: 0.0456\n",
"Epoch 76/100, Loss: 0.0480\n",
"Epoch 77/100, Loss: 0.0477\n",
"Epoch 78/100, Loss: 0.0494\n",
"Epoch 79/100, Loss: 0.0490\n",
"Epoch 80/100, Loss: 0.0474\n",
"Epoch 81/100, Loss: 0.0462\n",
"Epoch 82/100, Loss: 0.0432\n",
"Epoch 83/100, Loss: 0.0447\n",
"Epoch 84/100, Loss: 0.0482\n",
"Epoch 85/100, Loss: 0.0493\n",
"Epoch 86/100, Loss: 0.0452\n",
"Epoch 87/100, Loss: 0.0417\n",
"Epoch 88/100, Loss: 0.0489\n",
"Epoch 89/100, Loss: 0.0487\n",
"Epoch 90/100, Loss: 0.0486\n",
"Epoch 91/100, Loss: 0.0451\n",
"Epoch 92/100, Loss: 0.0443\n",
"Epoch 93/100, Loss: 0.0442\n",
"Epoch 94/100, Loss: 0.0486\n",
"Epoch 95/100, Loss: 0.0464\n",
"Epoch 96/100, Loss: 0.0429\n",
"Epoch 97/100, Loss: 0.0461\n",
"Epoch 98/100, Loss: 0.0496\n",
"Epoch 99/100, Loss: 0.0476\n",
"Epoch 100/100, Loss: 0.0441\n"
]
},
{
"data": {
"text/plain": [
"Gemma(\n",
" (_token_embeddings): TokenEmbeddings(\n",
" (_embedding): Embedding(100, 256)\n",
" )\n",
" (_position_embeddings): RoPE()\n",
" (_dropout): Dropout(p=0.1, inplace=False)\n",
" (_decoders): ModuleList(\n",
" (0-3): 4 x Decoder(\n",
" (_heads): MultiQueryAttention(\n",
" (_rope): RoPE()\n",
" (_q): Linear(in_features=256, out_features=256, bias=True)\n",
" (_k): Linear(in_features=256, out_features=64, bias=True)\n",
" (_v): Linear(in_features=256, out_features=64, bias=True)\n",
" (_layer): Linear(in_features=256, out_features=256, bias=True)\n",
" (_dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (_ff): GeGLU(\n",
" (_gate): Linear(in_features=256, out_features=1024, bias=True)\n",
" (_up): Linear(in_features=256, out_features=1024, bias=True)\n",
" (_down): Linear(in_features=1024, out_features=256, bias=True)\n",
" (_activation): GELU()\n",
" (_dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (_norm1): RMSNorm()\n",
" (_norm2): RMSNorm()\n",
" )\n",
" )\n",
" (_norm): RMSNorm()\n",
" (_linear): Linear(in_features=256, out_features=100, bias=True)\n",
")"
]
},
"execution_count": 70,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# 1. Исходный текст\n",
"text = \"Deep learning is amazing. Transformers changed the world. Attention is all you need. GPT models revolutionized NLP.\"\n",
"\n",
"# 2. Обучаем токенизатор\n",
"bpe = BPE(vocab_size=100)\n",
"bpe.fit(text)\n",
"\n",
"# 3. Создаем датасет\n",
"dataset = GPTDataset(text, bpe, block_size=8)\n",
"print(f\"Dataset length: {len(dataset)}\")\n",
"\n",
"# 4. Инициализируем модель\n",
"model = Gemma(\n",
" vocab_size=len(bpe.vocab), # размер словаря BPE\n",
" max_seq_len=512, # GPT-2 использует контекст в 512 токена\n",
" emb_size=256, # размер эмбеддингов\n",
" num_q_heads=4, # количество голов внимания\n",
" head_size=64, # размер каждой головы (256 / 4)\n",
" num_layers=4, # количество блоков Transformer\n",
" dropout=0.1 # стандартный dropout GPT-2\n",
")\n",
"\n",
"# 5. Обучаем\n",
"train_gemma(model, dataset, epochs=100, batch_size=4)"
]
},
{
"cell_type": "markdown",
"id": "f5a37671",
"metadata": {},
"source": [
"\n",
"---\n",
"\n",
"### 5.2 Дообучение\n",
"\n",
"После предобучения Gemma уже знает структуру и грамматику языка. \n",
"На втором этапе она дообучается на конкретных задачах (например, классификация, QA) с помощью размеченных данных.\n",
"\n",
"Технически это почти то же обучение, только:\n",
"\n",
"- Загружаем модель с уже обученными весами.\n",
"- Используем новые данные.\n",
"- Можно уменьшить скорость обучения.\n",
"- Иногда замораживают часть слоёв (например, эмбеддинги).\n"
]
},
{
"cell_type": "code",
"execution_count": 71,
"id": "d062af63",
"metadata": {},
"outputs": [],
"source": [
"def fine_tune_gemma(model, dataset, epochs=3, batch_size=16, lr=1e-5, device='cpu', freeze_embeddings=True):\n",
" if freeze_embeddings:\n",
" for param in model._token_embeddings.parameters():\n",
" param.requires_grad = False\n",
" for param in model._position_embeddings.parameters():\n",
" param.requires_grad = False\n",
"\n",
" dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)\n",
" optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=lr)\n",
"\n",
" model.to(device)\n",
" model.train()\n",
"\n",
" for epoch in range(epochs):\n",
" total_loss = 0\n",
" for x, y in dataloader:\n",
" x, y = x.to(device), y.to(device)\n",
" logits, _ = model(x, use_cache=False)\n",
" loss = F.cross_entropy(logits.view(-1, logits.size(-1)), y.view(-1))\n",
" optimizer.zero_grad()\n",
" loss.backward()\n",
" optimizer.step()\n",
" total_loss += loss.item()\n",
" print(f\"Fine-tune Epoch {epoch+1}/{epochs}, Loss: {total_loss / len(dataloader):.4f}\")"
]
},
{
"cell_type": "code",
"execution_count": 72,
"id": "064dd678",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Fine-tune Epoch 1/10, Loss: 4.9095\n",
"Fine-tune Epoch 2/10, Loss: 2.8684\n",
"Fine-tune Epoch 3/10, Loss: 1.7589\n",
"Fine-tune Epoch 4/10, Loss: 1.3044\n",
"Fine-tune Epoch 5/10, Loss: 1.0614\n",
"Fine-tune Epoch 6/10, Loss: 0.8326\n",
"Fine-tune Epoch 7/10, Loss: 0.6908\n",
"Fine-tune Epoch 8/10, Loss: 0.5926\n",
"Fine-tune Epoch 9/10, Loss: 0.5082\n",
"Fine-tune Epoch 10/10, Loss: 0.4758\n"
]
}
],
"source": [
"# Например, мы хотим дообучить модель на стиле коротких технических фраз\n",
"fine_tune_text = \"\"\"\n",
"Transformers revolutionize NLP.\n",
"Deep learning enables self-attention.\n",
"GPT generates text autoregressively.\n",
"\"\"\"\n",
"\n",
"dataset = GPTDataset(fine_tune_text, bpe, block_size=8)\n",
"\n",
"\n",
"# Запуск дообучения\n",
"fine_tune_gemma(model, dataset, epochs=10, batch_size=4, lr=1e-4)"
]
},
{
"cell_type": "markdown",
"id": "a496ddae",
"metadata": {},
"source": [
"## 📝 6. Генерация текста после обучения"
]
},
{
"cell_type": "code",
"execution_count": 73,
"id": "645f777c",
"metadata": {},
"outputs": [],
"source": [
"def generate_text(model, bpe, prompt: str, max_new_tokens=20, device='cpu'):\n",
" model.eval()\n",
" ids = torch.tensor([bpe.encode(prompt)], dtype=torch.long).to(device)\n",
" out = model.generate(ids, max_new_tokens=max_new_tokens, do_sample=True)\n",
" text = bpe.decode(out[0].tolist())\n",
" return text"
]
},
{
"cell_type": "code",
"execution_count": 74,
"id": "14778ecd",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Deep learningenena lf lenenssssf \n"
]
}
],
"source": [
"print(generate_text(model, bpe, \"Deep learning\", max_new_tokens=20))"
]
},
{
"cell_type": "markdown",
"id": "1b70d909",
"metadata": {},
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.9"
}
},
"nbformat": 4,
"nbformat_minor": 5
}