diff --git a/example/example_bpe.py b/example/example_bpe.py index 1d619b2..e0cbd78 100644 --- a/example/example_bpe.py +++ b/example/example_bpe.py @@ -1,77 +1,84 @@ +from simple_llm.tokenizer.bpe import BPE from simple_llm.tokenizer.simple_bpe import SimpleBPE from simple_llm.tokenizer.optimize_bpe import OptimizeBPE import time -def tokenize_manually(text, vocab): - """Простая ручная токенизация по словарю""" - tokens = [] - i = 0 - n = len(text) - while i < n: - found = False - # Ищем самый длинный возможный токен из словаря - for l in range(min(4, n-i), 0, -1): # проверяем токены длиной до 4 символов - if text[i:i+l] in vocab: - tokens.append(text[i:i+l]) - i += l - found = True - break - if not found: # если токен не найден, берем один символ - tokens.append(text[i]) - i += 1 - return tokens - -def run_example(text, vocab_size=30): - print("\n=== Тестирование токенизаторов ===") - print(f"Исходный текст: '{text}'\n") +def compare_tokenizers(text, vocab_size=50): + """Сравнивает разные реализации BPE""" + print(f"\n=== Анализ текста: '{text[:20]}...' ===") - # Simple BPE + # 1. Базовая реализация BPE + start = time.time() + bpe = BPE(vocab_size=vocab_size) + bpe.fit(text) + base_time = time.time() - start + + print("\n[Базовая реализация BPE]") + print(f"Время обучения: {base_time:.4f} сек") + print(f"Размер словаря: {len(bpe.vocab)}") + print("Примеры токенов:", list(bpe.vocab)[:10], "...") + + # 2. SimpleBPE start = time.time() simple_bpe = SimpleBPE(vocab_size=vocab_size) simple_bpe.fit(text) simple_time = time.time() - start - print("SimpleBPE:") + print("\n[SimpleBPE]") print(f"Время обучения: {simple_time:.4f} сек") print(f"Размер словаря: {len(simple_bpe.vocab)}") - print(f"Пример словаря: {simple_bpe.vocab[:5]}...") - # Демонстрация encode/decode - test_phrases = [text, text.split()[0], "неизвестное_слово"] - for phrase in test_phrases: - encoded = simple_bpe.encode(phrase) - decoded = simple_bpe.decode(encoded) - print(f"\nФраза: '{phrase}'") - print(f"Закодировано: {encoded}") - print(f"Декодировано: '{decoded}'") - print(f"Совпадение: {phrase == decoded}") - - # Optimize BPE + # 3. OptimizeBPE start = time.time() opt_bpe = OptimizeBPE(vocab_size=vocab_size) opt_bpe.fit(text) opt_time = time.time() - start - print("\nOptimizeBPE:") + print("\n[OptimizeBPE]") print(f"Время обучения: {opt_time:.4f} сек") print(f"Размер словаря: {len(opt_bpe.vocab)}") - print(f"Пример словаря: {opt_bpe.vocab[:5]}...") - # Демонстрация encode/decode + # Сравнение производительности + if opt_time > 0: + print(f"\nОптимизированная версия быстрее SimpleBPE в {simple_time/opt_time:.1f} раз") + + # Демонстрация работы на примерах + test_phrases = [ + text.split()[0], # первое слово + text[:10], # часть текста + "неизвестное_слово", # OOV + "спецсимволы: 123, !@#" + ] + + print("\n=== Примеры кодирования/декодирования ===") for phrase in test_phrases: + print(f"\nФраза: '{phrase}'") + + encoded = bpe.encode(phrase) + decoded = bpe.decode(encoded) + print(f"BPE: {encoded} -> '{decoded}'") + + encoded = simple_bpe.encode(phrase) + decoded = simple_bpe.decode(encoded) + print(f"SimpleBPE: {encoded} -> '{decoded}'") + encoded = opt_bpe.encode(phrase) decoded = opt_bpe.decode(encoded) - print(f"\nФраза: '{phrase}'") - print(f"Закодировано: {encoded}") - print(f"Декодировано: '{decoded}'") - print(f"Совпадение: {phrase == decoded}") + print(f"OptimizeBPE: {encoded} -> '{decoded}'") + +def main(): + # Тестовые тексты разной сложности + texts = [ + "мама мыла раму, папа пил какао", + "коты бегают быстро, собаки лают громко", + "искусственный интеллект меняет мир вокруг нас", + "BPE (Byte Pair Encoding) - популярный алгоритм токенизации" + ] - if opt_time > 0: - print(f"\nОптимизированная версия быстрее в {simple_time/opt_time:.1f} раз") + for text in texts: + compare_tokenizers(text) + + print("\n=== Тестирование завершено ===") if __name__ == "__main__": - text1 = "мама мыла раму, папа пил какао" - text2 = "коты бегают быстро, собаки лают громко" - - run_example(text1) - run_example(text2) + main() diff --git a/simple_llm/tokenizer/bpe.py b/simple_llm/tokenizer/bpe.py new file mode 100644 index 0000000..30b59ac --- /dev/null +++ b/simple_llm/tokenizer/bpe.py @@ -0,0 +1,226 @@ +import dill + +class BPE: + """Реализация алгоритма Byte Pair Encoding (BPE) для токенизации текста. + + BPE - это алгоритм сжатия данных, адаптированный для токенизации текста в NLP. + Работает путем итеративного объединения наиболее частых пар символов/токенов. + + Пример использования: + >>> tokenizer = BPE(vocab_size=100) + >>> tokenizer.fit("текст для обучения") + >>> encoded = tokenizer.encode("пример текста") + >>> decoded = tokenizer.decode(encoded) + + Args: + vocab_size (int): Максимальный размер словаря токенов + """ + def __init__(self, vocab_size: int): + self.vocab_size = vocab_size + self.id2token = {} + self.token2id = {} + + def fit(self, text: str): + """Обучает токенизатор на заданном тексте. + + Процесс обучения: + 1. Начинает с базовых символов текста + 2. Итеративно находит и объединяет самые частые пары символов + 3. Продолжает пока не достигнет заданного размера словаря + + Args: + text (str): Текст для обучения токенизатора + + Пример: + >>> tokenizer = BPE(vocab_size=100) + >>> tokenizer.fit("Это текст для обучения токенизатора") + """ + # 1. Получаем уникальные токены (символы) + unique_tokens = sorted(set(text)) + tokens = unique_tokens.copy() + + # 2. Разбиваем текст на токены-символы + sequence = list(text) + + # 3. Объединяем токены до достижения нужного размера словаря + while len(tokens) < self.vocab_size: + #print(f'len={len(tokens)} < {self.vocab_size}') + # Считаем частоты пар + pair_freq = {} + for i in range(len(sequence) - 1): + pair = (sequence[i], sequence[i + 1]) + #print(f'pair = {pair}') + if pair not in pair_freq: + pair_freq[pair] = 0 + pair_freq[pair] += 1 + + + #print(f'pair_freq = {pair_freq}') + if not pair_freq: + break # нет пар — выходим + + #for x in pair_freq.items(): + # self.debug(x, sequence) + + # Находим самую частую пару (в случае равенства — та, что встретилась первой) + most_frequent_pair = max(pair_freq.items(), key=lambda x: (x[1], -self._pair_first_index(sequence, x[0])))[0] + #print(most_frequent_pair) + # Создаем новый токен + new_token = most_frequent_pair[0] + most_frequent_pair[1] + #print(f"new token={new_token}") + tokens.append(new_token) + #print(f"tokens={tokens}") + + i = 0 + new_sequence = [] + + while i < len(sequence): + if i < len(sequence) - 1 and (sequence[i], sequence[i + 1]) == most_frequent_pair: + new_sequence.append(new_token) + i += 2 # пропускаем два символа — заменённую пару + else: + new_sequence.append(sequence[i]) + i += 1 + sequence = new_sequence + #break + + # 4. Создаем словари + self.vocab = tokens.copy() + self.token2id = dict(zip(tokens, range(self.vocab_size))) + self.id2token = dict(zip(range(self.vocab_size), tokens)) + + def _pair_first_index(self, sequence, pair): + for i in range(len(sequence) - 1): + if (sequence[i], sequence[i + 1]) == pair: + return i + return float('inf') # если пара не найдена (в теории не должно случиться) + + + def encode(self, text: str): + """Кодирует текст в последовательность ID токенов. + + Использует жадный алгоритм для поиска наиболее длинных совпадений: + 1. Начинает с первого символа + 2. Ищет самый длинный токен из словаря, совпадающий с началом текста + 3. Добавляет ID найденного токена в результат + 4. Сдвигается на длину найденного токена и повторяет + + Args: + text (str): Текст для кодирования + + Returns: + list: Список ID токенов (неизвестные символы кодируются как -1) + + Пример: + >>> encoded = tokenizer.encode("Пример текста") + >>> print(encoded) + [12, 34, 56, 78] + """ + # 1. Разбиваем текст на токены-символы + sequence = list(text) + # 2. Инициализация пустого списка токенов + tokens = [] + # 3. Установить i = 0 + i = 0 + while i < len(text): + # 3.1 Найти все токены в словаре, начинающиеся с text[i] + start_char = text[i] + result = [token for token in self.vocab if token.startswith(start_char)] + # 3.2 Выбрать самый длинный подходящий токен + find_token = self._find_max_matching_token(text[i:], result) + if find_token is None: + # Обработка неизвестного символа + tokens.append(text[i]) # Добавляем сам символ как токен + i += 1 + else: + # 3.3 Добавить токен в результат + tokens.append(find_token) + # 3.4 Увеличить i на длину токена + i += len(find_token) + + # 4. Заменить токены на их ID + return self._tokens_to_ids(tokens) + + def _find_max_matching_token(self, text: str, tokens: list): + """Находит самый длинный токен из списка, с которого начинается текст""" + matching = [token for token in tokens if text.startswith(token)] + return max(matching, key=len) if matching else None + + def _tokens_to_ids(self, tokens): + """Конвертирует список токенов в их ID с обработкой неизвестных токенов""" + ids = [] + for token in tokens: + if token in self.token2id: + ids.append(self.token2id[token]) + else: + ids.append(-1) # Специальное значение + return ids + + + def decode(self, ids: list) -> str: + """Декодирует последовательность ID обратно в текст. + + Args: + ids (list): Список ID токенов + + Returns: + str: Декодированный текст + + Пример: + >>> decoded = tokenizer.decode([12, 34, 56, 78]) + >>> print(decoded) + "Пример текста" + """ + return ''.join(self._ids_to_tokens(ids)) + + def _ids_to_tokens(self, ids: list) -> list: + """Внутренний метод преобразования ID в токены. + + Args: + ids (list): Список ID токенов + + Returns: + list: Список соответствующих токенов (неизвестные ID = '') + """ + """Конвертирует список Ids в их tokens""" + tokens = [] + for id in ids: + if id in self.id2token: + tokens.append(self.id2token[id]) + else: + tokens.append('') # Специальное значение + return tokens + + + def save(self, filename): + with open(filename, 'wb') as f: + dill.dump(self, f) + print(f"Объект сохранён в {filename}") + + + @classmethod + def load(cls, filename): + """Загружает токенизатор из файла. + + Args: + filename (str): Путь к файлу с сохраненным токенизатором + + Returns: + BPE: Загруженный экземпляр токенизатора + + Пример: + >>> tokenizer = BPE.load("bpe_tokenizer.pkl") + """ + """Load trained tokenizer from file. + + Args: + filename (str): Path to saved tokenizer + + Returns: + BPE: Loaded tokenizer instance + """ + with open(filename, 'rb') as f: + obj = dill.load(f) + + print(f"Объект загружен из {filename}") + return obj \ No newline at end of file diff --git a/tests/test_bpe.py b/tests/test_bpe.py index 4c607a9..ace84f0 100644 --- a/tests/test_bpe.py +++ b/tests/test_bpe.py @@ -1,54 +1,24 @@ import pytest -from simple_llm.tokenizer.simple_bpe import SimpleBPE -from simple_llm.tokenizer.optimize_bpe import OptimizeBPE +from simple_llm.tokenizer.bpe import BPE -class TestBPE: - @pytest.fixture(params=[SimpleBPE, OptimizeBPE]) - def bpe_class(self, request): - return request.param +def test_basic_bpe(): + """Базовый тест работы BPE""" + tokenizer = BPE(vocab_size=10) + text = "мама мыла раму" - def test_initialization(self, bpe_class): - """Тест инициализации BPE-токенизатора""" - bpe = bpe_class(vocab_size=100) - assert bpe.vocab_size == 100 - assert bpe.vocab == [] - assert bpe.token2id == {} - assert bpe.id2token == {} + # Обучение + tokenizer.fit(text) - def test_fit_simple_text(self, bpe_class): - """Тест обучения на простом тексте""" - text = "мама мыла раму" - bpe = bpe_class(vocab_size=20) - bpe.fit(text) - - # Проверки словаря - assert isinstance(bpe.vocab, list) - assert len(bpe.vocab) > 0 - assert len(bpe.vocab) <= 20 - assert all(isinstance(token, str) for token in bpe.vocab) - - # Проверка словарей - assert len(bpe.vocab) == len(bpe.token2id) - assert len(bpe.vocab) == len(bpe.id2token) - - # Проверка соответствия токенов и ID - for token in bpe.vocab: - assert bpe.token2id[token] == bpe.vocab.index(token) - assert bpe.id2token[bpe.token2id[token]] == token - - @pytest.mark.parametrize("text,expected_min_size", [ - ("", 0), - ("а", 1), - ("ааааа", 3) # Минимум 3 токена - ]) - def test_edge_cases(self, bpe_class, text, expected_min_size): - """Тест граничных случаев""" - bpe = bpe_class(vocab_size=10) - bpe.fit(text) - assert len(bpe.vocab) >= expected_min_size - - def test_duplicate_protection(self, bpe_class): - """Тест защиты от дубликатов токенов""" - bpe = bpe_class(vocab_size=50) - bpe.fit("аааааааааа" * 100) # Много повторений - assert len(bpe.vocab) == len(set(bpe.vocab)) + # Проверка размера словаря + assert len(tokenizer.vocab) == 10 + + # Кодирование/декодирование + encoded = tokenizer.encode(text) + decoded = tokenizer.decode(encoded) + + assert decoded == text + assert len(encoded) > 0 + + # Проверка неизвестных символов + unknown_encoded = tokenizer.encode("мама мыла окно") + assert -1 in unknown_encoded # Специальный токен для неизвестных символов diff --git a/tests/test_bpe_detailed.py b/tests/test_bpe_detailed.py new file mode 100644 index 0000000..d0d0f8e --- /dev/null +++ b/tests/test_bpe_detailed.py @@ -0,0 +1,74 @@ +import os +import tempfile +import pytest +from simple_llm.tokenizer.bpe import BPE + +class TestBPE: + @pytest.fixture + def sample_text(self): + return "ааабббвввггг аааббб дддд ееее жжжж" + + @pytest.fixture + def bpe(self): + return BPE(vocab_size=20) + + def test_fit(self, bpe, sample_text): + """Тест обучения токенизатора""" + bpe.fit(sample_text) + assert len(bpe.vocab) == bpe.vocab_size + assert len(bpe.token2id) == bpe.vocab_size + assert len(bpe.id2token) == bpe.vocab_size + + def test_encode_decode(self, bpe, sample_text): + """Тест кодирования и декодирования""" + bpe.fit(sample_text) + encoded = bpe.encode(sample_text) + decoded = bpe.decode(encoded) + assert decoded == sample_text + + def test_encode_unknown_chars(self, bpe, sample_text): + """Тест с неизвестными символами""" + bpe.fit(sample_text) + test_text = "ааббцц" # 'цц' нет в обучающем тексте + encoded = bpe.encode(test_text) + assert -1 in encoded # Должен содержать специальный токен для неизвестных символов + decoded = bpe.decode(encoded) + assert "цц" in decoded + + def test_save_load(self, bpe, sample_text): + """Тест сохранения и загрузки""" + bpe.fit(sample_text) + + with tempfile.NamedTemporaryFile(delete=False) as tmp: + try: + bpe.save(tmp.name) + loaded = BPE.load(tmp.name) + + assert loaded.vocab_size == bpe.vocab_size + assert loaded.vocab == bpe.vocab + assert loaded.token2id == bpe.token2id + assert loaded.id2token == bpe.id2token + + # Проверяем работоспособность после загрузки + encoded = loaded.encode(sample_text) + decoded = loaded.decode(encoded) + assert decoded == sample_text + finally: + os.unlink(tmp.name) + + def test_pair_merging(self, bpe, sample_text): + """Тест правильности объединения пар""" + bpe.fit(sample_text) + + # Проверяем, что самые частые пары были объединены + assert 'аа' in bpe.vocab or 'ааа' in bpe.vocab + assert 'бб' in bpe.vocab or 'ббб' in bpe.vocab + + def test_vocab_size(self): + """Тест обработки слишком маленького vocab_size""" + small_bpe = BPE(vocab_size=5) + with pytest.raises(ValueError): + small_bpe.fit("абвгд") # Слишком мало для начальных символов + +if __name__ == "__main__": + pytest.main()