From e791f7cd93fb3efbc000f99315866ab3810425b3 Mon Sep 17 00:00:00 2001 From: Sergey Penkovsky Date: Wed, 15 Oct 2025 13:20:30 +0300 Subject: [PATCH] fix(mistral): fix top-k/top-p mask handling for PyTorch >=1.2 --- llm/src/llm/models/mistral/mistral.py | 57 +++------------------------ 1 file changed, 6 insertions(+), 51 deletions(-) diff --git a/llm/src/llm/models/mistral/mistral.py b/llm/src/llm/models/mistral/mistral.py index 607d2d4..76559db 100644 --- a/llm/src/llm/models/mistral/mistral.py +++ b/llm/src/llm/models/mistral/mistral.py @@ -18,51 +18,6 @@ from torch import nn from typing import Optional -#class RoPE(nn.Module): -# -# def __init__(self, head_size: int, max_seq_len: int, base: int = 10_000): -# super().__init__() -# assert head_size % 2 == 0, "head_size должен быть четным" -# -# # Вычисление частот: θ_i = base^(-2i/d) для i ∈ [0, d/2-1] -# freqs = 1.0 / (base ** (2 * torch.arange(head_size // 2).float() / head_size)) -# -# # Позиции от 0 до max_seq_len-1 -# positions = torch.arange(max_seq_len).float() -# -# # Внешнее произведение: m * θ_i для всех позиций и частот -# freq_matrix = positions.unsqueeze(1) * freqs.unsqueeze(0) -# -# # Предвычисление матриц косинусов и синусов -# self.register_buffer("cos_matrix", torch.cos(freq_matrix)) -# self.register_buffer("sin_matrix", torch.sin(freq_matrix)) -# -# def forward(self, x: torch.Tensor, start_pos: int = 0) -> torch.Tensor: # [batch_size × seq_len × head_size] [batch_size × num_heads × seq_len × head_size] -# batch_size, num_heads, seq_len, head_size = x.shape -# -# # Берем нужную часть матриц и приводим к типу x -# cos = self.cos_matrix[start_pos:start_pos+seq_len].to(x.dtype) # [seq_len, head_size//2] -# sin = self.sin_matrix[start_pos:start_pos+seq_len].to(x.dtype) # [seq_len, head_size//2] -# -# # Явное изменение формы для broadcasting -# cos = cos.reshape(1, 1, seq_len, head_size // 2) -# sin = sin.reshape(1, 1, seq_len, head_size // 2) -# -# # Разделяем на четные и нечетные компоненты по ПОСЛЕДНЕМУ измерению -# x_even = x[..., 0::2] # [batch_size, num_heads, seq_len, head_size//2] -# x_odd = x[..., 1::2] # [batch_size, num_heads, seq_len, head_size//2] -# -# # Применяем поворот: q' = q * cos(mθ) + rotate(q) * sin(mθ) -# x_rotated_even = x_even * cos - x_odd * sin -# x_rotated_odd = x_even * sin + x_odd * cos -# -# # Объединяем обратно в исходную размерность -# x_rotated = torch.stack([x_rotated_even, x_rotated_odd], dim=-1) -# x_rotated = x_rotated.flatten(-2) # [batch_size, seq_len, head_size] -# -# return x_rotated - - import torch from torch import nn import torch.nn.functional as F @@ -456,9 +411,9 @@ class Mistral(BaseModel): vocab_size = logits_scaled.size(-1) # создаём маску: 1, если токен НЕ в topk_indices - mask = torch.ones_like(logits_scaled, dtype=torch.uint8) - mask.scatter_(1, topk_indices, 0) # 0 там, где top-k индексы - masked_logits[mask.byte()] = float('-inf') + mask = torch.ones_like(logits_scaled, dtype=torch.bool if hasattr(torch, "bool") else torch.uint8) + mask.scatter_(1, topk_indices, False if hasattr(torch, "bool") else 0) # 0 там, где top-k индексы + masked_logits[mask.bool() if hasattr(torch, "bool") else mask.byte()] = float('-inf') logits_scaled = masked_logits @@ -470,12 +425,12 @@ class Mistral(BaseModel): # 3. Посчитаем кумулятивную сумму вероятностей: cum_probs = torch.cumsum(sorted_probs, dim=-1) # [B, vocab_size] # 4. Определим маску: оставить токены, пока сумма < top_p - sorted_mask = (cum_probs <= top_p).byte() # [B, vocab_size] + sorted_mask = (cum_probs <= top_p).bool() if hasattr(torch, "bool") else (cum_probs <= top_p).byte() # [B, vocab_size] # Гарантируем, что хотя бы первый токен останется - sorted_mask[:, 0] = 1 + sorted_mask[:, 0] = True if hasattr(torch, "bool") else 1 # 5. Преобразуем маску обратно в оригинальный порядок: # Создаём полную маску из 0 - mask = torch.zeros_like(probs, dtype=torch.uint8) + mask = torch.zeros_like(probs, dtype=torch.bool if hasattr(torch, "bool") else torch.uint8) # Устанавливаем 1 в местах нужных токенов mask.scatter_(dim=1, index=sorted_indices, src=sorted_mask) # 6. Зануляем логиты токенов вне топ-p: