mirror of
https://github.com/pese-git/llm-arch-research.git
synced 2026-01-23 21:10:54 +00:00
fix(mistral): fix top-k/top-p mask handling for PyTorch >=1.2
This commit is contained in:
@@ -18,51 +18,6 @@ from torch import nn
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
#class RoPE(nn.Module):
|
|
||||||
#
|
|
||||||
# def __init__(self, head_size: int, max_seq_len: int, base: int = 10_000):
|
|
||||||
# super().__init__()
|
|
||||||
# assert head_size % 2 == 0, "head_size должен быть четным"
|
|
||||||
#
|
|
||||||
# # Вычисление частот: θ_i = base^(-2i/d) для i ∈ [0, d/2-1]
|
|
||||||
# freqs = 1.0 / (base ** (2 * torch.arange(head_size // 2).float() / head_size))
|
|
||||||
#
|
|
||||||
# # Позиции от 0 до max_seq_len-1
|
|
||||||
# positions = torch.arange(max_seq_len).float()
|
|
||||||
#
|
|
||||||
# # Внешнее произведение: m * θ_i для всех позиций и частот
|
|
||||||
# freq_matrix = positions.unsqueeze(1) * freqs.unsqueeze(0)
|
|
||||||
#
|
|
||||||
# # Предвычисление матриц косинусов и синусов
|
|
||||||
# self.register_buffer("cos_matrix", torch.cos(freq_matrix))
|
|
||||||
# self.register_buffer("sin_matrix", torch.sin(freq_matrix))
|
|
||||||
#
|
|
||||||
# def forward(self, x: torch.Tensor, start_pos: int = 0) -> torch.Tensor: # [batch_size × seq_len × head_size] [batch_size × num_heads × seq_len × head_size]
|
|
||||||
# batch_size, num_heads, seq_len, head_size = x.shape
|
|
||||||
#
|
|
||||||
# # Берем нужную часть матриц и приводим к типу x
|
|
||||||
# cos = self.cos_matrix[start_pos:start_pos+seq_len].to(x.dtype) # [seq_len, head_size//2]
|
|
||||||
# sin = self.sin_matrix[start_pos:start_pos+seq_len].to(x.dtype) # [seq_len, head_size//2]
|
|
||||||
#
|
|
||||||
# # Явное изменение формы для broadcasting
|
|
||||||
# cos = cos.reshape(1, 1, seq_len, head_size // 2)
|
|
||||||
# sin = sin.reshape(1, 1, seq_len, head_size // 2)
|
|
||||||
#
|
|
||||||
# # Разделяем на четные и нечетные компоненты по ПОСЛЕДНЕМУ измерению
|
|
||||||
# x_even = x[..., 0::2] # [batch_size, num_heads, seq_len, head_size//2]
|
|
||||||
# x_odd = x[..., 1::2] # [batch_size, num_heads, seq_len, head_size//2]
|
|
||||||
#
|
|
||||||
# # Применяем поворот: q' = q * cos(mθ) + rotate(q) * sin(mθ)
|
|
||||||
# x_rotated_even = x_even * cos - x_odd * sin
|
|
||||||
# x_rotated_odd = x_even * sin + x_odd * cos
|
|
||||||
#
|
|
||||||
# # Объединяем обратно в исходную размерность
|
|
||||||
# x_rotated = torch.stack([x_rotated_even, x_rotated_odd], dim=-1)
|
|
||||||
# x_rotated = x_rotated.flatten(-2) # [batch_size, seq_len, head_size]
|
|
||||||
#
|
|
||||||
# return x_rotated
|
|
||||||
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@@ -456,9 +411,9 @@ class Mistral(BaseModel):
|
|||||||
vocab_size = logits_scaled.size(-1)
|
vocab_size = logits_scaled.size(-1)
|
||||||
|
|
||||||
# создаём маску: 1, если токен НЕ в topk_indices
|
# создаём маску: 1, если токен НЕ в topk_indices
|
||||||
mask = torch.ones_like(logits_scaled, dtype=torch.uint8)
|
mask = torch.ones_like(logits_scaled, dtype=torch.bool if hasattr(torch, "bool") else torch.uint8)
|
||||||
mask.scatter_(1, topk_indices, 0) # 0 там, где top-k индексы
|
mask.scatter_(1, topk_indices, False if hasattr(torch, "bool") else 0) # 0 там, где top-k индексы
|
||||||
masked_logits[mask.byte()] = float('-inf')
|
masked_logits[mask.bool() if hasattr(torch, "bool") else mask.byte()] = float('-inf')
|
||||||
|
|
||||||
logits_scaled = masked_logits
|
logits_scaled = masked_logits
|
||||||
|
|
||||||
@@ -470,12 +425,12 @@ class Mistral(BaseModel):
|
|||||||
# 3. Посчитаем кумулятивную сумму вероятностей:
|
# 3. Посчитаем кумулятивную сумму вероятностей:
|
||||||
cum_probs = torch.cumsum(sorted_probs, dim=-1) # [B, vocab_size]
|
cum_probs = torch.cumsum(sorted_probs, dim=-1) # [B, vocab_size]
|
||||||
# 4. Определим маску: оставить токены, пока сумма < top_p
|
# 4. Определим маску: оставить токены, пока сумма < top_p
|
||||||
sorted_mask = (cum_probs <= top_p).byte() # [B, vocab_size]
|
sorted_mask = (cum_probs <= top_p).bool() if hasattr(torch, "bool") else (cum_probs <= top_p).byte() # [B, vocab_size]
|
||||||
# Гарантируем, что хотя бы первый токен останется
|
# Гарантируем, что хотя бы первый токен останется
|
||||||
sorted_mask[:, 0] = 1
|
sorted_mask[:, 0] = True if hasattr(torch, "bool") else 1
|
||||||
# 5. Преобразуем маску обратно в оригинальный порядок:
|
# 5. Преобразуем маску обратно в оригинальный порядок:
|
||||||
# Создаём полную маску из 0
|
# Создаём полную маску из 0
|
||||||
mask = torch.zeros_like(probs, dtype=torch.uint8)
|
mask = torch.zeros_like(probs, dtype=torch.bool if hasattr(torch, "bool") else torch.uint8)
|
||||||
# Устанавливаем 1 в местах нужных токенов
|
# Устанавливаем 1 в местах нужных токенов
|
||||||
mask.scatter_(dim=1, index=sorted_indices, src=sorted_mask)
|
mask.scatter_(dim=1, index=sorted_indices, src=sorted_mask)
|
||||||
# 6. Зануляем логиты токенов вне топ-p:
|
# 6. Зануляем логиты токенов вне топ-p:
|
||||||
|
|||||||
Reference in New Issue
Block a user