fix: update PyTorch mask types and BPE tokenizer serialization

- Replace deprecated torch.uint8 and .byte() with torch.bool in GPT.generate
- Add save/load methods to BPETokenizer for proper merges and vocab_list serialization
- Update dependencies in pyproject.toml
This commit is contained in:
Sergey Penkovsky
2025-10-05 08:09:30 +03:00
parent ec07546ea8
commit f4bdc81829
4 changed files with 110 additions and 9 deletions

View File

@@ -12,6 +12,13 @@ dependencies = [
"numpy>=1.24.0",
]
[project.optional-dependencies]
test = [
"pytest>=7.0.0",
"pytest-cov>=4.0.0",
"pytest-mock>=3.0.0",
]
[build-system]
requires = ["uv_build>=0.8.22,<0.9.0"]
build-backend = "uv_build"

View File

@@ -215,10 +215,10 @@ class GPT(BaseModel):
masked_logits = logits_scaled.clone()
vocab_size = logits_scaled.size(-1)
# создаём маску: 1, если токен НЕ в topk_indices
mask = torch.ones_like(logits_scaled, dtype=torch.uint8)
mask.scatter_(1, topk_indices, 0) # 0 там, где top-k индексы
masked_logits[mask.byte()] = float('-inf')
# создаём маску: True, если токен НЕ в topk_indices
mask = torch.ones_like(logits_scaled, dtype=torch.bool if hasattr(torch, 'bool') else torch.uint8)
mask.scatter_(1, topk_indices, False if hasattr(torch, 'bool') else 0) # False там, где top-k индексы
masked_logits[mask] = float('-inf')
logits_scaled = masked_logits
@@ -230,13 +230,13 @@ class GPT(BaseModel):
# 3. Посчитаем кумулятивную сумму вероятностей:
cum_probs = torch.cumsum(sorted_probs, dim=-1) # [B, vocab_size]
# 4. Определим маску: оставить токены, пока сумма < top_p
sorted_mask = (cum_probs <= top_p).byte() # [B, vocab_size]
sorted_mask = (cum_probs <= top_p) # [B, vocab_size]
# Гарантируем, что хотя бы первый токен останется
sorted_mask[:, 0] = 1
sorted_mask[:, 0] = True
# 5. Преобразуем маску обратно в оригинальный порядок:
# Создаём полную маску из 0
mask = torch.zeros_like(probs, dtype=torch.uint8)
# Устанавливаем 1 в местах нужных токенов
# Создаём полную маску из False
mask = torch.zeros_like(probs, dtype=torch.bool if hasattr(torch, 'bool') else torch.uint8)
# Устанавливаем True в местах нужных токенов
mask.scatter_(dim=1, index=sorted_indices, src=sorted_mask)
# 6. Зануляем логиты токенов вне топ-p:
logits_scaled[~mask] = float('-inf')

View File

@@ -198,6 +198,77 @@ class BPETokenizer(BaseTokenizer):
tokens.append(self.unk_token) # Специальное значение
return tokens
def save(self, filepath: str):
"""
Сохраняет токенизатор в файл.
Args:
filepath: Путь для сохранения
"""
import json
# Преобразуем кортежи в строки для JSON сериализации
merges_serializable = {f"{k[0]},{k[1]}": v for k, v in self.merges.items()}
config = {
'vocab': self.vocab,
'vocab_size': self.vocab_size,
'pad_token': self.pad_token,
'unk_token': self.unk_token,
'bos_token': self.bos_token,
'eos_token': self.eos_token,
'tokenizer_type': self.__class__.__name__,
'merges': merges_serializable,
'vocab_list': self.vocab_list
}
with open(filepath, 'w', encoding='utf-8') as f:
json.dump(config, f, ensure_ascii=False, indent=2)
@classmethod
def load(cls, filepath: str):
"""
Загружает токенизатор из файла.
Args:
filepath: Путь к файлу
Returns:
BPETokenizer: Загруженный токенизатор
"""
import json
with open(filepath, 'r', encoding='utf-8') as f:
config = json.load(f)
# Создаем экземпляр токенизатора
tokenizer = cls()
tokenizer.vocab = config['vocab']
tokenizer.vocab_size = config['vocab_size']
tokenizer.pad_token = config['pad_token']
tokenizer.unk_token = config['unk_token']
tokenizer.bos_token = config['bos_token']
tokenizer.eos_token = config['eos_token']
tokenizer.vocab_list = config['vocab_list']
# Восстанавливаем кортежи из строк
tokenizer.merges = {}
for k, v in config['merges'].items():
parts = k.split(',')
if len(parts) == 2:
tokenizer.merges[(parts[0], parts[1])] = v
# Создаем обратный словарь
tokenizer.inverse_vocab = {v: k for k, v in tokenizer.vocab.items()}
# Обновляем ID специальных токенов
tokenizer.pad_token_id = tokenizer.vocab.get(tokenizer.pad_token)
tokenizer.unk_token_id = tokenizer.vocab.get(tokenizer.unk_token)
tokenizer.bos_token_id = tokenizer.vocab.get(tokenizer.bos_token)
tokenizer.eos_token_id = tokenizer.vocab.get(tokenizer.eos_token)
return tokenizer
class SimpleBPETokenizer(BPETokenizer):
"""

23
uv.lock generated
View File

@@ -1393,11 +1393,22 @@ dependencies = [
{ name = "torch" },
]
[package.optional-dependencies]
test = [
{ name = "pytest" },
{ name = "pytest-cov" },
{ name = "pytest-mock" },
]
[package.metadata]
requires-dist = [
{ name = "numpy", specifier = ">=1.24.0" },
{ name = "pytest", marker = "extra == 'test'", specifier = ">=7.0.0" },
{ name = "pytest-cov", marker = "extra == 'test'", specifier = ">=4.0.0" },
{ name = "pytest-mock", marker = "extra == 'test'", specifier = ">=3.0.0" },
{ name = "torch", specifier = ">=2.3.0" },
]
provides-extras = ["test"]
[[package]]
name = "llm-arch-research"
@@ -2500,6 +2511,18 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/ee/49/1377b49de7d0c1ce41292161ea0f721913fa8722c19fb9c1e3aa0367eecb/pytest_cov-7.0.0-py3-none-any.whl", hash = "sha256:3b8e9558b16cc1479da72058bdecf8073661c7f57f7d3c5f22a1c23507f2d861", size = 22424, upload-time = "2025-09-09T10:57:00.695Z" },
]
[[package]]
name = "pytest-mock"
version = "3.15.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "pytest" },
]
sdist = { url = "https://files.pythonhosted.org/packages/68/14/eb014d26be205d38ad5ad20d9a80f7d201472e08167f0bb4361e251084a9/pytest_mock-3.15.1.tar.gz", hash = "sha256:1849a238f6f396da19762269de72cb1814ab44416fa73a8686deac10b0d87a0f", size = 34036, upload-time = "2025-09-16T16:37:27.081Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/5a/cc/06253936f4a7fa2e0f48dfe6d851d9c56df896a9ab09ac019d70b760619c/pytest_mock-3.15.1-py3-none-any.whl", hash = "sha256:0a25e2eb88fe5168d535041d09a4529a188176ae608a6d249ee65abc0949630d", size = 10095, upload-time = "2025-09-16T16:37:25.734Z" },
]
[[package]]
name = "python-dateutil"
version = "2.9.0.post0"