From 50d759302380a5e2af44e0543b7a798052bf644e Mon Sep 17 00:00:00 2001 From: Sergey Penkovsky Date: Wed, 15 Oct 2025 14:35:10 +0300 Subject: [PATCH] fix(gpt2, llama): proper top-k/top-p mask handling in sampling for PyTorch compatibility (bool/uint8) - Refactored token selection logic in methods of GPT2 and Llama classes. - Masks are now created with dtype=torch.bool (or torch.uint8 for legacy PyTorch). - Used True/False for mask/scatter instead of 1/0, ensuring correctness across PyTorch versions. - Fixed RuntimeError: masked_fill_ only supports boolean masks, previously raised by uint8-masks in new PyTorch. - Backward compatibility maintained: code works on PyTorch >=1.2 and for old clusters (via the else branch). Motivation: Fixes sampling errors for all modern PyTorch users while keeping research code usable on old infra. --- llm/src/llm/models/gpt/gpt2.py | 15 +++++++-------- llm/src/llm/models/llama/llama.py | 15 +++++++-------- 2 files changed, 14 insertions(+), 16 deletions(-) diff --git a/llm/src/llm/models/gpt/gpt2.py b/llm/src/llm/models/gpt/gpt2.py index 7e98b69..abc666e 100644 --- a/llm/src/llm/models/gpt/gpt2.py +++ b/llm/src/llm/models/gpt/gpt2.py @@ -219,10 +219,9 @@ class GPT2(BaseModel): vocab_size = logits_scaled.size(-1) # создаём маску: 1, если токен НЕ в topk_indices - mask = torch.ones_like(logits_scaled, dtype=torch.uint8) - mask.scatter_(1, topk_indices, 0) # 0 там, где top-k индексы - masked_logits[mask.byte()] = float("-inf") - + mask = torch.ones_like(logits_scaled, dtype=torch.bool if hasattr(torch, "bool") else torch.uint8) + mask.scatter_(1, topk_indices, False if hasattr(torch, "bool") else 0) # 0 там, где top-k индексы + masked_logits[mask.bool() if hasattr(torch, "bool") else mask.byte()] = float('-inf') logits_scaled = masked_logits if do_sample == True and top_p != None: @@ -235,16 +234,16 @@ class GPT2(BaseModel): # 3. Посчитаем кумулятивную сумму вероятностей: cum_probs = torch.cumsum(sorted_probs, dim=-1) # [B, vocab_size] # 4. Определим маску: оставить токены, пока сумма < top_p - sorted_mask = (cum_probs <= top_p).byte() # [B, vocab_size] + sorted_mask = (cum_probs <= top_p).bool() if hasattr(torch, "bool") else (cum_probs <= top_p).byte() # [B, vocab_size] # Гарантируем, что хотя бы первый токен останется - sorted_mask[:, 0] = 1 + sorted_mask[:, 0] = True if hasattr(torch, "bool") else 1 # 5. Преобразуем маску обратно в оригинальный порядок: # Создаём полную маску из 0 - mask = torch.zeros_like(probs, dtype=torch.uint8) + mask = torch.zeros_like(probs, dtype=torch.bool if hasattr(torch, "bool") else torch.uint8) # Устанавливаем 1 в местах нужных токенов mask.scatter_(dim=1, index=sorted_indices, src=sorted_mask) # 6. Зануляем логиты токенов вне топ-p: - logits_scaled[~mask] = float("-inf") + logits_scaled[~mask] = float('-inf') # 4. Применяем Softmax probs = F.softmax(logits_scaled, dim=-1) # [batch_size, vocab_size] diff --git a/llm/src/llm/models/llama/llama.py b/llm/src/llm/models/llama/llama.py index 78486fb..5aaf72c 100644 --- a/llm/src/llm/models/llama/llama.py +++ b/llm/src/llm/models/llama/llama.py @@ -194,10 +194,9 @@ class Llama(BaseModel): vocab_size = logits_scaled.size(-1) # создаём маску: 1, если токен НЕ в topk_indices - mask = torch.ones_like(logits_scaled, dtype=torch.uint8) - mask.scatter_(1, topk_indices, 0) # 0 там, где top-k индексы - masked_logits[mask.byte()] = float("-inf") - + mask = torch.ones_like(logits_scaled, dtype=torch.bool if hasattr(torch, "bool") else torch.uint8) + mask.scatter_(1, topk_indices, False if hasattr(torch, "bool") else 0) # 0 там, где top-k индексы + masked_logits[mask.bool() if hasattr(torch, "bool") else mask.byte()] = float('-inf') logits_scaled = masked_logits if do_sample == True and top_p != None: @@ -210,16 +209,16 @@ class Llama(BaseModel): # 3. Посчитаем кумулятивную сумму вероятностей: cum_probs = torch.cumsum(sorted_probs, dim=-1) # [B, vocab_size] # 4. Определим маску: оставить токены, пока сумма < top_p - sorted_mask = (cum_probs <= top_p).byte() # [B, vocab_size] + sorted_mask = (cum_probs <= top_p).bool() if hasattr(torch, "bool") else (cum_probs <= top_p).byte() # [B, vocab_size] # Гарантируем, что хотя бы первый токен останется - sorted_mask[:, 0] = 1 + sorted_mask[:, 0] = True if hasattr(torch, "bool") else 1 # 5. Преобразуем маску обратно в оригинальный порядок: # Создаём полную маску из 0 - mask = torch.zeros_like(probs, dtype=torch.uint8) + mask = torch.zeros_like(probs, dtype=torch.bool if hasattr(torch, "bool") else torch.uint8) # Устанавливаем 1 в местах нужных токенов mask.scatter_(dim=1, index=sorted_indices, src=sorted_mask) # 6. Зануляем логиты токенов вне топ-p: - logits_scaled[~mask] = float("-inf") + logits_scaled[~mask] = float('-inf') # 4. Применяем Softmax probs = F.softmax(logits_scaled, dim=-1) # [batch_size, vocab_size]