fix: update PyTorch mask types and BPE tokenizer serialization

- Replace deprecated torch.uint8 and .byte() with torch.bool in GPT.generate
- Add save/load methods to BPETokenizer for proper merges and vocab_list serialization
- Update dependencies in pyproject.toml
This commit is contained in:
Sergey Penkovsky
2025-10-05 08:09:30 +03:00
parent ec07546ea8
commit f4bdc81829
4 changed files with 110 additions and 9 deletions

View File

@@ -12,6 +12,13 @@ dependencies = [
"numpy>=1.24.0",
]
[project.optional-dependencies]
test = [
"pytest>=7.0.0",
"pytest-cov>=4.0.0",
"pytest-mock>=3.0.0",
]
[build-system]
requires = ["uv_build>=0.8.22,<0.9.0"]
build-backend = "uv_build"

View File

@@ -215,10 +215,10 @@ class GPT(BaseModel):
masked_logits = logits_scaled.clone()
vocab_size = logits_scaled.size(-1)
# создаём маску: 1, если токен НЕ в topk_indices
mask = torch.ones_like(logits_scaled, dtype=torch.uint8)
mask.scatter_(1, topk_indices, 0) # 0 там, где top-k индексы
masked_logits[mask.byte()] = float('-inf')
# создаём маску: True, если токен НЕ в topk_indices
mask = torch.ones_like(logits_scaled, dtype=torch.bool if hasattr(torch, 'bool') else torch.uint8)
mask.scatter_(1, topk_indices, False if hasattr(torch, 'bool') else 0) # False там, где top-k индексы
masked_logits[mask] = float('-inf')
logits_scaled = masked_logits
@@ -230,13 +230,13 @@ class GPT(BaseModel):
# 3. Посчитаем кумулятивную сумму вероятностей:
cum_probs = torch.cumsum(sorted_probs, dim=-1) # [B, vocab_size]
# 4. Определим маску: оставить токены, пока сумма < top_p
sorted_mask = (cum_probs <= top_p).byte() # [B, vocab_size]
sorted_mask = (cum_probs <= top_p) # [B, vocab_size]
# Гарантируем, что хотя бы первый токен останется
sorted_mask[:, 0] = 1
sorted_mask[:, 0] = True
# 5. Преобразуем маску обратно в оригинальный порядок:
# Создаём полную маску из 0
mask = torch.zeros_like(probs, dtype=torch.uint8)
# Устанавливаем 1 в местах нужных токенов
# Создаём полную маску из False
mask = torch.zeros_like(probs, dtype=torch.bool if hasattr(torch, 'bool') else torch.uint8)
# Устанавливаем True в местах нужных токенов
mask.scatter_(dim=1, index=sorted_indices, src=sorted_mask)
# 6. Зануляем логиты токенов вне топ-p:
logits_scaled[~mask] = float('-inf')

View File

@@ -197,6 +197,77 @@ class BPETokenizer(BaseTokenizer):
else:
tokens.append(self.unk_token) # Специальное значение
return tokens
def save(self, filepath: str):
"""
Сохраняет токенизатор в файл.
Args:
filepath: Путь для сохранения
"""
import json
# Преобразуем кортежи в строки для JSON сериализации
merges_serializable = {f"{k[0]},{k[1]}": v for k, v in self.merges.items()}
config = {
'vocab': self.vocab,
'vocab_size': self.vocab_size,
'pad_token': self.pad_token,
'unk_token': self.unk_token,
'bos_token': self.bos_token,
'eos_token': self.eos_token,
'tokenizer_type': self.__class__.__name__,
'merges': merges_serializable,
'vocab_list': self.vocab_list
}
with open(filepath, 'w', encoding='utf-8') as f:
json.dump(config, f, ensure_ascii=False, indent=2)
@classmethod
def load(cls, filepath: str):
"""
Загружает токенизатор из файла.
Args:
filepath: Путь к файлу
Returns:
BPETokenizer: Загруженный токенизатор
"""
import json
with open(filepath, 'r', encoding='utf-8') as f:
config = json.load(f)
# Создаем экземпляр токенизатора
tokenizer = cls()
tokenizer.vocab = config['vocab']
tokenizer.vocab_size = config['vocab_size']
tokenizer.pad_token = config['pad_token']
tokenizer.unk_token = config['unk_token']
tokenizer.bos_token = config['bos_token']
tokenizer.eos_token = config['eos_token']
tokenizer.vocab_list = config['vocab_list']
# Восстанавливаем кортежи из строк
tokenizer.merges = {}
for k, v in config['merges'].items():
parts = k.split(',')
if len(parts) == 2:
tokenizer.merges[(parts[0], parts[1])] = v
# Создаем обратный словарь
tokenizer.inverse_vocab = {v: k for k, v in tokenizer.vocab.items()}
# Обновляем ID специальных токенов
tokenizer.pad_token_id = tokenizer.vocab.get(tokenizer.pad_token)
tokenizer.unk_token_id = tokenizer.vocab.get(tokenizer.unk_token)
tokenizer.bos_token_id = tokenizer.vocab.get(tokenizer.bos_token)
tokenizer.eos_token_id = tokenizer.vocab.get(tokenizer.eos_token)
return tokenizer
class SimpleBPETokenizer(BPETokenizer):