fix(rope, attention): корректное позиционирование RoPE при генерации с кэшем

- Исправлена ошибка расчёта позиции для RoPE (Rotary Positional Embeddings) при автодополнении с использованием кэша.
- В HeadAttention теперь передаётся start_pos в RoPE, вычисляемый из длины кэша.
- Обновлена сигнатура и логика метода RoPE.forward.
- Обновлен ноутбук llama.ipynb под новые интерфейсы и выводы.

BREAKING CHANGE: переопределён метод forward у RoPE, требуется обновить код, если RoPE использовался вручную.
This commit is contained in:
Sergey Penkovsky
2025-10-14 12:03:20 +03:00
parent 3e4815fcc6
commit e5706a690d
3 changed files with 155 additions and 140 deletions

View File

@@ -89,10 +89,16 @@ class HeadAttention(nn.Module):
q = self._q(x) # [B, T, hs]
v = self._v(x) # [B, T, hs]
start_pos = 0
if cache is not None:
k_cache, v_cache = cache
cache_len = k_cache.shape[1]
start_pos = cache_len
if self._rope is not None:
# ✅ Применяем RoPE к Q и K (НЕ к V!)
q = self._rope(q) # [B, T, hs]
k = self._rope(k) # [B, T, hs]
q = self._rope(q, start_pos=start_pos) # [B, T, hs]
k = self._rope(k, start_pos=start_pos) # [B, T, hs]
if cache is not None:
k_cache, v_cache = cache

View File

@@ -68,7 +68,7 @@ class RoPE(nn.Module):
self.register_buffer("cos_matrix", torch.cos(freq_matrix))
self.register_buffer("sin_matrix", torch.sin(freq_matrix))
def forward(self, x: torch.Tensor) -> torch.Tensor:
def forward(self, x: torch.Tensor, start_pos: int = 0) -> torch.Tensor:
"""
Применение ротационного позиционного кодирования к входному тензору.
@@ -83,12 +83,12 @@ class RoPE(nn.Module):
2. Применение вращения через синусы и косинусы
3. Объединение компонент обратно
"""
seq_len = x.size(1)
batch_size, seq_len, emb_size = x.shape
# Берем нужную часть матриц и приводим к типу x
cos = self.cos_matrix[:seq_len].to(x.dtype) # [seq_len, head_size//2]
sin = self.sin_matrix[:seq_len].to(x.dtype) # [seq_len, head_size//2]
cos = self.cos_matrix[start_pos:start_pos+seq_len].to(x.dtype) # [seq_len, head_size//2]
sin = self.sin_matrix[start_pos:start_pos+seq_len].to(x.dtype) # [seq_len, head_size//2]
# Разделяем на четные и нечетные компоненты
x_even = x[:, :, 0::2] # [batch_size, seq_len, head_size//2]
x_odd = x[:, :, 1::2] # [batch_size, seq_len, head_size//2]