fix(mistral): fix top-k/top-p mask handling for PyTorch >=1.2

This commit is contained in:
Sergey Penkovsky
2025-10-15 13:20:30 +03:00
parent d10044e4a7
commit e791f7cd93

View File

@@ -18,51 +18,6 @@ from torch import nn
from typing import Optional from typing import Optional
#class RoPE(nn.Module):
#
# def __init__(self, head_size: int, max_seq_len: int, base: int = 10_000):
# super().__init__()
# assert head_size % 2 == 0, "head_size должен быть четным"
#
# # Вычисление частот: θ_i = base^(-2i/d) для i ∈ [0, d/2-1]
# freqs = 1.0 / (base ** (2 * torch.arange(head_size // 2).float() / head_size))
#
# # Позиции от 0 до max_seq_len-1
# positions = torch.arange(max_seq_len).float()
#
# # Внешнее произведение: m * θ_i для всех позиций и частот
# freq_matrix = positions.unsqueeze(1) * freqs.unsqueeze(0)
#
# # Предвычисление матриц косинусов и синусов
# self.register_buffer("cos_matrix", torch.cos(freq_matrix))
# self.register_buffer("sin_matrix", torch.sin(freq_matrix))
#
# def forward(self, x: torch.Tensor, start_pos: int = 0) -> torch.Tensor: # [batch_size × seq_len × head_size] [batch_size × num_heads × seq_len × head_size]
# batch_size, num_heads, seq_len, head_size = x.shape
#
# # Берем нужную часть матриц и приводим к типу x
# cos = self.cos_matrix[start_pos:start_pos+seq_len].to(x.dtype) # [seq_len, head_size//2]
# sin = self.sin_matrix[start_pos:start_pos+seq_len].to(x.dtype) # [seq_len, head_size//2]
#
# # Явное изменение формы для broadcasting
# cos = cos.reshape(1, 1, seq_len, head_size // 2)
# sin = sin.reshape(1, 1, seq_len, head_size // 2)
#
# # Разделяем на четные и нечетные компоненты по ПОСЛЕДНЕМУ измерению
# x_even = x[..., 0::2] # [batch_size, num_heads, seq_len, head_size//2]
# x_odd = x[..., 1::2] # [batch_size, num_heads, seq_len, head_size//2]
#
# # Применяем поворот: q' = q * cos(mθ) + rotate(q) * sin(mθ)
# x_rotated_even = x_even * cos - x_odd * sin
# x_rotated_odd = x_even * sin + x_odd * cos
#
# # Объединяем обратно в исходную размерность
# x_rotated = torch.stack([x_rotated_even, x_rotated_odd], dim=-1)
# x_rotated = x_rotated.flatten(-2) # [batch_size, seq_len, head_size]
#
# return x_rotated
import torch import torch
from torch import nn from torch import nn
import torch.nn.functional as F import torch.nn.functional as F
@@ -456,9 +411,9 @@ class Mistral(BaseModel):
vocab_size = logits_scaled.size(-1) vocab_size = logits_scaled.size(-1)
# создаём маску: 1, если токен НЕ в topk_indices # создаём маску: 1, если токен НЕ в topk_indices
mask = torch.ones_like(logits_scaled, dtype=torch.uint8) mask = torch.ones_like(logits_scaled, dtype=torch.bool if hasattr(torch, "bool") else torch.uint8)
mask.scatter_(1, topk_indices, 0) # 0 там, где top-k индексы mask.scatter_(1, topk_indices, False if hasattr(torch, "bool") else 0) # 0 там, где top-k индексы
masked_logits[mask.byte()] = float('-inf') masked_logits[mask.bool() if hasattr(torch, "bool") else mask.byte()] = float('-inf')
logits_scaled = masked_logits logits_scaled = masked_logits
@@ -470,12 +425,12 @@ class Mistral(BaseModel):
# 3. Посчитаем кумулятивную сумму вероятностей: # 3. Посчитаем кумулятивную сумму вероятностей:
cum_probs = torch.cumsum(sorted_probs, dim=-1) # [B, vocab_size] cum_probs = torch.cumsum(sorted_probs, dim=-1) # [B, vocab_size]
# 4. Определим маску: оставить токены, пока сумма < top_p # 4. Определим маску: оставить токены, пока сумма < top_p
sorted_mask = (cum_probs <= top_p).byte() # [B, vocab_size] sorted_mask = (cum_probs <= top_p).bool() if hasattr(torch, "bool") else (cum_probs <= top_p).byte() # [B, vocab_size]
# Гарантируем, что хотя бы первый токен останется # Гарантируем, что хотя бы первый токен останется
sorted_mask[:, 0] = 1 sorted_mask[:, 0] = True if hasattr(torch, "bool") else 1
# 5. Преобразуем маску обратно в оригинальный порядок: # 5. Преобразуем маску обратно в оригинальный порядок:
# Создаём полную маску из 0 # Создаём полную маску из 0
mask = torch.zeros_like(probs, dtype=torch.uint8) mask = torch.zeros_like(probs, dtype=torch.bool if hasattr(torch, "bool") else torch.uint8)
# Устанавливаем 1 в местах нужных токенов # Устанавливаем 1 в местах нужных токенов
mask.scatter_(dim=1, index=sorted_indices, src=sorted_mask) mask.scatter_(dim=1, index=sorted_indices, src=sorted_mask)
# 6. Зануляем логиты токенов вне топ-p: # 6. Зануляем логиты токенов вне топ-p: