mirror of
https://github.com/pese-git/llm-arch-research.git
synced 2026-01-23 21:10:54 +00:00
fix(gpt2, llama): proper top-k/top-p mask handling in sampling for PyTorch compatibility (bool/uint8)
- Refactored token selection logic in methods of GPT2 and Llama classes. - Masks are now created with dtype=torch.bool (or torch.uint8 for legacy PyTorch). - Used True/False for mask/scatter instead of 1/0, ensuring correctness across PyTorch versions. - Fixed RuntimeError: masked_fill_ only supports boolean masks, previously raised by uint8-masks in new PyTorch. - Backward compatibility maintained: code works on PyTorch >=1.2 and for old clusters (via the else branch). Motivation: Fixes sampling errors for all modern PyTorch users while keeping research code usable on old infra.
This commit is contained in:
@@ -219,10 +219,9 @@ class GPT2(BaseModel):
|
||||
vocab_size = logits_scaled.size(-1)
|
||||
|
||||
# создаём маску: 1, если токен НЕ в topk_indices
|
||||
mask = torch.ones_like(logits_scaled, dtype=torch.uint8)
|
||||
mask.scatter_(1, topk_indices, 0) # 0 там, где top-k индексы
|
||||
masked_logits[mask.byte()] = float("-inf")
|
||||
|
||||
mask = torch.ones_like(logits_scaled, dtype=torch.bool if hasattr(torch, "bool") else torch.uint8)
|
||||
mask.scatter_(1, topk_indices, False if hasattr(torch, "bool") else 0) # 0 там, где top-k индексы
|
||||
masked_logits[mask.bool() if hasattr(torch, "bool") else mask.byte()] = float('-inf')
|
||||
logits_scaled = masked_logits
|
||||
|
||||
if do_sample == True and top_p != None:
|
||||
@@ -235,16 +234,16 @@ class GPT2(BaseModel):
|
||||
# 3. Посчитаем кумулятивную сумму вероятностей:
|
||||
cum_probs = torch.cumsum(sorted_probs, dim=-1) # [B, vocab_size]
|
||||
# 4. Определим маску: оставить токены, пока сумма < top_p
|
||||
sorted_mask = (cum_probs <= top_p).byte() # [B, vocab_size]
|
||||
sorted_mask = (cum_probs <= top_p).bool() if hasattr(torch, "bool") else (cum_probs <= top_p).byte() # [B, vocab_size]
|
||||
# Гарантируем, что хотя бы первый токен останется
|
||||
sorted_mask[:, 0] = 1
|
||||
sorted_mask[:, 0] = True if hasattr(torch, "bool") else 1
|
||||
# 5. Преобразуем маску обратно в оригинальный порядок:
|
||||
# Создаём полную маску из 0
|
||||
mask = torch.zeros_like(probs, dtype=torch.uint8)
|
||||
mask = torch.zeros_like(probs, dtype=torch.bool if hasattr(torch, "bool") else torch.uint8)
|
||||
# Устанавливаем 1 в местах нужных токенов
|
||||
mask.scatter_(dim=1, index=sorted_indices, src=sorted_mask)
|
||||
# 6. Зануляем логиты токенов вне топ-p:
|
||||
logits_scaled[~mask] = float("-inf")
|
||||
logits_scaled[~mask] = float('-inf')
|
||||
|
||||
# 4. Применяем Softmax
|
||||
probs = F.softmax(logits_scaled, dim=-1) # [batch_size, vocab_size]
|
||||
|
||||
@@ -194,10 +194,9 @@ class Llama(BaseModel):
|
||||
vocab_size = logits_scaled.size(-1)
|
||||
|
||||
# создаём маску: 1, если токен НЕ в topk_indices
|
||||
mask = torch.ones_like(logits_scaled, dtype=torch.uint8)
|
||||
mask.scatter_(1, topk_indices, 0) # 0 там, где top-k индексы
|
||||
masked_logits[mask.byte()] = float("-inf")
|
||||
|
||||
mask = torch.ones_like(logits_scaled, dtype=torch.bool if hasattr(torch, "bool") else torch.uint8)
|
||||
mask.scatter_(1, topk_indices, False if hasattr(torch, "bool") else 0) # 0 там, где top-k индексы
|
||||
masked_logits[mask.bool() if hasattr(torch, "bool") else mask.byte()] = float('-inf')
|
||||
logits_scaled = masked_logits
|
||||
|
||||
if do_sample == True and top_p != None:
|
||||
@@ -210,16 +209,16 @@ class Llama(BaseModel):
|
||||
# 3. Посчитаем кумулятивную сумму вероятностей:
|
||||
cum_probs = torch.cumsum(sorted_probs, dim=-1) # [B, vocab_size]
|
||||
# 4. Определим маску: оставить токены, пока сумма < top_p
|
||||
sorted_mask = (cum_probs <= top_p).byte() # [B, vocab_size]
|
||||
sorted_mask = (cum_probs <= top_p).bool() if hasattr(torch, "bool") else (cum_probs <= top_p).byte() # [B, vocab_size]
|
||||
# Гарантируем, что хотя бы первый токен останется
|
||||
sorted_mask[:, 0] = 1
|
||||
sorted_mask[:, 0] = True if hasattr(torch, "bool") else 1
|
||||
# 5. Преобразуем маску обратно в оригинальный порядок:
|
||||
# Создаём полную маску из 0
|
||||
mask = torch.zeros_like(probs, dtype=torch.uint8)
|
||||
mask = torch.zeros_like(probs, dtype=torch.bool if hasattr(torch, "bool") else torch.uint8)
|
||||
# Устанавливаем 1 в местах нужных токенов
|
||||
mask.scatter_(dim=1, index=sorted_indices, src=sorted_mask)
|
||||
# 6. Зануляем логиты токенов вне топ-p:
|
||||
logits_scaled[~mask] = float("-inf")
|
||||
logits_scaled[~mask] = float('-inf')
|
||||
|
||||
# 4. Применяем Softmax
|
||||
probs = F.softmax(logits_scaled, dim=-1) # [batch_size, vocab_size]
|
||||
|
||||
Reference in New Issue
Block a user