fix(rope, attention): корректное позиционирование RoPE при генерации с кэшем

- Исправлена ошибка расчёта позиции для RoPE (Rotary Positional Embeddings) при автодополнении с использованием кэша.
- В HeadAttention теперь передаётся start_pos в RoPE, вычисляемый из длины кэша.
- Обновлена сигнатура и логика метода RoPE.forward.
- Обновлен ноутбук llama.ipynb под новые интерфейсы и выводы.

BREAKING CHANGE: переопределён метод forward у RoPE, требуется обновить код, если RoPE использовался вручную.
This commit is contained in:
Sergey Penkovsky
2025-10-14 12:03:20 +03:00
parent 3e4815fcc6
commit e5706a690d
3 changed files with 155 additions and 140 deletions

View File

@@ -89,10 +89,16 @@ class HeadAttention(nn.Module):
q = self._q(x) # [B, T, hs] q = self._q(x) # [B, T, hs]
v = self._v(x) # [B, T, hs] v = self._v(x) # [B, T, hs]
start_pos = 0
if cache is not None:
k_cache, v_cache = cache
cache_len = k_cache.shape[1]
start_pos = cache_len
if self._rope is not None: if self._rope is not None:
# ✅ Применяем RoPE к Q и K (НЕ к V!) # ✅ Применяем RoPE к Q и K (НЕ к V!)
q = self._rope(q) # [B, T, hs] q = self._rope(q, start_pos=start_pos) # [B, T, hs]
k = self._rope(k) # [B, T, hs] k = self._rope(k, start_pos=start_pos) # [B, T, hs]
if cache is not None: if cache is not None:
k_cache, v_cache = cache k_cache, v_cache = cache

View File

@@ -68,7 +68,7 @@ class RoPE(nn.Module):
self.register_buffer("cos_matrix", torch.cos(freq_matrix)) self.register_buffer("cos_matrix", torch.cos(freq_matrix))
self.register_buffer("sin_matrix", torch.sin(freq_matrix)) self.register_buffer("sin_matrix", torch.sin(freq_matrix))
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor, start_pos: int = 0) -> torch.Tensor:
""" """
Применение ротационного позиционного кодирования к входному тензору. Применение ротационного позиционного кодирования к входному тензору.
@@ -83,11 +83,11 @@ class RoPE(nn.Module):
2. Применение вращения через синусы и косинусы 2. Применение вращения через синусы и косинусы
3. Объединение компонент обратно 3. Объединение компонент обратно
""" """
seq_len = x.size(1) batch_size, seq_len, emb_size = x.shape
# Берем нужную часть матриц и приводим к типу x # Берем нужную часть матриц и приводим к типу x
cos = self.cos_matrix[:seq_len].to(x.dtype) # [seq_len, head_size//2] cos = self.cos_matrix[start_pos:start_pos+seq_len].to(x.dtype) # [seq_len, head_size//2]
sin = self.sin_matrix[:seq_len].to(x.dtype) # [seq_len, head_size//2] sin = self.sin_matrix[start_pos:start_pos+seq_len].to(x.dtype) # [seq_len, head_size//2]
# Разделяем на четные и нечетные компоненты # Разделяем на четные и нечетные компоненты
x_even = x[:, :, 0::2] # [batch_size, seq_len, head_size//2] x_even = x[:, :, 0::2] # [batch_size, seq_len, head_size//2]

View File

@@ -78,7 +78,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 1, "execution_count": 19,
"id": "873704be", "id": "873704be",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@@ -149,7 +149,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 2, "execution_count": 20,
"id": "0484cf77", "id": "0484cf77",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@@ -173,7 +173,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 3, "execution_count": 21,
"id": "74ca39ba", "id": "74ca39ba",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@@ -432,7 +432,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 4, "execution_count": 22,
"id": "02c300f9", "id": "02c300f9",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@@ -480,7 +480,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 5, "execution_count": 23,
"id": "ac072b9b", "id": "ac072b9b",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@@ -515,7 +515,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 6, "execution_count": 24,
"id": "4fe79d02", "id": "4fe79d02",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@@ -543,11 +543,12 @@
" self.register_buffer('sin_matrix', torch.sin(freq_matrix))\n", " self.register_buffer('sin_matrix', torch.sin(freq_matrix))\n",
"\n", "\n",
"\n", "\n",
" def forward(self, x: torch.Tensor): # Получает на вход тензор x (тип float) размером [batch_size × seq_len × head_size]\n", " def forward(self, x: torch.Tensor, start_pos: int = 0): # Получает на вход тензор x (тип float) размером [batch_size × seq_len × head_size]\n",
" seq_len = x.size(1)\n", " batch_size, seq_len, emb_size = x.shape\n",
"\n",
" # Берем нужную часть матриц и приводим к типу x\n", " # Берем нужную часть матриц и приводим к типу x\n",
" cos = self.cos_matrix[:seq_len].to(x.dtype) # [seq_len, head_size//2]\n", " cos = self.cos_matrix[start_pos:start_pos+seq_len].to(x.dtype) # [seq_len, head_size//2]\n",
" sin = self.sin_matrix[:seq_len].to(x.dtype) # [seq_len, head_size//2]\n", " sin = self.sin_matrix[start_pos:start_pos+seq_len].to(x.dtype) # [seq_len, head_size//2]\n",
" \n", " \n",
"\n", "\n",
" # Разделяем на четные и нечетные\n", " # Разделяем на четные и нечетные\n",
@@ -580,7 +581,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 7, "execution_count": 25,
"id": "fe1274b1", "id": "fe1274b1",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@@ -630,7 +631,9 @@
" self.register_buffer('_tril_mask', mask.bool() if hasattr(torch, 'bool') else mask.byte())\n", " self.register_buffer('_tril_mask', mask.bool() if hasattr(torch, 'bool') else mask.byte())\n",
"\n", "\n",
" def forward(self, x: torch.Tensor, use_cache: bool = True, cache: tuple = None) -> tuple:\n", " def forward(self, x: torch.Tensor, use_cache: bool = True, cache: tuple = None) -> tuple:\n",
" seq_len = x.shape[1]\n", " #seq_len = x.shape[1]\n",
" batch_size, seq_len, emb_size = x.shape\n",
"\n",
" if seq_len > self._max_seq_len:\n", " if seq_len > self._max_seq_len:\n",
" raise ValueError(f\"Длина последовательности {seq_len} превышает максимум {self._max_seq_len}\")\n", " raise ValueError(f\"Длина последовательности {seq_len} превышает максимум {self._max_seq_len}\")\n",
"\n", "\n",
@@ -638,9 +641,15 @@
" q = self._q(x) # [B, T, hs]\n", " q = self._q(x) # [B, T, hs]\n",
" v = self._v(x) # [B, T, hs]\n", " v = self._v(x) # [B, T, hs]\n",
"\n", "\n",
" start_pos = 0\n",
" if cache is not None:\n",
" k_cache, v_cache = cache\n",
" cache_len = k_cache.shape[1]\n",
" start_pos = cache_len\n",
"\n",
" # ✅ Применяем RoPE к Q и K (НЕ к V!)\n", " # ✅ Применяем RoPE к Q и K (НЕ к V!)\n",
" q = self._rope(q) # [B, T, hs]\n", " q = self._rope(q, start_pos=start_pos) # [B, T, hs]\n",
" k = self._rope(k) # [B, T, hs]\n", " k = self._rope(k, start_pos=start_pos) # [B, T, hs]\n",
"\n", "\n",
" if cache is not None:\n", " if cache is not None:\n",
" k_cache, v_cache = cache\n", " k_cache, v_cache = cache\n",
@@ -1031,7 +1040,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 8, "execution_count": 26,
"id": "078ed4ce", "id": "078ed4ce",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@@ -1177,7 +1186,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 9, "execution_count": 27,
"id": "0258c483", "id": "0258c483",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@@ -1212,7 +1221,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 10, "execution_count": 28,
"id": "f8f06cf7", "id": "f8f06cf7",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@@ -1275,7 +1284,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 11, "execution_count": 29,
"id": "0c753d52", "id": "0c753d52",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@@ -1284,106 +1293,106 @@
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"Dataset length: 20\n", "Dataset length: 20\n",
"Epoch 1/100, Loss: 3.9063\n", "Epoch 1/100, Loss: 3.7312\n",
"Epoch 2/100, Loss: 1.6363\n", "Epoch 2/100, Loss: 1.5022\n",
"Epoch 3/100, Loss: 0.6436\n", "Epoch 3/100, Loss: 0.5940\n",
"Epoch 4/100, Loss: 0.3222\n", "Epoch 4/100, Loss: 0.3368\n",
"Epoch 5/100, Loss: 0.2031\n", "Epoch 5/100, Loss: 0.2121\n",
"Epoch 6/100, Loss: 0.1460\n", "Epoch 6/100, Loss: 0.1507\n",
"Epoch 7/100, Loss: 0.1119\n", "Epoch 7/100, Loss: 0.1095\n",
"Epoch 8/100, Loss: 0.0978\n", "Epoch 8/100, Loss: 0.1023\n",
"Epoch 9/100, Loss: 0.0925\n", "Epoch 9/100, Loss: 0.0831\n",
"Epoch 10/100, Loss: 0.0757\n", "Epoch 10/100, Loss: 0.0805\n",
"Epoch 11/100, Loss: 0.0809\n", "Epoch 11/100, Loss: 0.0789\n",
"Epoch 12/100, Loss: 0.0763\n", "Epoch 12/100, Loss: 0.0802\n",
"Epoch 13/100, Loss: 0.0671\n", "Epoch 13/100, Loss: 0.0758\n",
"Epoch 14/100, Loss: 0.0681\n", "Epoch 14/100, Loss: 0.0676\n",
"Epoch 15/100, Loss: 0.0582\n", "Epoch 15/100, Loss: 0.0655\n",
"Epoch 16/100, Loss: 0.0655\n", "Epoch 16/100, Loss: 0.0701\n",
"Epoch 17/100, Loss: 0.0751\n", "Epoch 17/100, Loss: 0.0631\n",
"Epoch 18/100, Loss: 0.0656\n", "Epoch 18/100, Loss: 0.0691\n",
"Epoch 19/100, Loss: 0.0625\n", "Epoch 19/100, Loss: 0.0687\n",
"Epoch 20/100, Loss: 0.0593\n", "Epoch 20/100, Loss: 0.0623\n",
"Epoch 21/100, Loss: 0.0678\n", "Epoch 21/100, Loss: 0.0574\n",
"Epoch 22/100, Loss: 0.0652\n", "Epoch 22/100, Loss: 0.0603\n",
"Epoch 23/100, Loss: 0.0644\n", "Epoch 23/100, Loss: 0.0592\n",
"Epoch 24/100, Loss: 0.0542\n", "Epoch 24/100, Loss: 0.0587\n",
"Epoch 25/100, Loss: 0.0627\n", "Epoch 25/100, Loss: 0.0598\n",
"Epoch 26/100, Loss: 0.0596\n", "Epoch 26/100, Loss: 0.0589\n",
"Epoch 27/100, Loss: 0.0639\n", "Epoch 27/100, Loss: 0.0589\n",
"Epoch 28/100, Loss: 0.0536\n", "Epoch 28/100, Loss: 0.0546\n",
"Epoch 29/100, Loss: 0.0606\n", "Epoch 29/100, Loss: 0.0570\n",
"Epoch 30/100, Loss: 0.0565\n", "Epoch 30/100, Loss: 0.0673\n",
"Epoch 31/100, Loss: 0.0622\n", "Epoch 31/100, Loss: 0.0601\n",
"Epoch 32/100, Loss: 0.0581\n", "Epoch 32/100, Loss: 0.0702\n",
"Epoch 33/100, Loss: 0.0492\n", "Epoch 33/100, Loss: 0.0528\n",
"Epoch 34/100, Loss: 0.0553\n", "Epoch 34/100, Loss: 0.0508\n",
"Epoch 35/100, Loss: 0.0526\n", "Epoch 35/100, Loss: 0.0522\n",
"Epoch 36/100, Loss: 0.0527\n", "Epoch 36/100, Loss: 0.0537\n",
"Epoch 37/100, Loss: 0.0570\n", "Epoch 37/100, Loss: 0.0570\n",
"Epoch 38/100, Loss: 0.0482\n", "Epoch 38/100, Loss: 0.0580\n",
"Epoch 39/100, Loss: 0.0553\n", "Epoch 39/100, Loss: 0.0445\n",
"Epoch 40/100, Loss: 0.0444\n", "Epoch 40/100, Loss: 0.0516\n",
"Epoch 41/100, Loss: 0.0602\n", "Epoch 41/100, Loss: 0.0518\n",
"Epoch 42/100, Loss: 0.0599\n", "Epoch 42/100, Loss: 0.0545\n",
"Epoch 43/100, Loss: 0.0598\n", "Epoch 43/100, Loss: 0.0466\n",
"Epoch 44/100, Loss: 0.0572\n", "Epoch 44/100, Loss: 0.0523\n",
"Epoch 45/100, Loss: 0.0551\n", "Epoch 45/100, Loss: 0.0523\n",
"Epoch 46/100, Loss: 0.0577\n", "Epoch 46/100, Loss: 0.0547\n",
"Epoch 47/100, Loss: 0.0527\n", "Epoch 47/100, Loss: 0.0497\n",
"Epoch 48/100, Loss: 0.0466\n", "Epoch 48/100, Loss: 0.0512\n",
"Epoch 49/100, Loss: 0.0551\n", "Epoch 49/100, Loss: 0.0481\n",
"Epoch 50/100, Loss: 0.0517\n", "Epoch 50/100, Loss: 0.0498\n",
"Epoch 51/100, Loss: 0.0477\n", "Epoch 51/100, Loss: 0.0672\n",
"Epoch 52/100, Loss: 0.0539\n", "Epoch 52/100, Loss: 0.0530\n",
"Epoch 53/100, Loss: 0.0478\n", "Epoch 53/100, Loss: 0.0562\n",
"Epoch 54/100, Loss: 0.0539\n", "Epoch 54/100, Loss: 0.0536\n",
"Epoch 55/100, Loss: 0.0435\n", "Epoch 55/100, Loss: 0.0482\n",
"Epoch 56/100, Loss: 0.0471\n", "Epoch 56/100, Loss: 0.0438\n",
"Epoch 57/100, Loss: 0.0461\n", "Epoch 57/100, Loss: 0.0467\n",
"Epoch 58/100, Loss: 0.0452\n", "Epoch 58/100, Loss: 0.0501\n",
"Epoch 59/100, Loss: 0.0507\n", "Epoch 59/100, Loss: 0.0445\n",
"Epoch 60/100, Loss: 0.0481\n", "Epoch 60/100, Loss: 0.0471\n",
"Epoch 61/100, Loss: 0.0398\n", "Epoch 61/100, Loss: 0.0502\n",
"Epoch 62/100, Loss: 0.0535\n", "Epoch 62/100, Loss: 0.0474\n",
"Epoch 63/100, Loss: 0.0503\n", "Epoch 63/100, Loss: 0.0420\n",
"Epoch 64/100, Loss: 0.0504\n", "Epoch 64/100, Loss: 0.0541\n",
"Epoch 65/100, Loss: 0.0473\n", "Epoch 65/100, Loss: 0.0491\n",
"Epoch 66/100, Loss: 0.0553\n", "Epoch 66/100, Loss: 0.0489\n",
"Epoch 67/100, Loss: 0.0514\n", "Epoch 67/100, Loss: 0.0498\n",
"Epoch 68/100, Loss: 0.0450\n", "Epoch 68/100, Loss: 0.0511\n",
"Epoch 69/100, Loss: 0.0488\n", "Epoch 69/100, Loss: 0.0463\n",
"Epoch 70/100, Loss: 0.0414\n", "Epoch 70/100, Loss: 0.0480\n",
"Epoch 71/100, Loss: 0.0413\n", "Epoch 71/100, Loss: 0.0460\n",
"Epoch 72/100, Loss: 0.0473\n", "Epoch 72/100, Loss: 0.0533\n",
"Epoch 73/100, Loss: 0.0530\n", "Epoch 73/100, Loss: 0.0515\n",
"Epoch 74/100, Loss: 0.0482\n", "Epoch 74/100, Loss: 0.0419\n",
"Epoch 75/100, Loss: 0.0477\n", "Epoch 75/100, Loss: 0.0491\n",
"Epoch 76/100, Loss: 0.0483\n", "Epoch 76/100, Loss: 0.0471\n",
"Epoch 77/100, Loss: 0.0452\n", "Epoch 77/100, Loss: 0.0479\n",
"Epoch 78/100, Loss: 0.0452\n", "Epoch 78/100, Loss: 0.0444\n",
"Epoch 79/100, Loss: 0.0474\n", "Epoch 79/100, Loss: 0.0520\n",
"Epoch 80/100, Loss: 0.0483\n", "Epoch 80/100, Loss: 0.0520\n",
"Epoch 81/100, Loss: 0.0522\n", "Epoch 81/100, Loss: 0.0489\n",
"Epoch 82/100, Loss: 0.0453\n", "Epoch 82/100, Loss: 0.0467\n",
"Epoch 83/100, Loss: 0.0436\n", "Epoch 83/100, Loss: 0.0464\n",
"Epoch 84/100, Loss: 0.0452\n", "Epoch 84/100, Loss: 0.0451\n",
"Epoch 85/100, Loss: 0.0523\n", "Epoch 85/100, Loss: 0.0526\n",
"Epoch 86/100, Loss: 0.0446\n", "Epoch 86/100, Loss: 0.0501\n",
"Epoch 87/100, Loss: 0.0475\n", "Epoch 87/100, Loss: 0.0438\n",
"Epoch 88/100, Loss: 0.0503\n", "Epoch 88/100, Loss: 0.0476\n",
"Epoch 89/100, Loss: 0.0484\n", "Epoch 89/100, Loss: 0.0442\n",
"Epoch 90/100, Loss: 0.0456\n", "Epoch 90/100, Loss: 0.0432\n",
"Epoch 91/100, Loss: 0.0433\n", "Epoch 91/100, Loss: 0.0469\n",
"Epoch 92/100, Loss: 0.0458\n", "Epoch 92/100, Loss: 0.0494\n",
"Epoch 93/100, Loss: 0.0461\n", "Epoch 93/100, Loss: 0.0487\n",
"Epoch 94/100, Loss: 0.0448\n", "Epoch 94/100, Loss: 0.0445\n",
"Epoch 95/100, Loss: 0.0432\n", "Epoch 95/100, Loss: 0.0442\n",
"Epoch 96/100, Loss: 0.0456\n", "Epoch 96/100, Loss: 0.0417\n",
"Epoch 97/100, Loss: 0.0470\n", "Epoch 97/100, Loss: 0.0441\n",
"Epoch 98/100, Loss: 0.0470\n", "Epoch 98/100, Loss: 0.0417\n",
"Epoch 99/100, Loss: 0.0467\n", "Epoch 99/100, Loss: 0.0435\n",
"Epoch 100/100, Loss: 0.0471\n" "Epoch 100/100, Loss: 0.0433\n"
] ]
}, },
{ {
@@ -1425,7 +1434,7 @@
")" ")"
] ]
}, },
"execution_count": 11, "execution_count": 29,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
@@ -1486,7 +1495,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 12, "execution_count": 30,
"id": "8efb1396", "id": "8efb1396",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@@ -1519,7 +1528,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 14, "execution_count": 31,
"id": "b4a1c9d9", "id": "b4a1c9d9",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@@ -1527,16 +1536,16 @@
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"Fine-tune Epoch 1/10, Loss: 4.8250\n", "Fine-tune Epoch 1/10, Loss: 4.7966\n",
"Fine-tune Epoch 2/10, Loss: 2.6400\n", "Fine-tune Epoch 2/10, Loss: 2.6961\n",
"Fine-tune Epoch 3/10, Loss: 1.6920\n", "Fine-tune Epoch 3/10, Loss: 1.7293\n",
"Fine-tune Epoch 4/10, Loss: 1.2630\n", "Fine-tune Epoch 4/10, Loss: 1.2899\n",
"Fine-tune Epoch 5/10, Loss: 1.0183\n", "Fine-tune Epoch 5/10, Loss: 1.0189\n",
"Fine-tune Epoch 6/10, Loss: 0.8237\n", "Fine-tune Epoch 6/10, Loss: 0.8710\n",
"Fine-tune Epoch 7/10, Loss: 0.6869\n", "Fine-tune Epoch 7/10, Loss: 0.7198\n",
"Fine-tune Epoch 8/10, Loss: 0.5854\n", "Fine-tune Epoch 8/10, Loss: 0.6079\n",
"Fine-tune Epoch 9/10, Loss: 0.5155\n", "Fine-tune Epoch 9/10, Loss: 0.5297\n",
"Fine-tune Epoch 10/10, Loss: 0.4489\n" "Fine-tune Epoch 10/10, Loss: 0.4712\n"
] ]
} }
], ],
@@ -1565,7 +1574,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 15, "execution_count": 32,
"id": "d77b8ff5", "id": "d77b8ff5",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@@ -1580,7 +1589,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 17, "execution_count": 33,
"id": "79fea720", "id": "79fea720",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@@ -1588,7 +1597,7 @@
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"Deep learning ena les selelf les te\n" "Deep learning ena les te autt es re\n"
] ]
} }
], ],