\n",
"\n",
"Gemma 1 вышла в феврале 2024 года.\n",
"\n",
"По архитектуре модель больше всего похожа на Llama'у. Содержит уже знакомые нам RoPE и RMSNorm. Но есть и новинки:\n",
"\n",
"* **Multi-Query Attention (MQA)** — крайне экономный вариант механизма внимания.\n",
"* **GeGLU** — гибридная функция активации. Почти клон SwiGLU :)\n",
"\n",
"Обе довольно легкие для внедрения, по сравнению с прошлыми новинками :)\n"
]
},
{
"cell_type": "markdown",
"id": "cea30169",
"metadata": {},
"source": [
"# Multi-Query Attention\n",
"\n",
"По своей сути, Multi-Query Attention (MQA) — это частный случай Grouped Query Attention (GQA), который мы реализовали в уроке про Mistral.\n",
"\n",
"
\n",
" \n",
"
\n",
"\n",
"В GQA на каждую голову приходится один вектор запроса (query). При этом каждый вектор ключа (key) и значения (value) обслуживает ( n )-голов.\n",
"Так вот, в MQA на все головы (в одном блоке декодера) приходится всего по одному вектору ключа (key) и одному вектору значения (value). Это такая радикальная форма экономии :)\n"
]
},
{
"cell_type": "markdown",
"id": "539550fe",
"metadata": {},
"source": [
"**Multi-Query Attention (разработка)**"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "6be61c63",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from torch import nn\n",
"from typing import Optional\n",
"\n",
"\n",
"class RoPE(nn.Module):\n",
"\n",
" def __init__(self, head_size: int, max_seq_len: int, base: int = 10_000):\n",
" super().__init__()\n",
" assert head_size % 2 == 0, \"head_size должен быть четным\"\n",
"\n",
" # Вычисление частот: θ_i = base^(-2i/d) для i ∈ [0, d/2-1]\n",
" freqs = 1.0 / (base ** (2 * torch.arange(head_size // 2).float() / head_size))\n",
"\n",
" # Позиции от 0 до max_seq_len-1\n",
" positions = torch.arange(max_seq_len).float()\n",
"\n",
" # Внешнее произведение: m * θ_i для всех позиций и частот\n",
" freq_matrix = positions.unsqueeze(1) * freqs.unsqueeze(0)\n",
"\n",
" # Предвычисление матриц косинусов и синусов\n",
" self.register_buffer(\"cos_matrix\", torch.cos(freq_matrix))\n",
" self.register_buffer(\"sin_matrix\", torch.sin(freq_matrix))\n",
"\n",
" def forward(self, x: torch.Tensor) -> torch.Tensor: # [batch_size × seq_len × head_size] [batch_size × num_heads × seq_len × head_size]\n",
" batch_size, num_heads, seq_len, head_size = x.shape\n",
"\n",
" # Берем нужную часть матриц и приводим к типу x\n",
" cos = self.cos_matrix[:seq_len].to(x.dtype) # [seq_len, head_size//2]\n",
" sin = self.sin_matrix[:seq_len].to(x.dtype) # [seq_len, head_size//2]\n",
"\n",
" # Явное изменение формы для broadcasting\n",
" cos = cos.reshape(1, 1, seq_len, head_size // 2)\n",
" sin = sin.reshape(1, 1, seq_len, head_size // 2)\n",
"\n",
" # Разделяем на четные и нечетные компоненты по ПОСЛЕДНЕМУ измерению\n",
" x_even = x[..., 0::2] # [batch_size, num_heads, seq_len, head_size//2]\n",
" x_odd = x[..., 1::2] # [batch_size, num_heads, seq_len, head_size//2]\n",
"\n",
" # Применяем поворот: q' = q * cos(mθ) + rotate(q) * sin(mθ)\n",
" x_rotated_even = x_even * cos - x_odd * sin\n",
" x_rotated_odd = x_even * sin + x_odd * cos\n",
"\n",
" # Объединяем обратно в исходную размерность\n",
" x_rotated = torch.stack([x_rotated_even, x_rotated_odd], dim=-1)\n",
" x_rotated = x_rotated.flatten(-2) # [batch_size, seq_len, head_size]\n",
"\n",
" return x_rotated"
]
},
{
"cell_type": "code",
"execution_count": 34,
"id": "811921b1",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from torch import nn\n",
"import torch.nn.functional as F\n",
"\n",
"class MultiQueryAttention(nn.Module):\n",
" def __init__(\n",
" self,\n",
" num_q_heads: int,\n",
" emb_size: int,\n",
" head_size: int,\n",
" max_seq_len: int,\n",
" rope: RoPE = None,\n",
" dropout: float = 0.1,\n",
" ):\n",
" super().__init__()\n",
" self._num_q_heads = num_q_heads\n",
" self._head_size = head_size\n",
" self._max_seq_len = max_seq_len\n",
" self._rope = rope\n",
" \n",
" self._q = nn.Linear(emb_size, num_q_heads * head_size)\n",
" self._k = nn.Linear(emb_size, head_size)\n",
" self._v = nn.Linear(emb_size, head_size)\n",
"\n",
" # Создание causal маски\n",
" mask = torch.tril(torch.ones(max_seq_len, max_seq_len))\n",
" self.register_buffer(\n",
" \"_tril_mask\", mask.bool() if hasattr(torch, \"bool\") else mask.byte()\n",
" )\n",
" \n",
" self._layer = nn.Linear(num_q_heads * head_size, emb_size)\n",
" self._dropout = nn.Dropout(dropout)\n",
"\n",
" def forward(\n",
" self,\n",
" x: torch.Tensor,\n",
" mask: torch.Tensor = None,\n",
" use_cache: bool = True,\n",
" cache: list = None,\n",
" ):\n",
" batch_size, seq_len, emb_size = x.shape\n",
" if seq_len > self._max_seq_len:\n",
" raise ValueError(\n",
" f\"Длина последовательности {seq_len} превышает максимум {self._max_seq_len}\"\n",
" )\n",
"\n",
" # Пропустите тензор x через матрицы Wq, Wk , Wv, чтобы получить матрицы запроса, ключа и значения.\n",
" k = self._k(x) # [B, T, hs]\n",
" q = self._q(x) # [B, T, hs]\n",
" v = self._v(x) # [B, T, hs]\n",
"\n",
" # Шаг 2: Изменение формы для multi-head\n",
" # [batch_size, seq_len, num_heads * head_size] \n",
" # -> [batch_size, seq_len, num_heads, head_size]\n",
" q = q.reshape(batch_size, seq_len, self._num_q_heads, self._head_size)\n",
" k = k.reshape(batch_size, seq_len, 1, self._head_size)\n",
" v = v.reshape(batch_size, seq_len, 1, self._head_size)\n",
"\n",
" # 3. Transpose: [B, T, H, hs] -> [B, H, T, hs]\n",
" q = q.transpose(1, 2)\n",
" k = k.transpose(1, 2)\n",
" v = v.transpose(1, 2)\n",
"\n",
" # Пропустите матрицы запроса и ключа через экземпляр rope, чтобы выполнить поворот.\n",
" if self._rope is not None:\n",
" # Применяем RoPE к Q и K (НЕ к V!)\n",
" q = self._rope(q) # [B, T, hs]\n",
" k = self._rope(k) # [B, T, hs]\n",
"\n",
"\n",
" # Если cache пришел, то объединяем кэш и одну строку из ключа и значения. Это будут новые key и value для последующих вычислений.\n",
" # 5. Кэширование (для autoregressive generation)\n",
" if cache is not None:\n",
" k_cache, v_cache = cache\n",
" k = torch.cat([k_cache, k], dim=2) # Concat по seq_len (dim=2)\n",
" v = torch.cat([v_cache, v], dim=2)\n",
"\n",
"\n",
" # Перемножим матрицы запроса и ключа (транспонированную), чтобы вычислить матрицу внимания.\n",
" # И разделить все значения в матрице внимания на корень из head_size.\n",
" scores = q @ k.transpose(-2, -1) / (self._head_size ** 0.5)\n",
"\n",
" # Если cache пришел, то маску не накладываем. Иначе наложите на матрицу внимания треугольную маску, созданную при инициализации. Все скрытые значения должны быть приведены к минус бесконечности: float('-inf').\n",
" if cache is None:\n",
" scores = scores.masked_fill(\n",
" ~self._tril_mask[:seq_len, :seq_len], float(\"-inf\")\n",
" )\n",
"\n",
" # Применить к матрице внимания (построчно) функцию Softmax.\n",
" weights = F.softmax(scores, dim=-1)\n",
"\n",
" # Перемножим матрицу внимания и матрицу значения.\n",
" x_out = weights @ v # [B, T, hs]\n",
"\n",
"\n",
" # Измените форму тензора на batch_size × seq_len × num_heads*head_size.\n",
" # Transpose обратно и concatenate heads\n",
" x_out = x_out.transpose(1, 2) # [B, T_q, H, hs]\n",
" x_out = x_out.contiguous() # Важно для reshape!\n",
" concatenated_attention = x_out.reshape(batch_size, seq_len, self._num_q_heads * self._head_size)\n",
"\n",
"\n",
" # Пропустите получившийся тензор через последний линейный слой.\n",
" # 3. Проецируем в пространство эмбеддингов\n",
" projected_output = self._layer(concatenated_attention)\n",
"\n",
"\n",
" # 4. Применяем dropout для регуляризации\n",
" final_output = self._dropout(projected_output)\n",
"\n",
" if use_cache is True:\n",
" return (final_output, (k, v))\n",
" else:\n",
" return (final_output, None)\n"
]
},
{
"cell_type": "code",
"execution_count": 35,
"id": "97771d9a",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"✅ Test 1 - Output shape: torch.Size([2, 10, 512])\n",
"✅ Test 2 - First output shape: torch.Size([2, 5, 512])\n",
"✅ Test 2 - Second output shape: torch.Size([2, 1, 512])\n",
"\n",
"✅ Все тесты пройдены!\n"
]
}
],
"source": [
"\n",
"# Параметры\n",
"batch_size = 2\n",
"seq_len = 10\n",
"emb_size = 512\n",
"num_q_heads = 8\n",
"head_size = 64\n",
"max_seq_len = 512\n",
"\n",
" # Создание модели\n",
"rope = RoPE(head_size=head_size, max_seq_len=max_seq_len)\n",
"mha = MultiQueryAttention(\n",
" num_q_heads=num_q_heads,\n",
" emb_size=emb_size,\n",
" head_size=head_size,\n",
" max_seq_len=max_seq_len,\n",
" rope=rope,\n",
" dropout=0.1,\n",
")\n",
"\n",
" # Тест 1: Обычный forward pass\n",
"x = torch.randn(batch_size, seq_len, emb_size)\n",
"output, cache = mha(x, use_cache=False)\n",
"print(f\"✅ Test 1 - Output shape: {output.shape}\") # [2, 10, 512]\n",
"assert output.shape == (batch_size, seq_len, emb_size)\n",
"\n",
" # Тест 2: С кэшированием\n",
"x1 = torch.randn(batch_size, 5, emb_size)\n",
"output1, cache1 = mha(x1, use_cache=True)\n",
"print(f\"✅ Test 2 - First output shape: {output1.shape}\") # [2, 5, 512]\n",
"\n",
"x2 = torch.randn(batch_size, 1, emb_size)\n",
"output2, cache2 = mha(x2, use_cache=True, cache=cache1)\n",
"print(f\"✅ Test 2 - Second output shape: {output2.shape}\") # [2, 1, 512]\n",
"\n",
"print(\"\\n✅ Все тесты пройдены!\")"
]
},
{
"cell_type": "markdown",
"id": "b5875022",
"metadata": {},
"source": [
"Вот конвертированный Markdown для твоего HTML:\n",
"\n",
"---\n",
"\n",
"# GeGLU\n",
"\n",
"