mirror of
https://github.com/pese-git/llm-arch-research.git
synced 2026-01-23 21:10:54 +00:00
fix(rope, attention): корректное позиционирование RoPE при генерации с кэшем
- Исправлена ошибка расчёта позиции для RoPE (Rotary Positional Embeddings) при автодополнении с использованием кэша. - В HeadAttention теперь передаётся start_pos в RoPE, вычисляемый из длины кэша. - Обновлена сигнатура и логика метода RoPE.forward. - Обновлен ноутбук llama.ipynb под новые интерфейсы и выводы. BREAKING CHANGE: переопределён метод forward у RoPE, требуется обновить код, если RoPE использовался вручную.
This commit is contained in:
@@ -89,10 +89,16 @@ class HeadAttention(nn.Module):
|
||||
q = self._q(x) # [B, T, hs]
|
||||
v = self._v(x) # [B, T, hs]
|
||||
|
||||
start_pos = 0
|
||||
if cache is not None:
|
||||
k_cache, v_cache = cache
|
||||
cache_len = k_cache.shape[1]
|
||||
start_pos = cache_len
|
||||
|
||||
if self._rope is not None:
|
||||
# ✅ Применяем RoPE к Q и K (НЕ к V!)
|
||||
q = self._rope(q) # [B, T, hs]
|
||||
k = self._rope(k) # [B, T, hs]
|
||||
q = self._rope(q, start_pos=start_pos) # [B, T, hs]
|
||||
k = self._rope(k, start_pos=start_pos) # [B, T, hs]
|
||||
|
||||
if cache is not None:
|
||||
k_cache, v_cache = cache
|
||||
|
||||
@@ -68,7 +68,7 @@ class RoPE(nn.Module):
|
||||
self.register_buffer("cos_matrix", torch.cos(freq_matrix))
|
||||
self.register_buffer("sin_matrix", torch.sin(freq_matrix))
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
def forward(self, x: torch.Tensor, start_pos: int = 0) -> torch.Tensor:
|
||||
"""
|
||||
Применение ротационного позиционного кодирования к входному тензору.
|
||||
|
||||
@@ -83,12 +83,12 @@ class RoPE(nn.Module):
|
||||
2. Применение вращения через синусы и косинусы
|
||||
3. Объединение компонент обратно
|
||||
"""
|
||||
seq_len = x.size(1)
|
||||
batch_size, seq_len, emb_size = x.shape
|
||||
|
||||
# Берем нужную часть матриц и приводим к типу x
|
||||
cos = self.cos_matrix[:seq_len].to(x.dtype) # [seq_len, head_size//2]
|
||||
sin = self.sin_matrix[:seq_len].to(x.dtype) # [seq_len, head_size//2]
|
||||
|
||||
cos = self.cos_matrix[start_pos:start_pos+seq_len].to(x.dtype) # [seq_len, head_size//2]
|
||||
sin = self.sin_matrix[start_pos:start_pos+seq_len].to(x.dtype) # [seq_len, head_size//2]
|
||||
|
||||
# Разделяем на четные и нечетные компоненты
|
||||
x_even = x[:, :, 0::2] # [batch_size, seq_len, head_size//2]
|
||||
x_odd = x[:, :, 1::2] # [batch_size, seq_len, head_size//2]
|
||||
|
||||
Reference in New Issue
Block a user