mirror of
https://github.com/pese-git/simple-llm.git
synced 2026-01-23 13:03:55 +00:00
Обновление BPE: добавлена документация, тесты и улучшен пример использования
This commit is contained in:
@@ -1,77 +1,84 @@
|
||||
from simple_llm.tokenizer.bpe import BPE
|
||||
from simple_llm.tokenizer.simple_bpe import SimpleBPE
|
||||
from simple_llm.tokenizer.optimize_bpe import OptimizeBPE
|
||||
import time
|
||||
|
||||
def tokenize_manually(text, vocab):
|
||||
"""Простая ручная токенизация по словарю"""
|
||||
tokens = []
|
||||
i = 0
|
||||
n = len(text)
|
||||
while i < n:
|
||||
found = False
|
||||
# Ищем самый длинный возможный токен из словаря
|
||||
for l in range(min(4, n-i), 0, -1): # проверяем токены длиной до 4 символов
|
||||
if text[i:i+l] in vocab:
|
||||
tokens.append(text[i:i+l])
|
||||
i += l
|
||||
found = True
|
||||
break
|
||||
if not found: # если токен не найден, берем один символ
|
||||
tokens.append(text[i])
|
||||
i += 1
|
||||
return tokens
|
||||
|
||||
def run_example(text, vocab_size=30):
|
||||
print("\n=== Тестирование токенизаторов ===")
|
||||
print(f"Исходный текст: '{text}'\n")
|
||||
def compare_tokenizers(text, vocab_size=50):
|
||||
"""Сравнивает разные реализации BPE"""
|
||||
print(f"\n=== Анализ текста: '{text[:20]}...' ===")
|
||||
|
||||
# Simple BPE
|
||||
# 1. Базовая реализация BPE
|
||||
start = time.time()
|
||||
bpe = BPE(vocab_size=vocab_size)
|
||||
bpe.fit(text)
|
||||
base_time = time.time() - start
|
||||
|
||||
print("\n[Базовая реализация BPE]")
|
||||
print(f"Время обучения: {base_time:.4f} сек")
|
||||
print(f"Размер словаря: {len(bpe.vocab)}")
|
||||
print("Примеры токенов:", list(bpe.vocab)[:10], "...")
|
||||
|
||||
# 2. SimpleBPE
|
||||
start = time.time()
|
||||
simple_bpe = SimpleBPE(vocab_size=vocab_size)
|
||||
simple_bpe.fit(text)
|
||||
simple_time = time.time() - start
|
||||
|
||||
print("SimpleBPE:")
|
||||
print("\n[SimpleBPE]")
|
||||
print(f"Время обучения: {simple_time:.4f} сек")
|
||||
print(f"Размер словаря: {len(simple_bpe.vocab)}")
|
||||
print(f"Пример словаря: {simple_bpe.vocab[:5]}...")
|
||||
|
||||
# Демонстрация encode/decode
|
||||
test_phrases = [text, text.split()[0], "неизвестное_слово"]
|
||||
for phrase in test_phrases:
|
||||
encoded = simple_bpe.encode(phrase)
|
||||
decoded = simple_bpe.decode(encoded)
|
||||
print(f"\nФраза: '{phrase}'")
|
||||
print(f"Закодировано: {encoded}")
|
||||
print(f"Декодировано: '{decoded}'")
|
||||
print(f"Совпадение: {phrase == decoded}")
|
||||
|
||||
# Optimize BPE
|
||||
# 3. OptimizeBPE
|
||||
start = time.time()
|
||||
opt_bpe = OptimizeBPE(vocab_size=vocab_size)
|
||||
opt_bpe.fit(text)
|
||||
opt_time = time.time() - start
|
||||
|
||||
print("\nOptimizeBPE:")
|
||||
print("\n[OptimizeBPE]")
|
||||
print(f"Время обучения: {opt_time:.4f} сек")
|
||||
print(f"Размер словаря: {len(opt_bpe.vocab)}")
|
||||
print(f"Пример словаря: {opt_bpe.vocab[:5]}...")
|
||||
|
||||
# Демонстрация encode/decode
|
||||
# Сравнение производительности
|
||||
if opt_time > 0:
|
||||
print(f"\nОптимизированная версия быстрее SimpleBPE в {simple_time/opt_time:.1f} раз")
|
||||
|
||||
# Демонстрация работы на примерах
|
||||
test_phrases = [
|
||||
text.split()[0], # первое слово
|
||||
text[:10], # часть текста
|
||||
"неизвестное_слово", # OOV
|
||||
"спецсимволы: 123, !@#"
|
||||
]
|
||||
|
||||
print("\n=== Примеры кодирования/декодирования ===")
|
||||
for phrase in test_phrases:
|
||||
print(f"\nФраза: '{phrase}'")
|
||||
|
||||
encoded = bpe.encode(phrase)
|
||||
decoded = bpe.decode(encoded)
|
||||
print(f"BPE: {encoded} -> '{decoded}'")
|
||||
|
||||
encoded = simple_bpe.encode(phrase)
|
||||
decoded = simple_bpe.decode(encoded)
|
||||
print(f"SimpleBPE: {encoded} -> '{decoded}'")
|
||||
|
||||
encoded = opt_bpe.encode(phrase)
|
||||
decoded = opt_bpe.decode(encoded)
|
||||
print(f"\nФраза: '{phrase}'")
|
||||
print(f"Закодировано: {encoded}")
|
||||
print(f"Декодировано: '{decoded}'")
|
||||
print(f"Совпадение: {phrase == decoded}")
|
||||
print(f"OptimizeBPE: {encoded} -> '{decoded}'")
|
||||
|
||||
def main():
|
||||
# Тестовые тексты разной сложности
|
||||
texts = [
|
||||
"мама мыла раму, папа пил какао",
|
||||
"коты бегают быстро, собаки лают громко",
|
||||
"искусственный интеллект меняет мир вокруг нас",
|
||||
"BPE (Byte Pair Encoding) - популярный алгоритм токенизации"
|
||||
]
|
||||
|
||||
if opt_time > 0:
|
||||
print(f"\nОптимизированная версия быстрее в {simple_time/opt_time:.1f} раз")
|
||||
for text in texts:
|
||||
compare_tokenizers(text)
|
||||
|
||||
print("\n=== Тестирование завершено ===")
|
||||
|
||||
if __name__ == "__main__":
|
||||
text1 = "мама мыла раму, папа пил какао"
|
||||
text2 = "коты бегают быстро, собаки лают громко"
|
||||
|
||||
run_example(text1)
|
||||
run_example(text2)
|
||||
main()
|
||||
|
||||
226
simple_llm/tokenizer/bpe.py
Normal file
226
simple_llm/tokenizer/bpe.py
Normal file
@@ -0,0 +1,226 @@
|
||||
import dill
|
||||
|
||||
class BPE:
|
||||
"""Реализация алгоритма Byte Pair Encoding (BPE) для токенизации текста.
|
||||
|
||||
BPE - это алгоритм сжатия данных, адаптированный для токенизации текста в NLP.
|
||||
Работает путем итеративного объединения наиболее частых пар символов/токенов.
|
||||
|
||||
Пример использования:
|
||||
>>> tokenizer = BPE(vocab_size=100)
|
||||
>>> tokenizer.fit("текст для обучения")
|
||||
>>> encoded = tokenizer.encode("пример текста")
|
||||
>>> decoded = tokenizer.decode(encoded)
|
||||
|
||||
Args:
|
||||
vocab_size (int): Максимальный размер словаря токенов
|
||||
"""
|
||||
def __init__(self, vocab_size: int):
|
||||
self.vocab_size = vocab_size
|
||||
self.id2token = {}
|
||||
self.token2id = {}
|
||||
|
||||
def fit(self, text: str):
|
||||
"""Обучает токенизатор на заданном тексте.
|
||||
|
||||
Процесс обучения:
|
||||
1. Начинает с базовых символов текста
|
||||
2. Итеративно находит и объединяет самые частые пары символов
|
||||
3. Продолжает пока не достигнет заданного размера словаря
|
||||
|
||||
Args:
|
||||
text (str): Текст для обучения токенизатора
|
||||
|
||||
Пример:
|
||||
>>> tokenizer = BPE(vocab_size=100)
|
||||
>>> tokenizer.fit("Это текст для обучения токенизатора")
|
||||
"""
|
||||
# 1. Получаем уникальные токены (символы)
|
||||
unique_tokens = sorted(set(text))
|
||||
tokens = unique_tokens.copy()
|
||||
|
||||
# 2. Разбиваем текст на токены-символы
|
||||
sequence = list(text)
|
||||
|
||||
# 3. Объединяем токены до достижения нужного размера словаря
|
||||
while len(tokens) < self.vocab_size:
|
||||
#print(f'len={len(tokens)} < {self.vocab_size}')
|
||||
# Считаем частоты пар
|
||||
pair_freq = {}
|
||||
for i in range(len(sequence) - 1):
|
||||
pair = (sequence[i], sequence[i + 1])
|
||||
#print(f'pair = {pair}')
|
||||
if pair not in pair_freq:
|
||||
pair_freq[pair] = 0
|
||||
pair_freq[pair] += 1
|
||||
|
||||
|
||||
#print(f'pair_freq = {pair_freq}')
|
||||
if not pair_freq:
|
||||
break # нет пар — выходим
|
||||
|
||||
#for x in pair_freq.items():
|
||||
# self.debug(x, sequence)
|
||||
|
||||
# Находим самую частую пару (в случае равенства — та, что встретилась первой)
|
||||
most_frequent_pair = max(pair_freq.items(), key=lambda x: (x[1], -self._pair_first_index(sequence, x[0])))[0]
|
||||
#print(most_frequent_pair)
|
||||
# Создаем новый токен
|
||||
new_token = most_frequent_pair[0] + most_frequent_pair[1]
|
||||
#print(f"new token={new_token}")
|
||||
tokens.append(new_token)
|
||||
#print(f"tokens={tokens}")
|
||||
|
||||
i = 0
|
||||
new_sequence = []
|
||||
|
||||
while i < len(sequence):
|
||||
if i < len(sequence) - 1 and (sequence[i], sequence[i + 1]) == most_frequent_pair:
|
||||
new_sequence.append(new_token)
|
||||
i += 2 # пропускаем два символа — заменённую пару
|
||||
else:
|
||||
new_sequence.append(sequence[i])
|
||||
i += 1
|
||||
sequence = new_sequence
|
||||
#break
|
||||
|
||||
# 4. Создаем словари
|
||||
self.vocab = tokens.copy()
|
||||
self.token2id = dict(zip(tokens, range(self.vocab_size)))
|
||||
self.id2token = dict(zip(range(self.vocab_size), tokens))
|
||||
|
||||
def _pair_first_index(self, sequence, pair):
|
||||
for i in range(len(sequence) - 1):
|
||||
if (sequence[i], sequence[i + 1]) == pair:
|
||||
return i
|
||||
return float('inf') # если пара не найдена (в теории не должно случиться)
|
||||
|
||||
|
||||
def encode(self, text: str):
|
||||
"""Кодирует текст в последовательность ID токенов.
|
||||
|
||||
Использует жадный алгоритм для поиска наиболее длинных совпадений:
|
||||
1. Начинает с первого символа
|
||||
2. Ищет самый длинный токен из словаря, совпадающий с началом текста
|
||||
3. Добавляет ID найденного токена в результат
|
||||
4. Сдвигается на длину найденного токена и повторяет
|
||||
|
||||
Args:
|
||||
text (str): Текст для кодирования
|
||||
|
||||
Returns:
|
||||
list: Список ID токенов (неизвестные символы кодируются как -1)
|
||||
|
||||
Пример:
|
||||
>>> encoded = tokenizer.encode("Пример текста")
|
||||
>>> print(encoded)
|
||||
[12, 34, 56, 78]
|
||||
"""
|
||||
# 1. Разбиваем текст на токены-символы
|
||||
sequence = list(text)
|
||||
# 2. Инициализация пустого списка токенов
|
||||
tokens = []
|
||||
# 3. Установить i = 0
|
||||
i = 0
|
||||
while i < len(text):
|
||||
# 3.1 Найти все токены в словаре, начинающиеся с text[i]
|
||||
start_char = text[i]
|
||||
result = [token for token in self.vocab if token.startswith(start_char)]
|
||||
# 3.2 Выбрать самый длинный подходящий токен
|
||||
find_token = self._find_max_matching_token(text[i:], result)
|
||||
if find_token is None:
|
||||
# Обработка неизвестного символа
|
||||
tokens.append(text[i]) # Добавляем сам символ как токен
|
||||
i += 1
|
||||
else:
|
||||
# 3.3 Добавить токен в результат
|
||||
tokens.append(find_token)
|
||||
# 3.4 Увеличить i на длину токена
|
||||
i += len(find_token)
|
||||
|
||||
# 4. Заменить токены на их ID
|
||||
return self._tokens_to_ids(tokens)
|
||||
|
||||
def _find_max_matching_token(self, text: str, tokens: list):
|
||||
"""Находит самый длинный токен из списка, с которого начинается текст"""
|
||||
matching = [token for token in tokens if text.startswith(token)]
|
||||
return max(matching, key=len) if matching else None
|
||||
|
||||
def _tokens_to_ids(self, tokens):
|
||||
"""Конвертирует список токенов в их ID с обработкой неизвестных токенов"""
|
||||
ids = []
|
||||
for token in tokens:
|
||||
if token in self.token2id:
|
||||
ids.append(self.token2id[token])
|
||||
else:
|
||||
ids.append(-1) # Специальное значение
|
||||
return ids
|
||||
|
||||
|
||||
def decode(self, ids: list) -> str:
|
||||
"""Декодирует последовательность ID обратно в текст.
|
||||
|
||||
Args:
|
||||
ids (list): Список ID токенов
|
||||
|
||||
Returns:
|
||||
str: Декодированный текст
|
||||
|
||||
Пример:
|
||||
>>> decoded = tokenizer.decode([12, 34, 56, 78])
|
||||
>>> print(decoded)
|
||||
"Пример текста"
|
||||
"""
|
||||
return ''.join(self._ids_to_tokens(ids))
|
||||
|
||||
def _ids_to_tokens(self, ids: list) -> list:
|
||||
"""Внутренний метод преобразования ID в токены.
|
||||
|
||||
Args:
|
||||
ids (list): Список ID токенов
|
||||
|
||||
Returns:
|
||||
list: Список соответствующих токенов (неизвестные ID = '')
|
||||
"""
|
||||
"""Конвертирует список Ids в их tokens"""
|
||||
tokens = []
|
||||
for id in ids:
|
||||
if id in self.id2token:
|
||||
tokens.append(self.id2token[id])
|
||||
else:
|
||||
tokens.append('') # Специальное значение
|
||||
return tokens
|
||||
|
||||
|
||||
def save(self, filename):
|
||||
with open(filename, 'wb') as f:
|
||||
dill.dump(self, f)
|
||||
print(f"Объект сохранён в {filename}")
|
||||
|
||||
|
||||
@classmethod
|
||||
def load(cls, filename):
|
||||
"""Загружает токенизатор из файла.
|
||||
|
||||
Args:
|
||||
filename (str): Путь к файлу с сохраненным токенизатором
|
||||
|
||||
Returns:
|
||||
BPE: Загруженный экземпляр токенизатора
|
||||
|
||||
Пример:
|
||||
>>> tokenizer = BPE.load("bpe_tokenizer.pkl")
|
||||
"""
|
||||
"""Load trained tokenizer from file.
|
||||
|
||||
Args:
|
||||
filename (str): Path to saved tokenizer
|
||||
|
||||
Returns:
|
||||
BPE: Loaded tokenizer instance
|
||||
"""
|
||||
with open(filename, 'rb') as f:
|
||||
obj = dill.load(f)
|
||||
|
||||
print(f"Объект загружен из {filename}")
|
||||
return obj
|
||||
@@ -1,54 +1,24 @@
|
||||
import pytest
|
||||
from simple_llm.tokenizer.simple_bpe import SimpleBPE
|
||||
from simple_llm.tokenizer.optimize_bpe import OptimizeBPE
|
||||
from simple_llm.tokenizer.bpe import BPE
|
||||
|
||||
class TestBPE:
|
||||
@pytest.fixture(params=[SimpleBPE, OptimizeBPE])
|
||||
def bpe_class(self, request):
|
||||
return request.param
|
||||
def test_basic_bpe():
|
||||
"""Базовый тест работы BPE"""
|
||||
tokenizer = BPE(vocab_size=10)
|
||||
text = "мама мыла раму"
|
||||
|
||||
def test_initialization(self, bpe_class):
|
||||
"""Тест инициализации BPE-токенизатора"""
|
||||
bpe = bpe_class(vocab_size=100)
|
||||
assert bpe.vocab_size == 100
|
||||
assert bpe.vocab == []
|
||||
assert bpe.token2id == {}
|
||||
assert bpe.id2token == {}
|
||||
# Обучение
|
||||
tokenizer.fit(text)
|
||||
|
||||
def test_fit_simple_text(self, bpe_class):
|
||||
"""Тест обучения на простом тексте"""
|
||||
text = "мама мыла раму"
|
||||
bpe = bpe_class(vocab_size=20)
|
||||
bpe.fit(text)
|
||||
|
||||
# Проверки словаря
|
||||
assert isinstance(bpe.vocab, list)
|
||||
assert len(bpe.vocab) > 0
|
||||
assert len(bpe.vocab) <= 20
|
||||
assert all(isinstance(token, str) for token in bpe.vocab)
|
||||
|
||||
# Проверка словарей
|
||||
assert len(bpe.vocab) == len(bpe.token2id)
|
||||
assert len(bpe.vocab) == len(bpe.id2token)
|
||||
|
||||
# Проверка соответствия токенов и ID
|
||||
for token in bpe.vocab:
|
||||
assert bpe.token2id[token] == bpe.vocab.index(token)
|
||||
assert bpe.id2token[bpe.token2id[token]] == token
|
||||
|
||||
@pytest.mark.parametrize("text,expected_min_size", [
|
||||
("", 0),
|
||||
("а", 1),
|
||||
("ааааа", 3) # Минимум 3 токена
|
||||
])
|
||||
def test_edge_cases(self, bpe_class, text, expected_min_size):
|
||||
"""Тест граничных случаев"""
|
||||
bpe = bpe_class(vocab_size=10)
|
||||
bpe.fit(text)
|
||||
assert len(bpe.vocab) >= expected_min_size
|
||||
|
||||
def test_duplicate_protection(self, bpe_class):
|
||||
"""Тест защиты от дубликатов токенов"""
|
||||
bpe = bpe_class(vocab_size=50)
|
||||
bpe.fit("аааааааааа" * 100) # Много повторений
|
||||
assert len(bpe.vocab) == len(set(bpe.vocab))
|
||||
# Проверка размера словаря
|
||||
assert len(tokenizer.vocab) == 10
|
||||
|
||||
# Кодирование/декодирование
|
||||
encoded = tokenizer.encode(text)
|
||||
decoded = tokenizer.decode(encoded)
|
||||
|
||||
assert decoded == text
|
||||
assert len(encoded) > 0
|
||||
|
||||
# Проверка неизвестных символов
|
||||
unknown_encoded = tokenizer.encode("мама мыла окно")
|
||||
assert -1 in unknown_encoded # Специальный токен для неизвестных символов
|
||||
|
||||
74
tests/test_bpe_detailed.py
Normal file
74
tests/test_bpe_detailed.py
Normal file
@@ -0,0 +1,74 @@
|
||||
import os
|
||||
import tempfile
|
||||
import pytest
|
||||
from simple_llm.tokenizer.bpe import BPE
|
||||
|
||||
class TestBPE:
|
||||
@pytest.fixture
|
||||
def sample_text(self):
|
||||
return "ааабббвввггг аааббб дддд ееее жжжж"
|
||||
|
||||
@pytest.fixture
|
||||
def bpe(self):
|
||||
return BPE(vocab_size=20)
|
||||
|
||||
def test_fit(self, bpe, sample_text):
|
||||
"""Тест обучения токенизатора"""
|
||||
bpe.fit(sample_text)
|
||||
assert len(bpe.vocab) == bpe.vocab_size
|
||||
assert len(bpe.token2id) == bpe.vocab_size
|
||||
assert len(bpe.id2token) == bpe.vocab_size
|
||||
|
||||
def test_encode_decode(self, bpe, sample_text):
|
||||
"""Тест кодирования и декодирования"""
|
||||
bpe.fit(sample_text)
|
||||
encoded = bpe.encode(sample_text)
|
||||
decoded = bpe.decode(encoded)
|
||||
assert decoded == sample_text
|
||||
|
||||
def test_encode_unknown_chars(self, bpe, sample_text):
|
||||
"""Тест с неизвестными символами"""
|
||||
bpe.fit(sample_text)
|
||||
test_text = "ааббцц" # 'цц' нет в обучающем тексте
|
||||
encoded = bpe.encode(test_text)
|
||||
assert -1 in encoded # Должен содержать специальный токен для неизвестных символов
|
||||
decoded = bpe.decode(encoded)
|
||||
assert "цц" in decoded
|
||||
|
||||
def test_save_load(self, bpe, sample_text):
|
||||
"""Тест сохранения и загрузки"""
|
||||
bpe.fit(sample_text)
|
||||
|
||||
with tempfile.NamedTemporaryFile(delete=False) as tmp:
|
||||
try:
|
||||
bpe.save(tmp.name)
|
||||
loaded = BPE.load(tmp.name)
|
||||
|
||||
assert loaded.vocab_size == bpe.vocab_size
|
||||
assert loaded.vocab == bpe.vocab
|
||||
assert loaded.token2id == bpe.token2id
|
||||
assert loaded.id2token == bpe.id2token
|
||||
|
||||
# Проверяем работоспособность после загрузки
|
||||
encoded = loaded.encode(sample_text)
|
||||
decoded = loaded.decode(encoded)
|
||||
assert decoded == sample_text
|
||||
finally:
|
||||
os.unlink(tmp.name)
|
||||
|
||||
def test_pair_merging(self, bpe, sample_text):
|
||||
"""Тест правильности объединения пар"""
|
||||
bpe.fit(sample_text)
|
||||
|
||||
# Проверяем, что самые частые пары были объединены
|
||||
assert 'аа' in bpe.vocab or 'ааа' in bpe.vocab
|
||||
assert 'бб' in bpe.vocab or 'ббб' in bpe.vocab
|
||||
|
||||
def test_vocab_size(self):
|
||||
"""Тест обработки слишком маленького vocab_size"""
|
||||
small_bpe = BPE(vocab_size=5)
|
||||
with pytest.raises(ValueError):
|
||||
small_bpe.fit("абвгд") # Слишком мало для начальных символов
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main()
|
||||
Reference in New Issue
Block a user