mirror of
https://github.com/pese-git/llm-arch-research.git
synced 2026-01-23 21:10:54 +00:00
fix: update PyTorch mask types and BPE tokenizer serialization
- Replace deprecated torch.uint8 and .byte() with torch.bool in GPT.generate - Add save/load methods to BPETokenizer for proper merges and vocab_list serialization - Update dependencies in pyproject.toml
This commit is contained in:
@@ -12,6 +12,13 @@ dependencies = [
|
|||||||
"numpy>=1.24.0",
|
"numpy>=1.24.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[project.optional-dependencies]
|
||||||
|
test = [
|
||||||
|
"pytest>=7.0.0",
|
||||||
|
"pytest-cov>=4.0.0",
|
||||||
|
"pytest-mock>=3.0.0",
|
||||||
|
]
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
requires = ["uv_build>=0.8.22,<0.9.0"]
|
requires = ["uv_build>=0.8.22,<0.9.0"]
|
||||||
build-backend = "uv_build"
|
build-backend = "uv_build"
|
||||||
|
|||||||
@@ -215,10 +215,10 @@ class GPT(BaseModel):
|
|||||||
masked_logits = logits_scaled.clone()
|
masked_logits = logits_scaled.clone()
|
||||||
vocab_size = logits_scaled.size(-1)
|
vocab_size = logits_scaled.size(-1)
|
||||||
|
|
||||||
# создаём маску: 1, если токен НЕ в topk_indices
|
# создаём маску: True, если токен НЕ в topk_indices
|
||||||
mask = torch.ones_like(logits_scaled, dtype=torch.uint8)
|
mask = torch.ones_like(logits_scaled, dtype=torch.bool if hasattr(torch, 'bool') else torch.uint8)
|
||||||
mask.scatter_(1, topk_indices, 0) # 0 там, где top-k индексы
|
mask.scatter_(1, topk_indices, False if hasattr(torch, 'bool') else 0) # False там, где top-k индексы
|
||||||
masked_logits[mask.byte()] = float('-inf')
|
masked_logits[mask] = float('-inf')
|
||||||
|
|
||||||
logits_scaled = masked_logits
|
logits_scaled = masked_logits
|
||||||
|
|
||||||
@@ -230,13 +230,13 @@ class GPT(BaseModel):
|
|||||||
# 3. Посчитаем кумулятивную сумму вероятностей:
|
# 3. Посчитаем кумулятивную сумму вероятностей:
|
||||||
cum_probs = torch.cumsum(sorted_probs, dim=-1) # [B, vocab_size]
|
cum_probs = torch.cumsum(sorted_probs, dim=-1) # [B, vocab_size]
|
||||||
# 4. Определим маску: оставить токены, пока сумма < top_p
|
# 4. Определим маску: оставить токены, пока сумма < top_p
|
||||||
sorted_mask = (cum_probs <= top_p).byte() # [B, vocab_size]
|
sorted_mask = (cum_probs <= top_p) # [B, vocab_size]
|
||||||
# Гарантируем, что хотя бы первый токен останется
|
# Гарантируем, что хотя бы первый токен останется
|
||||||
sorted_mask[:, 0] = 1
|
sorted_mask[:, 0] = True
|
||||||
# 5. Преобразуем маску обратно в оригинальный порядок:
|
# 5. Преобразуем маску обратно в оригинальный порядок:
|
||||||
# Создаём полную маску из 0
|
# Создаём полную маску из False
|
||||||
mask = torch.zeros_like(probs, dtype=torch.uint8)
|
mask = torch.zeros_like(probs, dtype=torch.bool if hasattr(torch, 'bool') else torch.uint8)
|
||||||
# Устанавливаем 1 в местах нужных токенов
|
# Устанавливаем True в местах нужных токенов
|
||||||
mask.scatter_(dim=1, index=sorted_indices, src=sorted_mask)
|
mask.scatter_(dim=1, index=sorted_indices, src=sorted_mask)
|
||||||
# 6. Зануляем логиты токенов вне топ-p:
|
# 6. Зануляем логиты токенов вне топ-p:
|
||||||
logits_scaled[~mask] = float('-inf')
|
logits_scaled[~mask] = float('-inf')
|
||||||
|
|||||||
@@ -198,6 +198,77 @@ class BPETokenizer(BaseTokenizer):
|
|||||||
tokens.append(self.unk_token) # Специальное значение
|
tokens.append(self.unk_token) # Специальное значение
|
||||||
return tokens
|
return tokens
|
||||||
|
|
||||||
|
def save(self, filepath: str):
|
||||||
|
"""
|
||||||
|
Сохраняет токенизатор в файл.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filepath: Путь для сохранения
|
||||||
|
"""
|
||||||
|
import json
|
||||||
|
|
||||||
|
# Преобразуем кортежи в строки для JSON сериализации
|
||||||
|
merges_serializable = {f"{k[0]},{k[1]}": v for k, v in self.merges.items()}
|
||||||
|
|
||||||
|
config = {
|
||||||
|
'vocab': self.vocab,
|
||||||
|
'vocab_size': self.vocab_size,
|
||||||
|
'pad_token': self.pad_token,
|
||||||
|
'unk_token': self.unk_token,
|
||||||
|
'bos_token': self.bos_token,
|
||||||
|
'eos_token': self.eos_token,
|
||||||
|
'tokenizer_type': self.__class__.__name__,
|
||||||
|
'merges': merges_serializable,
|
||||||
|
'vocab_list': self.vocab_list
|
||||||
|
}
|
||||||
|
|
||||||
|
with open(filepath, 'w', encoding='utf-8') as f:
|
||||||
|
json.dump(config, f, ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(cls, filepath: str):
|
||||||
|
"""
|
||||||
|
Загружает токенизатор из файла.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filepath: Путь к файлу
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
BPETokenizer: Загруженный токенизатор
|
||||||
|
"""
|
||||||
|
import json
|
||||||
|
|
||||||
|
with open(filepath, 'r', encoding='utf-8') as f:
|
||||||
|
config = json.load(f)
|
||||||
|
|
||||||
|
# Создаем экземпляр токенизатора
|
||||||
|
tokenizer = cls()
|
||||||
|
tokenizer.vocab = config['vocab']
|
||||||
|
tokenizer.vocab_size = config['vocab_size']
|
||||||
|
tokenizer.pad_token = config['pad_token']
|
||||||
|
tokenizer.unk_token = config['unk_token']
|
||||||
|
tokenizer.bos_token = config['bos_token']
|
||||||
|
tokenizer.eos_token = config['eos_token']
|
||||||
|
tokenizer.vocab_list = config['vocab_list']
|
||||||
|
|
||||||
|
# Восстанавливаем кортежи из строк
|
||||||
|
tokenizer.merges = {}
|
||||||
|
for k, v in config['merges'].items():
|
||||||
|
parts = k.split(',')
|
||||||
|
if len(parts) == 2:
|
||||||
|
tokenizer.merges[(parts[0], parts[1])] = v
|
||||||
|
|
||||||
|
# Создаем обратный словарь
|
||||||
|
tokenizer.inverse_vocab = {v: k for k, v in tokenizer.vocab.items()}
|
||||||
|
|
||||||
|
# Обновляем ID специальных токенов
|
||||||
|
tokenizer.pad_token_id = tokenizer.vocab.get(tokenizer.pad_token)
|
||||||
|
tokenizer.unk_token_id = tokenizer.vocab.get(tokenizer.unk_token)
|
||||||
|
tokenizer.bos_token_id = tokenizer.vocab.get(tokenizer.bos_token)
|
||||||
|
tokenizer.eos_token_id = tokenizer.vocab.get(tokenizer.eos_token)
|
||||||
|
|
||||||
|
return tokenizer
|
||||||
|
|
||||||
|
|
||||||
class SimpleBPETokenizer(BPETokenizer):
|
class SimpleBPETokenizer(BPETokenizer):
|
||||||
"""
|
"""
|
||||||
|
|||||||
23
uv.lock
generated
23
uv.lock
generated
@@ -1393,11 +1393,22 @@ dependencies = [
|
|||||||
{ name = "torch" },
|
{ name = "torch" },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[package.optional-dependencies]
|
||||||
|
test = [
|
||||||
|
{ name = "pytest" },
|
||||||
|
{ name = "pytest-cov" },
|
||||||
|
{ name = "pytest-mock" },
|
||||||
|
]
|
||||||
|
|
||||||
[package.metadata]
|
[package.metadata]
|
||||||
requires-dist = [
|
requires-dist = [
|
||||||
{ name = "numpy", specifier = ">=1.24.0" },
|
{ name = "numpy", specifier = ">=1.24.0" },
|
||||||
|
{ name = "pytest", marker = "extra == 'test'", specifier = ">=7.0.0" },
|
||||||
|
{ name = "pytest-cov", marker = "extra == 'test'", specifier = ">=4.0.0" },
|
||||||
|
{ name = "pytest-mock", marker = "extra == 'test'", specifier = ">=3.0.0" },
|
||||||
{ name = "torch", specifier = ">=2.3.0" },
|
{ name = "torch", specifier = ">=2.3.0" },
|
||||||
]
|
]
|
||||||
|
provides-extras = ["test"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "llm-arch-research"
|
name = "llm-arch-research"
|
||||||
@@ -2500,6 +2511,18 @@ wheels = [
|
|||||||
{ url = "https://files.pythonhosted.org/packages/ee/49/1377b49de7d0c1ce41292161ea0f721913fa8722c19fb9c1e3aa0367eecb/pytest_cov-7.0.0-py3-none-any.whl", hash = "sha256:3b8e9558b16cc1479da72058bdecf8073661c7f57f7d3c5f22a1c23507f2d861", size = 22424, upload-time = "2025-09-09T10:57:00.695Z" },
|
{ url = "https://files.pythonhosted.org/packages/ee/49/1377b49de7d0c1ce41292161ea0f721913fa8722c19fb9c1e3aa0367eecb/pytest_cov-7.0.0-py3-none-any.whl", hash = "sha256:3b8e9558b16cc1479da72058bdecf8073661c7f57f7d3c5f22a1c23507f2d861", size = 22424, upload-time = "2025-09-09T10:57:00.695Z" },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "pytest-mock"
|
||||||
|
version = "3.15.1"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
dependencies = [
|
||||||
|
{ name = "pytest" },
|
||||||
|
]
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/68/14/eb014d26be205d38ad5ad20d9a80f7d201472e08167f0bb4361e251084a9/pytest_mock-3.15.1.tar.gz", hash = "sha256:1849a238f6f396da19762269de72cb1814ab44416fa73a8686deac10b0d87a0f", size = 34036, upload-time = "2025-09-16T16:37:27.081Z" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/5a/cc/06253936f4a7fa2e0f48dfe6d851d9c56df896a9ab09ac019d70b760619c/pytest_mock-3.15.1-py3-none-any.whl", hash = "sha256:0a25e2eb88fe5168d535041d09a4529a188176ae608a6d249ee65abc0949630d", size = 10095, upload-time = "2025-09-16T16:37:25.734Z" },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "python-dateutil"
|
name = "python-dateutil"
|
||||||
version = "2.9.0.post0"
|
version = "2.9.0.post0"
|
||||||
|
|||||||
Reference in New Issue
Block a user