fix(rope, attention): корректное позиционирование RoPE при генерации с кэшем

- Исправлена ошибка расчёта позиции для RoPE (Rotary Positional Embeddings) при автодополнении с использованием кэша.
- В HeadAttention теперь передаётся start_pos в RoPE, вычисляемый из длины кэша.
- Обновлена сигнатура и логика метода RoPE.forward.
- Обновлен ноутбук llama.ipynb под новые интерфейсы и выводы.

BREAKING CHANGE: переопределён метод forward у RoPE, требуется обновить код, если RoPE использовался вручную.
This commit is contained in:
Sergey Penkovsky
2025-10-14 12:03:20 +03:00
parent 3e4815fcc6
commit e5706a690d
3 changed files with 155 additions and 140 deletions

View File

@@ -89,10 +89,16 @@ class HeadAttention(nn.Module):
q = self._q(x) # [B, T, hs]
v = self._v(x) # [B, T, hs]
start_pos = 0
if cache is not None:
k_cache, v_cache = cache
cache_len = k_cache.shape[1]
start_pos = cache_len
if self._rope is not None:
# ✅ Применяем RoPE к Q и K (НЕ к V!)
q = self._rope(q) # [B, T, hs]
k = self._rope(k) # [B, T, hs]
q = self._rope(q, start_pos=start_pos) # [B, T, hs]
k = self._rope(k, start_pos=start_pos) # [B, T, hs]
if cache is not None:
k_cache, v_cache = cache

View File

@@ -68,7 +68,7 @@ class RoPE(nn.Module):
self.register_buffer("cos_matrix", torch.cos(freq_matrix))
self.register_buffer("sin_matrix", torch.sin(freq_matrix))
def forward(self, x: torch.Tensor) -> torch.Tensor:
def forward(self, x: torch.Tensor, start_pos: int = 0) -> torch.Tensor:
"""
Применение ротационного позиционного кодирования к входному тензору.
@@ -83,12 +83,12 @@ class RoPE(nn.Module):
2. Применение вращения через синусы и косинусы
3. Объединение компонент обратно
"""
seq_len = x.size(1)
batch_size, seq_len, emb_size = x.shape
# Берем нужную часть матриц и приводим к типу x
cos = self.cos_matrix[:seq_len].to(x.dtype) # [seq_len, head_size//2]
sin = self.sin_matrix[:seq_len].to(x.dtype) # [seq_len, head_size//2]
cos = self.cos_matrix[start_pos:start_pos+seq_len].to(x.dtype) # [seq_len, head_size//2]
sin = self.sin_matrix[start_pos:start_pos+seq_len].to(x.dtype) # [seq_len, head_size//2]
# Разделяем на четные и нечетные компоненты
x_even = x[:, :, 0::2] # [batch_size, seq_len, head_size//2]
x_odd = x[:, :, 1::2] # [batch_size, seq_len, head_size//2]

View File

@@ -78,7 +78,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 19,
"id": "873704be",
"metadata": {},
"outputs": [],
@@ -149,7 +149,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 20,
"id": "0484cf77",
"metadata": {},
"outputs": [],
@@ -173,7 +173,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 21,
"id": "74ca39ba",
"metadata": {},
"outputs": [],
@@ -432,7 +432,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 22,
"id": "02c300f9",
"metadata": {},
"outputs": [
@@ -480,7 +480,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 23,
"id": "ac072b9b",
"metadata": {},
"outputs": [
@@ -515,7 +515,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 24,
"id": "4fe79d02",
"metadata": {},
"outputs": [],
@@ -543,11 +543,12 @@
" self.register_buffer('sin_matrix', torch.sin(freq_matrix))\n",
"\n",
"\n",
" def forward(self, x: torch.Tensor): # Получает на вход тензор x (тип float) размером [batch_size × seq_len × head_size]\n",
" seq_len = x.size(1)\n",
" def forward(self, x: torch.Tensor, start_pos: int = 0): # Получает на вход тензор x (тип float) размером [batch_size × seq_len × head_size]\n",
" batch_size, seq_len, emb_size = x.shape\n",
"\n",
" # Берем нужную часть матриц и приводим к типу x\n",
" cos = self.cos_matrix[:seq_len].to(x.dtype) # [seq_len, head_size//2]\n",
" sin = self.sin_matrix[:seq_len].to(x.dtype) # [seq_len, head_size//2]\n",
" cos = self.cos_matrix[start_pos:start_pos+seq_len].to(x.dtype) # [seq_len, head_size//2]\n",
" sin = self.sin_matrix[start_pos:start_pos+seq_len].to(x.dtype) # [seq_len, head_size//2]\n",
" \n",
"\n",
" # Разделяем на четные и нечетные\n",
@@ -580,7 +581,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 25,
"id": "fe1274b1",
"metadata": {},
"outputs": [],
@@ -630,7 +631,9 @@
" self.register_buffer('_tril_mask', mask.bool() if hasattr(torch, 'bool') else mask.byte())\n",
"\n",
" def forward(self, x: torch.Tensor, use_cache: bool = True, cache: tuple = None) -> tuple:\n",
" seq_len = x.shape[1]\n",
" #seq_len = x.shape[1]\n",
" batch_size, seq_len, emb_size = x.shape\n",
"\n",
" if seq_len > self._max_seq_len:\n",
" raise ValueError(f\"Длина последовательности {seq_len} превышает максимум {self._max_seq_len}\")\n",
"\n",
@@ -638,9 +641,15 @@
" q = self._q(x) # [B, T, hs]\n",
" v = self._v(x) # [B, T, hs]\n",
"\n",
" start_pos = 0\n",
" if cache is not None:\n",
" k_cache, v_cache = cache\n",
" cache_len = k_cache.shape[1]\n",
" start_pos = cache_len\n",
"\n",
" # ✅ Применяем RoPE к Q и K (НЕ к V!)\n",
" q = self._rope(q) # [B, T, hs]\n",
" k = self._rope(k) # [B, T, hs]\n",
" q = self._rope(q, start_pos=start_pos) # [B, T, hs]\n",
" k = self._rope(k, start_pos=start_pos) # [B, T, hs]\n",
"\n",
" if cache is not None:\n",
" k_cache, v_cache = cache\n",
@@ -1031,7 +1040,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 26,
"id": "078ed4ce",
"metadata": {},
"outputs": [],
@@ -1177,7 +1186,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 27,
"id": "0258c483",
"metadata": {},
"outputs": [],
@@ -1212,7 +1221,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 28,
"id": "f8f06cf7",
"metadata": {},
"outputs": [],
@@ -1275,7 +1284,7 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 29,
"id": "0c753d52",
"metadata": {},
"outputs": [
@@ -1284,106 +1293,106 @@
"output_type": "stream",
"text": [
"Dataset length: 20\n",
"Epoch 1/100, Loss: 3.9063\n",
"Epoch 2/100, Loss: 1.6363\n",
"Epoch 3/100, Loss: 0.6436\n",
"Epoch 4/100, Loss: 0.3222\n",
"Epoch 5/100, Loss: 0.2031\n",
"Epoch 6/100, Loss: 0.1460\n",
"Epoch 7/100, Loss: 0.1119\n",
"Epoch 8/100, Loss: 0.0978\n",
"Epoch 9/100, Loss: 0.0925\n",
"Epoch 10/100, Loss: 0.0757\n",
"Epoch 11/100, Loss: 0.0809\n",
"Epoch 12/100, Loss: 0.0763\n",
"Epoch 13/100, Loss: 0.0671\n",
"Epoch 14/100, Loss: 0.0681\n",
"Epoch 15/100, Loss: 0.0582\n",
"Epoch 16/100, Loss: 0.0655\n",
"Epoch 17/100, Loss: 0.0751\n",
"Epoch 18/100, Loss: 0.0656\n",
"Epoch 19/100, Loss: 0.0625\n",
"Epoch 20/100, Loss: 0.0593\n",
"Epoch 21/100, Loss: 0.0678\n",
"Epoch 22/100, Loss: 0.0652\n",
"Epoch 23/100, Loss: 0.0644\n",
"Epoch 24/100, Loss: 0.0542\n",
"Epoch 25/100, Loss: 0.0627\n",
"Epoch 26/100, Loss: 0.0596\n",
"Epoch 27/100, Loss: 0.0639\n",
"Epoch 28/100, Loss: 0.0536\n",
"Epoch 29/100, Loss: 0.0606\n",
"Epoch 30/100, Loss: 0.0565\n",
"Epoch 31/100, Loss: 0.0622\n",
"Epoch 32/100, Loss: 0.0581\n",
"Epoch 33/100, Loss: 0.0492\n",
"Epoch 34/100, Loss: 0.0553\n",
"Epoch 35/100, Loss: 0.0526\n",
"Epoch 36/100, Loss: 0.0527\n",
"Epoch 1/100, Loss: 3.7312\n",
"Epoch 2/100, Loss: 1.5022\n",
"Epoch 3/100, Loss: 0.5940\n",
"Epoch 4/100, Loss: 0.3368\n",
"Epoch 5/100, Loss: 0.2121\n",
"Epoch 6/100, Loss: 0.1507\n",
"Epoch 7/100, Loss: 0.1095\n",
"Epoch 8/100, Loss: 0.1023\n",
"Epoch 9/100, Loss: 0.0831\n",
"Epoch 10/100, Loss: 0.0805\n",
"Epoch 11/100, Loss: 0.0789\n",
"Epoch 12/100, Loss: 0.0802\n",
"Epoch 13/100, Loss: 0.0758\n",
"Epoch 14/100, Loss: 0.0676\n",
"Epoch 15/100, Loss: 0.0655\n",
"Epoch 16/100, Loss: 0.0701\n",
"Epoch 17/100, Loss: 0.0631\n",
"Epoch 18/100, Loss: 0.0691\n",
"Epoch 19/100, Loss: 0.0687\n",
"Epoch 20/100, Loss: 0.0623\n",
"Epoch 21/100, Loss: 0.0574\n",
"Epoch 22/100, Loss: 0.0603\n",
"Epoch 23/100, Loss: 0.0592\n",
"Epoch 24/100, Loss: 0.0587\n",
"Epoch 25/100, Loss: 0.0598\n",
"Epoch 26/100, Loss: 0.0589\n",
"Epoch 27/100, Loss: 0.0589\n",
"Epoch 28/100, Loss: 0.0546\n",
"Epoch 29/100, Loss: 0.0570\n",
"Epoch 30/100, Loss: 0.0673\n",
"Epoch 31/100, Loss: 0.0601\n",
"Epoch 32/100, Loss: 0.0702\n",
"Epoch 33/100, Loss: 0.0528\n",
"Epoch 34/100, Loss: 0.0508\n",
"Epoch 35/100, Loss: 0.0522\n",
"Epoch 36/100, Loss: 0.0537\n",
"Epoch 37/100, Loss: 0.0570\n",
"Epoch 38/100, Loss: 0.0482\n",
"Epoch 39/100, Loss: 0.0553\n",
"Epoch 40/100, Loss: 0.0444\n",
"Epoch 41/100, Loss: 0.0602\n",
"Epoch 42/100, Loss: 0.0599\n",
"Epoch 43/100, Loss: 0.0598\n",
"Epoch 44/100, Loss: 0.0572\n",
"Epoch 45/100, Loss: 0.0551\n",
"Epoch 46/100, Loss: 0.0577\n",
"Epoch 47/100, Loss: 0.0527\n",
"Epoch 48/100, Loss: 0.0466\n",
"Epoch 49/100, Loss: 0.0551\n",
"Epoch 50/100, Loss: 0.0517\n",
"Epoch 51/100, Loss: 0.0477\n",
"Epoch 52/100, Loss: 0.0539\n",
"Epoch 53/100, Loss: 0.0478\n",
"Epoch 54/100, Loss: 0.0539\n",
"Epoch 55/100, Loss: 0.0435\n",
"Epoch 56/100, Loss: 0.0471\n",
"Epoch 57/100, Loss: 0.0461\n",
"Epoch 58/100, Loss: 0.0452\n",
"Epoch 59/100, Loss: 0.0507\n",
"Epoch 60/100, Loss: 0.0481\n",
"Epoch 61/100, Loss: 0.0398\n",
"Epoch 62/100, Loss: 0.0535\n",
"Epoch 63/100, Loss: 0.0503\n",
"Epoch 64/100, Loss: 0.0504\n",
"Epoch 65/100, Loss: 0.0473\n",
"Epoch 66/100, Loss: 0.0553\n",
"Epoch 67/100, Loss: 0.0514\n",
"Epoch 68/100, Loss: 0.0450\n",
"Epoch 69/100, Loss: 0.0488\n",
"Epoch 70/100, Loss: 0.0414\n",
"Epoch 71/100, Loss: 0.0413\n",
"Epoch 72/100, Loss: 0.0473\n",
"Epoch 73/100, Loss: 0.0530\n",
"Epoch 74/100, Loss: 0.0482\n",
"Epoch 75/100, Loss: 0.0477\n",
"Epoch 76/100, Loss: 0.0483\n",
"Epoch 77/100, Loss: 0.0452\n",
"Epoch 78/100, Loss: 0.0452\n",
"Epoch 79/100, Loss: 0.0474\n",
"Epoch 80/100, Loss: 0.0483\n",
"Epoch 81/100, Loss: 0.0522\n",
"Epoch 82/100, Loss: 0.0453\n",
"Epoch 83/100, Loss: 0.0436\n",
"Epoch 84/100, Loss: 0.0452\n",
"Epoch 85/100, Loss: 0.0523\n",
"Epoch 86/100, Loss: 0.0446\n",
"Epoch 87/100, Loss: 0.0475\n",
"Epoch 88/100, Loss: 0.0503\n",
"Epoch 89/100, Loss: 0.0484\n",
"Epoch 90/100, Loss: 0.0456\n",
"Epoch 91/100, Loss: 0.0433\n",
"Epoch 92/100, Loss: 0.0458\n",
"Epoch 93/100, Loss: 0.0461\n",
"Epoch 94/100, Loss: 0.0448\n",
"Epoch 95/100, Loss: 0.0432\n",
"Epoch 96/100, Loss: 0.0456\n",
"Epoch 97/100, Loss: 0.0470\n",
"Epoch 98/100, Loss: 0.0470\n",
"Epoch 99/100, Loss: 0.0467\n",
"Epoch 100/100, Loss: 0.0471\n"
"Epoch 38/100, Loss: 0.0580\n",
"Epoch 39/100, Loss: 0.0445\n",
"Epoch 40/100, Loss: 0.0516\n",
"Epoch 41/100, Loss: 0.0518\n",
"Epoch 42/100, Loss: 0.0545\n",
"Epoch 43/100, Loss: 0.0466\n",
"Epoch 44/100, Loss: 0.0523\n",
"Epoch 45/100, Loss: 0.0523\n",
"Epoch 46/100, Loss: 0.0547\n",
"Epoch 47/100, Loss: 0.0497\n",
"Epoch 48/100, Loss: 0.0512\n",
"Epoch 49/100, Loss: 0.0481\n",
"Epoch 50/100, Loss: 0.0498\n",
"Epoch 51/100, Loss: 0.0672\n",
"Epoch 52/100, Loss: 0.0530\n",
"Epoch 53/100, Loss: 0.0562\n",
"Epoch 54/100, Loss: 0.0536\n",
"Epoch 55/100, Loss: 0.0482\n",
"Epoch 56/100, Loss: 0.0438\n",
"Epoch 57/100, Loss: 0.0467\n",
"Epoch 58/100, Loss: 0.0501\n",
"Epoch 59/100, Loss: 0.0445\n",
"Epoch 60/100, Loss: 0.0471\n",
"Epoch 61/100, Loss: 0.0502\n",
"Epoch 62/100, Loss: 0.0474\n",
"Epoch 63/100, Loss: 0.0420\n",
"Epoch 64/100, Loss: 0.0541\n",
"Epoch 65/100, Loss: 0.0491\n",
"Epoch 66/100, Loss: 0.0489\n",
"Epoch 67/100, Loss: 0.0498\n",
"Epoch 68/100, Loss: 0.0511\n",
"Epoch 69/100, Loss: 0.0463\n",
"Epoch 70/100, Loss: 0.0480\n",
"Epoch 71/100, Loss: 0.0460\n",
"Epoch 72/100, Loss: 0.0533\n",
"Epoch 73/100, Loss: 0.0515\n",
"Epoch 74/100, Loss: 0.0419\n",
"Epoch 75/100, Loss: 0.0491\n",
"Epoch 76/100, Loss: 0.0471\n",
"Epoch 77/100, Loss: 0.0479\n",
"Epoch 78/100, Loss: 0.0444\n",
"Epoch 79/100, Loss: 0.0520\n",
"Epoch 80/100, Loss: 0.0520\n",
"Epoch 81/100, Loss: 0.0489\n",
"Epoch 82/100, Loss: 0.0467\n",
"Epoch 83/100, Loss: 0.0464\n",
"Epoch 84/100, Loss: 0.0451\n",
"Epoch 85/100, Loss: 0.0526\n",
"Epoch 86/100, Loss: 0.0501\n",
"Epoch 87/100, Loss: 0.0438\n",
"Epoch 88/100, Loss: 0.0476\n",
"Epoch 89/100, Loss: 0.0442\n",
"Epoch 90/100, Loss: 0.0432\n",
"Epoch 91/100, Loss: 0.0469\n",
"Epoch 92/100, Loss: 0.0494\n",
"Epoch 93/100, Loss: 0.0487\n",
"Epoch 94/100, Loss: 0.0445\n",
"Epoch 95/100, Loss: 0.0442\n",
"Epoch 96/100, Loss: 0.0417\n",
"Epoch 97/100, Loss: 0.0441\n",
"Epoch 98/100, Loss: 0.0417\n",
"Epoch 99/100, Loss: 0.0435\n",
"Epoch 100/100, Loss: 0.0433\n"
]
},
{
@@ -1425,7 +1434,7 @@
")"
]
},
"execution_count": 11,
"execution_count": 29,
"metadata": {},
"output_type": "execute_result"
}
@@ -1486,7 +1495,7 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 30,
"id": "8efb1396",
"metadata": {},
"outputs": [],
@@ -1519,7 +1528,7 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": 31,
"id": "b4a1c9d9",
"metadata": {},
"outputs": [
@@ -1527,16 +1536,16 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Fine-tune Epoch 1/10, Loss: 4.8250\n",
"Fine-tune Epoch 2/10, Loss: 2.6400\n",
"Fine-tune Epoch 3/10, Loss: 1.6920\n",
"Fine-tune Epoch 4/10, Loss: 1.2630\n",
"Fine-tune Epoch 5/10, Loss: 1.0183\n",
"Fine-tune Epoch 6/10, Loss: 0.8237\n",
"Fine-tune Epoch 7/10, Loss: 0.6869\n",
"Fine-tune Epoch 8/10, Loss: 0.5854\n",
"Fine-tune Epoch 9/10, Loss: 0.5155\n",
"Fine-tune Epoch 10/10, Loss: 0.4489\n"
"Fine-tune Epoch 1/10, Loss: 4.7966\n",
"Fine-tune Epoch 2/10, Loss: 2.6961\n",
"Fine-tune Epoch 3/10, Loss: 1.7293\n",
"Fine-tune Epoch 4/10, Loss: 1.2899\n",
"Fine-tune Epoch 5/10, Loss: 1.0189\n",
"Fine-tune Epoch 6/10, Loss: 0.8710\n",
"Fine-tune Epoch 7/10, Loss: 0.7198\n",
"Fine-tune Epoch 8/10, Loss: 0.6079\n",
"Fine-tune Epoch 9/10, Loss: 0.5297\n",
"Fine-tune Epoch 10/10, Loss: 0.4712\n"
]
}
],
@@ -1565,7 +1574,7 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 32,
"id": "d77b8ff5",
"metadata": {},
"outputs": [],
@@ -1580,7 +1589,7 @@
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": 33,
"id": "79fea720",
"metadata": {},
"outputs": [
@@ -1588,7 +1597,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Deep learning ena les selelf les te\n"
"Deep learning ena les te autt es re\n"
]
}
],