From 362a7483e68c5d79c09f502f95134b1491085db8 Mon Sep 17 00:00:00 2001 From: Sergey Penkovsky Date: Fri, 11 Jul 2025 12:21:33 +0300 Subject: [PATCH] feat: implement bpe algorithm --- README.md | 78 +++++++++++- doc/bpe_algorithm.drawio | 52 ++++++++ doc/bpe_algorithm.md | 136 ++++++++++++++++++++ example/__init__.py | 0 example/example_bpe.py | 60 +++++++++ pyproject.toml | 41 ++++++ simple_llm/__init__.py | 0 simple_llm/tokenizer/__init__.py | 0 simple_llm/tokenizer/bpe_interface.py | 39 ++++++ simple_llm/tokenizer/optimize_bpe.py | 147 ++++++++++++++++++++++ simple_llm/tokenizer/simple_bpe.py | 60 +++++++++ tests/__init__.py | 0 tests/conftest.py | 13 ++ tests/integration/test_bpe_integration.py | 35 ++++++ tests/test_bpe.py | 54 ++++++++ 15 files changed, 714 insertions(+), 1 deletion(-) create mode 100644 doc/bpe_algorithm.drawio create mode 100644 doc/bpe_algorithm.md create mode 100644 example/__init__.py create mode 100644 example/example_bpe.py create mode 100644 pyproject.toml create mode 100644 simple_llm/__init__.py create mode 100644 simple_llm/tokenizer/__init__.py create mode 100644 simple_llm/tokenizer/bpe_interface.py create mode 100644 simple_llm/tokenizer/optimize_bpe.py create mode 100644 simple_llm/tokenizer/simple_bpe.py create mode 100644 tests/__init__.py create mode 100644 tests/conftest.py create mode 100644 tests/integration/test_bpe_integration.py create mode 100644 tests/test_bpe.py diff --git a/README.md b/README.md index 149f2e2..8ccc306 100644 --- a/README.md +++ b/README.md @@ -1 +1,77 @@ -# simple-llm \ No newline at end of file +# Simple LLM Tokenizer + +Простой и эффективный токенизатор для языковых моделей на основе BPE (Byte Pair Encoding) + +## Описание проекта + +Проект предоставляет реализации алгоритма BPE (Byte Pair Encoding) для токенизации текста: +- `SimpleBPE` - базовая версия +- `OptimizeBPE` - оптимизированная версия с улучшенной производительностью + +Основные возможности: +- Обучение на любом тексте (поддержка кириллицы и других алфавитов) +- Гибкая настройка размера словаря +- Простота интеграции в существующие проекты + +## Установка + +1. Склонируйте репозиторий: +```bash +git clone https://github.com/yourusername/simple-llm.git +cd simple-llm +``` + +2. Установите пакет: +```bash +pip install -e . +``` + +## Быстрый старт + +```python +from simple_llm.tokenizer import SimpleBPE + +# Инициализация и обучение +text = "мама мыла раму, папа пил какао" +bpe = SimpleBPE(vocab_size=50) +bpe.fit(text) + +# Токенизация +tokens = bpe.tokenize(text) +print(tokens) +``` + +## Интеграция в проект + +Добавьте в ваш `requirements.txt`: +``` +git+https://github.com/yourusername/simple-llm.git +``` + +Или установите напрямую: +```bash +pip install git+https://github.com/yourusername/simple-llm.git +``` + +## Примеры + +Дополнительные примеры использования смотрите в папке [example](/example): +- Сравнение SimpleBPE и OptimizeBPE +- Работа с разными языками +- Настройка параметров токенизации + +## Разработка + +Для запуска тестов: +```bash +pytest tests/ +``` + +Для внесения изменений установите зависимости разработки: +```bash +pip install -e ".[dev]" +``` + +## Лицензия + +Проект распространяется под лицензией MIT. Подробнее см. [LICENSE](LICENSE). diff --git a/doc/bpe_algorithm.drawio b/doc/bpe_algorithm.drawio new file mode 100644 index 0000000..01fb2c7 --- /dev/null +++ b/doc/bpe_algorithm.drawio @@ -0,0 +1,52 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/doc/bpe_algorithm.md b/doc/bpe_algorithm.md new file mode 100644 index 0000000..4219e74 --- /dev/null +++ b/doc/bpe_algorithm.md @@ -0,0 +1,136 @@ +# Byte Pair Encoding (BPE) Algorithm + +## Введение + +Byte Pair Encoding (BPE) - это алгоритм компрессии данных, адаптированный для токенизации текста в обработке естественного языка. В контексте языковых моделей BPE используется для создания эффективного словаря подстрок (токенов). + +## Основные понятия + +- **Токен** - элементарная единица текста (символ или последовательность символов) +- **Словарь** - набор уникальных токенов, используемых для представления текста +- **Частота пары** - количество раз, когда два токена встречаются вместе в тексте + +## Алгоритм работы + +### 1. Инициализация +```python +Исходный текст → Разбить на символы → Первоначальный словарь +``` +Пример: +``` +"мама" → ['м', 'а', 'м', 'а'] +``` + +### 2. Основной цикл +```mermaid +graph TD + A[Подсчет частот пар] --> B[Выбор наиболее частой пары] + B --> C[Создание нового токена] + C --> D[Обновление последовательности] + D --> E{Достигнут лимит словаря?} + E -->|Нет| A + E -->|Да| F[Конец] +``` + +### 3. Детализация шагов + +#### Шаг 1: Подсчет частот пар +Для текущей последовательности токенов подсчитываем все пары соседних токенов: +``` +Текст: "мама мыла" +Токены: ['м', 'а', 'м', 'а', ' ', 'м', 'ы', 'л', 'а'] +Пары: ('м','а'), ('а','м'), ('м','а'), ('а',' '), (' ','м'), ('м','ы'), ('ы','л'), ('л','а') +``` + +#### Шаг 2: Выбор пары для слияния +Находим пару с максимальной частотой. При равенстве частот выбираем пару, которая встречается раньше в тексте. + +#### Шаг 3: Слияние +Объединяем выбранную пару в новый токен и заменяем все её вхождения в тексте: +``` +Выбранная пара: ('м', 'а') +Новый токен: 'ма' +Обновленная последовательность: ['ма', 'ма', ' ', 'м', 'ы', 'л', 'а'] +``` + +#### Шаг 4: Обновление словаря +Добавляем новый токен в словарь: +``` +Словарь: ['м', 'а', ' ', 'ы', 'л', 'ма'] +``` + +### 4. Критерии остановки + +1. Достижение заданного размера словаря +2. Отсутствие пар для слияния (все возможные пары уже добавлены) +3. Достижение максимального числа итераций + +## Псевдокод + +```python +def train_bpe(text, vocab_size): + # Инициализация + tokens = list(text) + vocab = set(tokens) + + while len(vocab) < vocab_size: + # Подсчет пар + pairs = get_pairs(tokens) + if not pairs: + break + + # Выбор наиболее частой пары + best_pair = max(pairs, key=pairs.get) + + # Слияние + new_tokens = [] + i = 0 + while i < len(tokens): + if i < len(tokens)-1 and (tokens[i], tokens[i+1]) == best_pair: + new_tokens.append(best_pair[0] + best_pair[1]) + i += 2 + else: + new_tokens.append(tokens[i]) + i += 1 + tokens = new_tokens + + # Обновление словаря + vocab.add(best_pair[0] + best_pair[1]) + + return vocab +``` + +## Пример работы + +**Исходный текст**: "мама мыла раму" + +**Итерация 1**: +- Пара ('м','а') встречается 2 раза +- Новый токен: 'ма' +- Текст: ['ма', 'ма', ' ', 'м', 'ы', 'л', 'а', ' ', 'р', 'а', 'м', 'у'] + +**Итерация 2**: +- Пара ('ма',' ') встречается 1 раз +- Новый токен: 'ма ' +- Текст: ['ма ', 'ма', 'мы', 'л', 'а', ' ', 'р', 'а', 'м', 'у'] + +**Результирующий словарь** (частично): +['м', 'а', ' ', 'ы', 'л', 'р', 'у', 'ма', 'ма ', 'мы'] + +## Применение в языковых моделях + +1. Эффективное представление редких слов +2. Снижение размерности входных данных +3. Возможность обработки OOV (Out-of-Vocabulary) слов + +## Ограничения + +1. Чувствительность к регистру (можно решить предварительной нормализацией) +2. Зависимость от обучающего корпуса +3. Не всегда выделяет лингвистически осмысленные морфемы + +## Дополнительные материалы + +1. [Original BPE paper](https://arxiv.org/abs/1508.07909) +2. [BPE in HuggingFace](https://huggingface.co/docs/transformers/tokenizer_summary) +3. [Practical guide to BPE](https://towardsdatascience.com/byte-pair-encoding-subword-based-tokenization-algorithm-77828a70bee0) diff --git a/example/__init__.py b/example/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/example/example_bpe.py b/example/example_bpe.py new file mode 100644 index 0000000..2e31cfc --- /dev/null +++ b/example/example_bpe.py @@ -0,0 +1,60 @@ +from simple_llm.tokenizer.simple_bpe import SimpleBPE +from simple_llm.tokenizer.optimize_bpe import OptimizeBPE +import time + +def tokenize_manually(text, vocab): + """Простая ручная токенизация по словарю""" + tokens = [] + i = 0 + n = len(text) + while i < n: + found = False + # Ищем самый длинный возможный токен из словаря + for l in range(min(4, n-i), 0, -1): # проверяем токены длиной до 4 символов + if text[i:i+l] in vocab: + tokens.append(text[i:i+l]) + i += l + found = True + break + if not found: # если токен не найден, берем один символ + tokens.append(text[i]) + i += 1 + return tokens + +def run_example(text, vocab_size=30): + print("\n=== Тестирование токенизаторов ===") + print(f"Исходный текст: '{text}'\n") + + # Simple BPE + start = time.time() + simple_bpe = SimpleBPE(vocab_size=vocab_size) + simple_bpe.fit(text) + simple_time = time.time() - start + + print("SimpleBPE:") + print(f"Время обучения: {simple_time:.4f} сек") + print(f"Размер словаря: {len(simple_bpe.vocab)}") + print(f"Словарь: {simple_bpe.vocab}") + print(f"Ручная токенизация: {tokenize_manually(text, simple_bpe.vocab)}\n") + + # Optimize BPE + start = time.time() + opt_bpe = OptimizeBPE(vocab_size=vocab_size) + opt_bpe.fit(text) + opt_time = time.time() - start + + print("OptimizeBPE:") + print(f"Время обучения: {opt_time:.4f} сек") + print(f"Размер словаря: {len(opt_bpe.vocab)}") + print(f"Словарь: {opt_bpe.vocab}") + print(f"Ручная токенизация: {tokenize_manually(text, opt_bpe.vocab)}\n") + + if opt_time > 0: + print(f"Оптимизированная версия быстрее в {simple_time/opt_time:.1f} раз\n") + +if __name__ == "__main__": + text1 = "мама мыла раму, папа пил какао" + text2 = "коты бегают быстро, собаки лают громко" + + run_example(text1) + run_example(text2) diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..6f560c7 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,41 @@ +[build-system] +requires = ["setuptools>=61.0", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "simple-llm" +version = "0.1.0" +description = "Simple BPE tokenizer implementation for educational purposes" +readme = "README.md" +authors = [ + { name = "Sergey Penkovsky", email = "sergey.penkovsky@gmail.com" }, +] +license = { text = "MIT" } +classifiers = [ + "Development Status :: 3 - Alpha", + "Intended Audience :: Education", + "License :: OSI Approved :: MIT License", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Topic :: Scientific/Engineering :: Artificial Intelligence", +] +requires-python = ">=3.8" +dependencies = [] + +[project.urls] +Homepage = "https://github.com/pese-git/simple-llm" + +[tool.setuptools.packages.find] +where = ["."] +include = ["simple_llm*"] +exclude = ["tests*", "example*"] + +[project.optional-dependencies] +dev = [ + "pytest>=7.0", + "black>=23.0", +] diff --git a/simple_llm/__init__.py b/simple_llm/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/simple_llm/tokenizer/__init__.py b/simple_llm/tokenizer/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/simple_llm/tokenizer/bpe_interface.py b/simple_llm/tokenizer/bpe_interface.py new file mode 100644 index 0000000..d4cfe8d --- /dev/null +++ b/simple_llm/tokenizer/bpe_interface.py @@ -0,0 +1,39 @@ +from abc import ABC, abstractmethod +from typing import List, Dict + +class BPE(ABC): + """ + Реализация алгоритма токенизации Byte Pair Encoding (BPE). + + BPE — это итеративный алгоритм, последовательно объединяющий наиболее частые пары символов/токенов, + чтобы построить эффективный словарь для работы с текстом: токенизации, обучения языковой модели и т.п. + + Аргументы конструктора: + vocab_size (int): Желаемый размер итогового словаря токенов (включая отдельные символы и составные токены). + + Атрибуты: + vocab (List[str]): Список токенов в порядке их получения (сначала символы, затем новые пары). + token2id (Dict[str, int]): Словарь преобразования токена в его индекс. + id2token (Dict[int, str]): Обратный словарь преобразования индекса в токен. + """ + def __init__(self, vocab_size: int): + """ + Инициализация BPE токенизатора. + + Args: + vocab_size (int): Размер словаря, к которому будет расширяться BPE. + """ + self.vocab_size = vocab_size + self.vocab: List[str] = [] + self.token2id: Dict[str, int] = {} + self.id2token: Dict[int, str] = {} + + @abstractmethod + def fit(self, text: str): + pass + + def encode(self, text: str): + raise NotImplementedError("Implement in subclass if needed.") + + def decode(self, ids: list[int]): + raise NotImplementedError("Implement in subclass if needed.") diff --git a/simple_llm/tokenizer/optimize_bpe.py b/simple_llm/tokenizer/optimize_bpe.py new file mode 100644 index 0000000..104eb56 --- /dev/null +++ b/simple_llm/tokenizer/optimize_bpe.py @@ -0,0 +1,147 @@ +from .bpe_interface import BPE + +from collections import Counter +from typing import List, Tuple, Dict + +class OptimizeBPE(BPE): + + def fit(self, text: str) -> None: + """ + Обучает BPE-модель на предоставленном тексте. + + Последовательно расширяет словарь за счёт объединения наиболее частых пар токенов до достижения vocab_size. + + Args: + text (str): Исходная строка для обучения токенизатора. + """ + sequence = list(text) + self._init_vocab(sequence) + pair_freq, pair_first_occurrence = self._get_pair_stats(sequence) + + while len(self.vocab) < self.vocab_size and pair_freq: + pair_to_merge = self._select_pair_to_merge(pair_freq, pair_first_occurrence) + new_token = pair_to_merge[0] + pair_to_merge[1] + + if new_token in self.vocab: + # Защита от зацикливания: пара уже была добавлена как новый токен. + del pair_freq[pair_to_merge] + continue + + self.vocab.append(new_token) + sequence, pair_freq, pair_first_occurrence = self._merge_pair( + sequence, pair_to_merge, new_token, pair_freq + ) + + self._build_token_dicts() + + def _init_vocab(self, sequence: List[str]) -> None: + """ + Формирует стартовый словарь уникальных символов из последовательности, отсортированный по символам. + + Args: + sequence (List[str]): Исходная последовательность символов. + """ + self.vocab = sorted(set(sequence)) + + def _get_pair_stats(self, sequence: List[str]) -> Tuple[Counter, Dict[Tuple[str, str], int]]: + """ + Вычисляет частоты появления и индексы первого появления всех пар соседних токенов в последовательности. + + Args: + sequence (List[str]): Текущая последовательность токенов. + + Returns: + Tuple[Counter, Dict[Tuple[str, str], int]]: + - Counter по всем парам (их частоты), + - Словарь первых индексов появления каждой пары. + """ + pair_freq = Counter() + pair_first_occurrence = {} + for i in range(len(sequence) - 1): + pair = (sequence[i], sequence[i + 1]) + pair_freq[pair] += 1 + if pair not in pair_first_occurrence: + pair_first_occurrence[pair] = i + return pair_freq, pair_first_occurrence + + def _select_pair_to_merge(self, pair_freq: Counter, pair_first_occurrence: Dict[Tuple[str, str], int]) -> Tuple[str, str]: + """ + Выбирает следующую пару для слияния: + приоритет — самая частая; если таких несколько — та, которая встречается раньше других (наименьший индекс появления). + + Args: + pair_freq (Counter): Частоты всех пар. + pair_first_occurrence (Dict[Tuple[str, str], int]): Индексы первых появлений каждой пары. + + Returns: + Tuple[str, str]: Пара для слияния (двойка токенов). + """ + pair_to_merge, _ = max( + pair_freq.items(), + key=lambda x: (x[1], -pair_first_occurrence.get(x[0], float('inf'))) + ) + return pair_to_merge + + def _merge_pair( + self, + sequence: List[str], + pair_to_merge: Tuple[str, str], + new_token: str, + pair_freq: Counter + ) -> Tuple[List[str], Counter, Dict[Tuple[str, str], int]]: + """ + Выполняет слияние заданной пары токенов в новой последовательности, корректирует частоты пар и индексы первых появлений. + + Args: + sequence (List[str]): Текущая последовательность токенов. + pair_to_merge (Tuple[str, str]): Пара для слияния. + new_token (str): Новый токен (результат слияния). + pair_freq (Counter): Частоты текущих пар. + + Returns: + Tuple[List[str], Counter, Dict[Tuple[str, str], int]]: + - Новая последовательность, + - Обновлённые частоты пар, + - Обновлённые индексы первых появлений пар. + """ + new_sequence = [] + i = 0 + pairs_to_decrement = Counter() + pairs_to_increment = Counter() + length = len(sequence) + while i < length: + if i < length - 1 and (sequence[i], sequence[i + 1]) == pair_to_merge: + if i > 0: + pairs_to_decrement[(sequence[i - 1], sequence[i])] += 1 + pairs_to_increment[(sequence[i - 1], new_token)] += 1 + if i + 2 < length: + pairs_to_decrement[(sequence[i + 1], sequence[i + 2])] += 1 + pairs_to_increment[(new_token, sequence[i + 2])] += 1 + new_sequence.append(new_token) + i += 2 + else: + new_sequence.append(sequence[i]) + i += 1 + for pair, dec_count in pairs_to_decrement.items(): + pair_freq[pair] -= dec_count + if pair_freq[pair] <= 0: + del pair_freq[pair] + for pair, inc_count in pairs_to_increment.items(): + pair_freq[pair] += inc_count + # Пересчитываем первый индекс появления пар + pair_first_occurrence = {} + for idx in range(len(new_sequence) - 1): + pair = (new_sequence[idx], new_sequence[idx + 1]) + if pair not in pair_first_occurrence: + pair_first_occurrence[pair] = idx + for pair in list(pair_freq.keys()): + if pair not in pair_first_occurrence: + del pair_freq[pair] + return new_sequence, pair_freq, pair_first_occurrence + + def _build_token_dicts(self) -> None: + """ + Формирует словари вида <токен, id> и по итоговому списку токенов. + """ + self.token2id = {token: idx for idx, token in enumerate(self.vocab)} + self.id2token = {idx: token for idx, token in enumerate(self.vocab)} \ No newline at end of file diff --git a/simple_llm/tokenizer/simple_bpe.py b/simple_llm/tokenizer/simple_bpe.py new file mode 100644 index 0000000..5a99de8 --- /dev/null +++ b/simple_llm/tokenizer/simple_bpe.py @@ -0,0 +1,60 @@ +from .bpe_interface import BPE + +class SimpleBPE(BPE): + def fit(self, text: str): + # 1. Получаем уникальные токены (символы) + unique_tokens = sorted(set(text)) + tokens = unique_tokens.copy() + + # 2. Разбиваем текст на токены-символы + sequence = list(text) + + # 3. Объединяем токены до достижения нужного размера словаря + while len(tokens) < self.vocab_size: + #print(f'len={len(tokens)} < {self.vocab_size}') + # Считаем частоты пар + pair_freq = {} + for i in range(len(sequence) - 1): + pair = (sequence[i], sequence[i + 1]) + #print(f'pair = {pair}') + if pair not in pair_freq: + pair_freq[pair] = 0 + pair_freq[pair] += 1 + + + #print(f'pair_freq = {pair_freq}') + if not pair_freq: + break # нет пар — выходим + + # Находим самую частую пару (в случае равенства — та, что встретилась первой) + most_frequent_pair = max(pair_freq.items(), key=lambda x: (x[1], -self._pair_first_index(sequence, x[0])))[0] + #print(most_frequent_pair) + # Создаем новый токен + new_token = most_frequent_pair[0] + most_frequent_pair[1] + #print(f"new token={new_token}") + tokens.append(new_token) + #print(f"tokens={tokens}") + + i = 0 + new_sequence = [] + + while i < len(sequence): + if i < len(sequence) - 1 and (sequence[i], sequence[i + 1]) == most_frequent_pair: + new_sequence.append(new_token) + i += 2 # пропускаем два символа — заменённую пару + else: + new_sequence.append(sequence[i]) + i += 1 + sequence = new_sequence + #break + + # 4. Создаем словари + self.vocab = tokens.copy() + self.token2id = dict(zip(tokens, range(self.vocab_size))) + self.id2token = dict(zip(range(self.vocab_size), tokens)) + + def _pair_first_index(self, sequence, pair): + for i in range(len(sequence) - 1): + if (sequence[i], sequence[i + 1]) == pair: + return i + return float('inf') \ No newline at end of file diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..473158a --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,13 @@ +import pytest +from simple_llm.tokenizer.simple_bpe import SimpleBPE +from simple_llm.tokenizer.optimize_bpe import OptimizeBPE + +@pytest.fixture(scope="session") +def large_text(): + """Генерирует большой текст для тестирования""" + return " ".join(["мама мыла раму"] * 1000) + +@pytest.fixture(params=[SimpleBPE, OptimizeBPE]) +def bpe_class(request): + """Возвращает классы BPE для тестирования""" + return request.param diff --git a/tests/integration/test_bpe_integration.py b/tests/integration/test_bpe_integration.py new file mode 100644 index 0000000..1be46e0 --- /dev/null +++ b/tests/integration/test_bpe_integration.py @@ -0,0 +1,35 @@ +import pytest +from simple_llm.tokenizer.simple_bpe import SimpleBPE +from simple_llm.tokenizer.optimize_bpe import OptimizeBPE + +def test_large_text_processing(bpe_class, large_text): + """Тест обработки большого текста""" + bpe = bpe_class(vocab_size=100) + bpe.fit(large_text) + + # Проверки + assert 50 < len(bpe.vocab) <= 100 + assert all(len(token) <= 4 for token in bpe.vocab) # Проверка на разумную длину токенов + assert "мама" in bpe.vocab or "ма" in bpe.vocab # Проверка на наличие ожидаемых токенов + +def test_special_characters(bpe_class): + """Тест обработки специальных символов""" + text = "!@#$%^&*()_+1234567890" + bpe = bpe_class(vocab_size=30) + bpe.fit(text) + + # Проверки + assert len(bpe.vocab) > 10 + for char in set(text): + assert any(char in token for token in bpe.vocab) # Каждый символ должен быть в каком-то токене + +def test_unicode_characters(bpe_class): + """Тест обработки unicode-символов""" + text = "日本語 한국어 русский English" + bpe = bpe_class(vocab_size=50) + bpe.fit(text) + + # Проверки + assert len(bpe.vocab) > 20 + assert any("日" in token for token in bpe.vocab) + assert any("한" in token for token in bpe.vocab) diff --git a/tests/test_bpe.py b/tests/test_bpe.py new file mode 100644 index 0000000..bb080ea --- /dev/null +++ b/tests/test_bpe.py @@ -0,0 +1,54 @@ +import pytest +from simple_llm.tokenizer.simple_bpe import SimpleBPE +from simple_llm.tokenizer.optimize_bpe import OptimizeBPE + +class TestBPE: + @pytest.fixture(params=[SimpleBPE, OptimizeBPE]) + def bpe_class(self, request): + return request.param + + def test_initialization(self, bpe_class): + """Тест инициализации BPE-токенизатора""" + bpe = bpe_class(vocab_size=100) + assert bpe.vocab_size == 100 + assert bpe.vocab == [] + assert bpe.token2id == {} + assert bpe.id2token == {} + + def test_fit_simple_text(self, bpe_class): + """Тест обучения на простом тексте""" + text = "мама мыла раму" + bpe = bpe_class(vocab_size=20) + bpe.fit(text) + + # Проверки словаря + assert isinstance(bpe.vocab, list) + assert len(bpe.vocab) > 0 + assert len(bpe.vocab) <= 20 + assert all(isinstance(token, str) for token in bpe.vocab) + + # Проверка словарей + assert len(bpe.vocab) == len(bpe.token2id) + assert len(bpe.vocab) == len(bpe.id2token) + + # Проверка соответствия токенов и ID + for token in bpe.vocab: + assert bpe.token2id[token] == bpe.vocab.index(token) + assert bpe.id2token[bpe.token2id[token]] == token + + @pytest.mark.parametrize("text,expected_size", [ + ("", 0), + ("а", 1), + ("ааааа", 2) # Должны быть 'а' и 'аа' + ]) + def test_edge_cases(self, bpe_class, text, expected_size): + """Тест граничных случаев""" + bpe = bpe_class(vocab_size=10) + bpe.fit(text) + assert len(bpe.vocab) == expected_size + + def test_duplicate_protection(self, bpe_class): + """Тест защиты от дубликатов токенов""" + bpe = bpe_class(vocab_size=50) + bpe.fit("аааааааааа" * 100) # Много повторений + assert len(bpe.vocab) == len(set(bpe.vocab))