From f4bdc81829440ac7486401ee7d9da23c4a495f83 Mon Sep 17 00:00:00 2001 From: Sergey Penkovsky Date: Sun, 5 Oct 2025 08:09:30 +0300 Subject: [PATCH] fix: update PyTorch mask types and BPE tokenizer serialization - Replace deprecated torch.uint8 and .byte() with torch.bool in GPT.generate - Add save/load methods to BPETokenizer for proper merges and vocab_list serialization - Update dependencies in pyproject.toml --- llm/pyproject.toml | 7 +++ llm/src/llm/models/gpt/gpt.py | 18 +++---- llm/src/llm/tokenizers/bpe_tokenizer.py | 71 +++++++++++++++++++++++++ uv.lock | 23 ++++++++ 4 files changed, 110 insertions(+), 9 deletions(-) diff --git a/llm/pyproject.toml b/llm/pyproject.toml index 72646c6..451e11b 100644 --- a/llm/pyproject.toml +++ b/llm/pyproject.toml @@ -12,6 +12,13 @@ dependencies = [ "numpy>=1.24.0", ] +[project.optional-dependencies] +test = [ + "pytest>=7.0.0", + "pytest-cov>=4.0.0", + "pytest-mock>=3.0.0", +] + [build-system] requires = ["uv_build>=0.8.22,<0.9.0"] build-backend = "uv_build" diff --git a/llm/src/llm/models/gpt/gpt.py b/llm/src/llm/models/gpt/gpt.py index 243e635..c184a27 100644 --- a/llm/src/llm/models/gpt/gpt.py +++ b/llm/src/llm/models/gpt/gpt.py @@ -215,10 +215,10 @@ class GPT(BaseModel): masked_logits = logits_scaled.clone() vocab_size = logits_scaled.size(-1) - # создаём маску: 1, если токен НЕ в topk_indices - mask = torch.ones_like(logits_scaled, dtype=torch.uint8) - mask.scatter_(1, topk_indices, 0) # 0 там, где top-k индексы - masked_logits[mask.byte()] = float('-inf') + # создаём маску: True, если токен НЕ в topk_indices + mask = torch.ones_like(logits_scaled, dtype=torch.bool if hasattr(torch, 'bool') else torch.uint8) + mask.scatter_(1, topk_indices, False if hasattr(torch, 'bool') else 0) # False там, где top-k индексы + masked_logits[mask] = float('-inf') logits_scaled = masked_logits @@ -230,13 +230,13 @@ class GPT(BaseModel): # 3. Посчитаем кумулятивную сумму вероятностей: cum_probs = torch.cumsum(sorted_probs, dim=-1) # [B, vocab_size] # 4. Определим маску: оставить токены, пока сумма < top_p - sorted_mask = (cum_probs <= top_p).byte() # [B, vocab_size] + sorted_mask = (cum_probs <= top_p) # [B, vocab_size] # Гарантируем, что хотя бы первый токен останется - sorted_mask[:, 0] = 1 + sorted_mask[:, 0] = True # 5. Преобразуем маску обратно в оригинальный порядок: - # Создаём полную маску из 0 - mask = torch.zeros_like(probs, dtype=torch.uint8) - # Устанавливаем 1 в местах нужных токенов + # Создаём полную маску из False + mask = torch.zeros_like(probs, dtype=torch.bool if hasattr(torch, 'bool') else torch.uint8) + # Устанавливаем True в местах нужных токенов mask.scatter_(dim=1, index=sorted_indices, src=sorted_mask) # 6. Зануляем логиты токенов вне топ-p: logits_scaled[~mask] = float('-inf') diff --git a/llm/src/llm/tokenizers/bpe_tokenizer.py b/llm/src/llm/tokenizers/bpe_tokenizer.py index a16cdcd..7e55e3e 100644 --- a/llm/src/llm/tokenizers/bpe_tokenizer.py +++ b/llm/src/llm/tokenizers/bpe_tokenizer.py @@ -197,6 +197,77 @@ class BPETokenizer(BaseTokenizer): else: tokens.append(self.unk_token) # Специальное значение return tokens + + def save(self, filepath: str): + """ + Сохраняет токенизатор в файл. + + Args: + filepath: Путь для сохранения + """ + import json + + # Преобразуем кортежи в строки для JSON сериализации + merges_serializable = {f"{k[0]},{k[1]}": v for k, v in self.merges.items()} + + config = { + 'vocab': self.vocab, + 'vocab_size': self.vocab_size, + 'pad_token': self.pad_token, + 'unk_token': self.unk_token, + 'bos_token': self.bos_token, + 'eos_token': self.eos_token, + 'tokenizer_type': self.__class__.__name__, + 'merges': merges_serializable, + 'vocab_list': self.vocab_list + } + + with open(filepath, 'w', encoding='utf-8') as f: + json.dump(config, f, ensure_ascii=False, indent=2) + + @classmethod + def load(cls, filepath: str): + """ + Загружает токенизатор из файла. + + Args: + filepath: Путь к файлу + + Returns: + BPETokenizer: Загруженный токенизатор + """ + import json + + with open(filepath, 'r', encoding='utf-8') as f: + config = json.load(f) + + # Создаем экземпляр токенизатора + tokenizer = cls() + tokenizer.vocab = config['vocab'] + tokenizer.vocab_size = config['vocab_size'] + tokenizer.pad_token = config['pad_token'] + tokenizer.unk_token = config['unk_token'] + tokenizer.bos_token = config['bos_token'] + tokenizer.eos_token = config['eos_token'] + tokenizer.vocab_list = config['vocab_list'] + + # Восстанавливаем кортежи из строк + tokenizer.merges = {} + for k, v in config['merges'].items(): + parts = k.split(',') + if len(parts) == 2: + tokenizer.merges[(parts[0], parts[1])] = v + + # Создаем обратный словарь + tokenizer.inverse_vocab = {v: k for k, v in tokenizer.vocab.items()} + + # Обновляем ID специальных токенов + tokenizer.pad_token_id = tokenizer.vocab.get(tokenizer.pad_token) + tokenizer.unk_token_id = tokenizer.vocab.get(tokenizer.unk_token) + tokenizer.bos_token_id = tokenizer.vocab.get(tokenizer.bos_token) + tokenizer.eos_token_id = tokenizer.vocab.get(tokenizer.eos_token) + + return tokenizer class SimpleBPETokenizer(BPETokenizer): diff --git a/uv.lock b/uv.lock index 6f1e674..00c6633 100644 --- a/uv.lock +++ b/uv.lock @@ -1393,11 +1393,22 @@ dependencies = [ { name = "torch" }, ] +[package.optional-dependencies] +test = [ + { name = "pytest" }, + { name = "pytest-cov" }, + { name = "pytest-mock" }, +] + [package.metadata] requires-dist = [ { name = "numpy", specifier = ">=1.24.0" }, + { name = "pytest", marker = "extra == 'test'", specifier = ">=7.0.0" }, + { name = "pytest-cov", marker = "extra == 'test'", specifier = ">=4.0.0" }, + { name = "pytest-mock", marker = "extra == 'test'", specifier = ">=3.0.0" }, { name = "torch", specifier = ">=2.3.0" }, ] +provides-extras = ["test"] [[package]] name = "llm-arch-research" @@ -2500,6 +2511,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ee/49/1377b49de7d0c1ce41292161ea0f721913fa8722c19fb9c1e3aa0367eecb/pytest_cov-7.0.0-py3-none-any.whl", hash = "sha256:3b8e9558b16cc1479da72058bdecf8073661c7f57f7d3c5f22a1c23507f2d861", size = 22424, upload-time = "2025-09-09T10:57:00.695Z" }, ] +[[package]] +name = "pytest-mock" +version = "3.15.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pytest" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/68/14/eb014d26be205d38ad5ad20d9a80f7d201472e08167f0bb4361e251084a9/pytest_mock-3.15.1.tar.gz", hash = "sha256:1849a238f6f396da19762269de72cb1814ab44416fa73a8686deac10b0d87a0f", size = 34036, upload-time = "2025-09-16T16:37:27.081Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/5a/cc/06253936f4a7fa2e0f48dfe6d851d9c56df896a9ab09ac019d70b760619c/pytest_mock-3.15.1-py3-none-any.whl", hash = "sha256:0a25e2eb88fe5168d535041d09a4529a188176ae608a6d249ee65abc0949630d", size = 10095, upload-time = "2025-09-16T16:37:25.734Z" }, +] + [[package]] name = "python-dateutil" version = "2.9.0.post0"