mirror of
https://github.com/pese-git/simple-llm.git
synced 2026-01-23 13:03:55 +00:00
feat: implement bpe algorithm
This commit is contained in:
78
README.md
78
README.md
@@ -1 +1,77 @@
|
||||
# simple-llm
|
||||
# Simple LLM Tokenizer
|
||||
|
||||
Простой и эффективный токенизатор для языковых моделей на основе BPE (Byte Pair Encoding)
|
||||
|
||||
## Описание проекта
|
||||
|
||||
Проект предоставляет реализации алгоритма BPE (Byte Pair Encoding) для токенизации текста:
|
||||
- `SimpleBPE` - базовая версия
|
||||
- `OptimizeBPE` - оптимизированная версия с улучшенной производительностью
|
||||
|
||||
Основные возможности:
|
||||
- Обучение на любом тексте (поддержка кириллицы и других алфавитов)
|
||||
- Гибкая настройка размера словаря
|
||||
- Простота интеграции в существующие проекты
|
||||
|
||||
## Установка
|
||||
|
||||
1. Склонируйте репозиторий:
|
||||
```bash
|
||||
git clone https://github.com/yourusername/simple-llm.git
|
||||
cd simple-llm
|
||||
```
|
||||
|
||||
2. Установите пакет:
|
||||
```bash
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
## Быстрый старт
|
||||
|
||||
```python
|
||||
from simple_llm.tokenizer import SimpleBPE
|
||||
|
||||
# Инициализация и обучение
|
||||
text = "мама мыла раму, папа пил какао"
|
||||
bpe = SimpleBPE(vocab_size=50)
|
||||
bpe.fit(text)
|
||||
|
||||
# Токенизация
|
||||
tokens = bpe.tokenize(text)
|
||||
print(tokens)
|
||||
```
|
||||
|
||||
## Интеграция в проект
|
||||
|
||||
Добавьте в ваш `requirements.txt`:
|
||||
```
|
||||
git+https://github.com/yourusername/simple-llm.git
|
||||
```
|
||||
|
||||
Или установите напрямую:
|
||||
```bash
|
||||
pip install git+https://github.com/yourusername/simple-llm.git
|
||||
```
|
||||
|
||||
## Примеры
|
||||
|
||||
Дополнительные примеры использования смотрите в папке [example](/example):
|
||||
- Сравнение SimpleBPE и OptimizeBPE
|
||||
- Работа с разными языками
|
||||
- Настройка параметров токенизации
|
||||
|
||||
## Разработка
|
||||
|
||||
Для запуска тестов:
|
||||
```bash
|
||||
pytest tests/
|
||||
```
|
||||
|
||||
Для внесения изменений установите зависимости разработки:
|
||||
```bash
|
||||
pip install -e ".[dev]"
|
||||
```
|
||||
|
||||
## Лицензия
|
||||
|
||||
Проект распространяется под лицензией MIT. Подробнее см. [LICENSE](LICENSE).
|
||||
|
||||
52
doc/bpe_algorithm.drawio
Normal file
52
doc/bpe_algorithm.drawio
Normal file
@@ -0,0 +1,52 @@
|
||||
<mxfile>
|
||||
<diagram name="Page-1">
|
||||
<mxGraphModel dx="1200" dy="580">
|
||||
<root>
|
||||
<mxCell id="0"/>
|
||||
<mxCell id="1" parent="0"/>
|
||||
<mxCell id="2" value="Начало" style="rounded=1;whiteSpace=wrap;html=1;" parent="1" vertex="1">
|
||||
<mxGeometry x="100" y="60" width="120" height="60"/>
|
||||
</mxCell>
|
||||
<mxCell id="3" value="Разбить текст на символы" style="rhombus;whiteSpace=wrap;html=1;" parent="1" vertex="1">
|
||||
<mxGeometry x="100" y="160" width="120" height="80"/>
|
||||
</mxCell>
|
||||
<mxCell id="4" value="" style="edgeStyle=none;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;" parent="1" source="2" target="3" edge="1">
|
||||
<mxGeometry relative="1"/>
|
||||
</mxCell>
|
||||
<mxCell id="5" value="Подсчитать частоты пар" style="whiteSpace=wrap;html=1;" parent="1" vertex="1">
|
||||
<mxGeometry x="100" y="280" width="120" height="60"/>
|
||||
</mxCell>
|
||||
<mxCell id="6" value="" style="edgeStyle=none;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;" parent="1" source="3" target="5" edge="1">
|
||||
<mxGeometry relative="1"/>
|
||||
</mxCell>
|
||||
<mxCell id="7" value="Выбрать наиболее частую пару" style="rhombus;whiteSpace=wrap;html=1;" parent="1" vertex="1">
|
||||
<mxGeometry x="100" y="380" width="120" height="80"/>
|
||||
</mxCell>
|
||||
<mxCell id="8" value="" style="edgeStyle=none;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;" parent="1" source="5" target="7" edge="1">
|
||||
<mxGeometry relative="1"/>
|
||||
</mxCell>
|
||||
<mxCell id="9" value="Заменить пару новым токеном" style="whiteSpace=wrap;html=1;" parent="1" vertex="1">
|
||||
<mxGeometry x="100" y="500" width="120" height="60"/>
|
||||
</mxCell>
|
||||
<mxCell id="10" value="" style="edgeStyle=none;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;" parent="1" source="7" target="9" edge="1">
|
||||
<mxGeometry relative="1"/>
|
||||
</mxCell>
|
||||
<mxCell id="11" value="Достигнут лимит словаря?" style="rhombus;whiteSpace=wrap;html=1;" parent="1" vertex="1">
|
||||
<mxGeometry x="100" y="600" width="120" height="80"/>
|
||||
</mxCell>
|
||||
<mxCell id="12" value="" style="edgeStyle=none;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;" parent="1" source="9" target="11" edge="1">
|
||||
<mxGeometry relative="1"/>
|
||||
</mxCell>
|
||||
<mxCell id="13" value="Конец" style="rounded=1;whiteSpace=wrap;html=1;" parent="1" vertex="1">
|
||||
<mxGeometry x="100" y="720" width="120" height="60"/>
|
||||
</mxCell>
|
||||
<mxCell id="14" value="Да" style="edgeStyle=none;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;" parent="1" source="11" target="13" edge="1">
|
||||
<mxGeometry relative="1"/>
|
||||
</mxCell>
|
||||
<mxCell id="15" value="Нет" style="edgeStyle=none;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;" parent="1" source="11" target="5" edge="1">
|
||||
<mxGeometry relative="1"/>
|
||||
</mxCell>
|
||||
</root>
|
||||
</mxGraphModel>
|
||||
</diagram>
|
||||
</mxfile>
|
||||
136
doc/bpe_algorithm.md
Normal file
136
doc/bpe_algorithm.md
Normal file
@@ -0,0 +1,136 @@
|
||||
# Byte Pair Encoding (BPE) Algorithm
|
||||
|
||||
## Введение
|
||||
|
||||
Byte Pair Encoding (BPE) - это алгоритм компрессии данных, адаптированный для токенизации текста в обработке естественного языка. В контексте языковых моделей BPE используется для создания эффективного словаря подстрок (токенов).
|
||||
|
||||
## Основные понятия
|
||||
|
||||
- **Токен** - элементарная единица текста (символ или последовательность символов)
|
||||
- **Словарь** - набор уникальных токенов, используемых для представления текста
|
||||
- **Частота пары** - количество раз, когда два токена встречаются вместе в тексте
|
||||
|
||||
## Алгоритм работы
|
||||
|
||||
### 1. Инициализация
|
||||
```python
|
||||
Исходный текст → Разбить на символы → Первоначальный словарь
|
||||
```
|
||||
Пример:
|
||||
```
|
||||
"мама" → ['м', 'а', 'м', 'а']
|
||||
```
|
||||
|
||||
### 2. Основной цикл
|
||||
```mermaid
|
||||
graph TD
|
||||
A[Подсчет частот пар] --> B[Выбор наиболее частой пары]
|
||||
B --> C[Создание нового токена]
|
||||
C --> D[Обновление последовательности]
|
||||
D --> E{Достигнут лимит словаря?}
|
||||
E -->|Нет| A
|
||||
E -->|Да| F[Конец]
|
||||
```
|
||||
|
||||
### 3. Детализация шагов
|
||||
|
||||
#### Шаг 1: Подсчет частот пар
|
||||
Для текущей последовательности токенов подсчитываем все пары соседних токенов:
|
||||
```
|
||||
Текст: "мама мыла"
|
||||
Токены: ['м', 'а', 'м', 'а', ' ', 'м', 'ы', 'л', 'а']
|
||||
Пары: ('м','а'), ('а','м'), ('м','а'), ('а',' '), (' ','м'), ('м','ы'), ('ы','л'), ('л','а')
|
||||
```
|
||||
|
||||
#### Шаг 2: Выбор пары для слияния
|
||||
Находим пару с максимальной частотой. При равенстве частот выбираем пару, которая встречается раньше в тексте.
|
||||
|
||||
#### Шаг 3: Слияние
|
||||
Объединяем выбранную пару в новый токен и заменяем все её вхождения в тексте:
|
||||
```
|
||||
Выбранная пара: ('м', 'а')
|
||||
Новый токен: 'ма'
|
||||
Обновленная последовательность: ['ма', 'ма', ' ', 'м', 'ы', 'л', 'а']
|
||||
```
|
||||
|
||||
#### Шаг 4: Обновление словаря
|
||||
Добавляем новый токен в словарь:
|
||||
```
|
||||
Словарь: ['м', 'а', ' ', 'ы', 'л', 'ма']
|
||||
```
|
||||
|
||||
### 4. Критерии остановки
|
||||
|
||||
1. Достижение заданного размера словаря
|
||||
2. Отсутствие пар для слияния (все возможные пары уже добавлены)
|
||||
3. Достижение максимального числа итераций
|
||||
|
||||
## Псевдокод
|
||||
|
||||
```python
|
||||
def train_bpe(text, vocab_size):
|
||||
# Инициализация
|
||||
tokens = list(text)
|
||||
vocab = set(tokens)
|
||||
|
||||
while len(vocab) < vocab_size:
|
||||
# Подсчет пар
|
||||
pairs = get_pairs(tokens)
|
||||
if not pairs:
|
||||
break
|
||||
|
||||
# Выбор наиболее частой пары
|
||||
best_pair = max(pairs, key=pairs.get)
|
||||
|
||||
# Слияние
|
||||
new_tokens = []
|
||||
i = 0
|
||||
while i < len(tokens):
|
||||
if i < len(tokens)-1 and (tokens[i], tokens[i+1]) == best_pair:
|
||||
new_tokens.append(best_pair[0] + best_pair[1])
|
||||
i += 2
|
||||
else:
|
||||
new_tokens.append(tokens[i])
|
||||
i += 1
|
||||
tokens = new_tokens
|
||||
|
||||
# Обновление словаря
|
||||
vocab.add(best_pair[0] + best_pair[1])
|
||||
|
||||
return vocab
|
||||
```
|
||||
|
||||
## Пример работы
|
||||
|
||||
**Исходный текст**: "мама мыла раму"
|
||||
|
||||
**Итерация 1**:
|
||||
- Пара ('м','а') встречается 2 раза
|
||||
- Новый токен: 'ма'
|
||||
- Текст: ['ма', 'ма', ' ', 'м', 'ы', 'л', 'а', ' ', 'р', 'а', 'м', 'у']
|
||||
|
||||
**Итерация 2**:
|
||||
- Пара ('ма',' ') встречается 1 раз
|
||||
- Новый токен: 'ма '
|
||||
- Текст: ['ма ', 'ма', 'мы', 'л', 'а', ' ', 'р', 'а', 'м', 'у']
|
||||
|
||||
**Результирующий словарь** (частично):
|
||||
['м', 'а', ' ', 'ы', 'л', 'р', 'у', 'ма', 'ма ', 'мы']
|
||||
|
||||
## Применение в языковых моделях
|
||||
|
||||
1. Эффективное представление редких слов
|
||||
2. Снижение размерности входных данных
|
||||
3. Возможность обработки OOV (Out-of-Vocabulary) слов
|
||||
|
||||
## Ограничения
|
||||
|
||||
1. Чувствительность к регистру (можно решить предварительной нормализацией)
|
||||
2. Зависимость от обучающего корпуса
|
||||
3. Не всегда выделяет лингвистически осмысленные морфемы
|
||||
|
||||
## Дополнительные материалы
|
||||
|
||||
1. [Original BPE paper](https://arxiv.org/abs/1508.07909)
|
||||
2. [BPE in HuggingFace](https://huggingface.co/docs/transformers/tokenizer_summary)
|
||||
3. [Practical guide to BPE](https://towardsdatascience.com/byte-pair-encoding-subword-based-tokenization-algorithm-77828a70bee0)
|
||||
0
example/__init__.py
Normal file
0
example/__init__.py
Normal file
60
example/example_bpe.py
Normal file
60
example/example_bpe.py
Normal file
@@ -0,0 +1,60 @@
|
||||
from simple_llm.tokenizer.simple_bpe import SimpleBPE
|
||||
from simple_llm.tokenizer.optimize_bpe import OptimizeBPE
|
||||
import time
|
||||
|
||||
def tokenize_manually(text, vocab):
|
||||
"""Простая ручная токенизация по словарю"""
|
||||
tokens = []
|
||||
i = 0
|
||||
n = len(text)
|
||||
while i < n:
|
||||
found = False
|
||||
# Ищем самый длинный возможный токен из словаря
|
||||
for l in range(min(4, n-i), 0, -1): # проверяем токены длиной до 4 символов
|
||||
if text[i:i+l] in vocab:
|
||||
tokens.append(text[i:i+l])
|
||||
i += l
|
||||
found = True
|
||||
break
|
||||
if not found: # если токен не найден, берем один символ
|
||||
tokens.append(text[i])
|
||||
i += 1
|
||||
return tokens
|
||||
|
||||
def run_example(text, vocab_size=30):
|
||||
print("\n=== Тестирование токенизаторов ===")
|
||||
print(f"Исходный текст: '{text}'\n")
|
||||
|
||||
# Simple BPE
|
||||
start = time.time()
|
||||
simple_bpe = SimpleBPE(vocab_size=vocab_size)
|
||||
simple_bpe.fit(text)
|
||||
simple_time = time.time() - start
|
||||
|
||||
print("SimpleBPE:")
|
||||
print(f"Время обучения: {simple_time:.4f} сек")
|
||||
print(f"Размер словаря: {len(simple_bpe.vocab)}")
|
||||
print(f"Словарь: {simple_bpe.vocab}")
|
||||
print(f"Ручная токенизация: {tokenize_manually(text, simple_bpe.vocab)}\n")
|
||||
|
||||
# Optimize BPE
|
||||
start = time.time()
|
||||
opt_bpe = OptimizeBPE(vocab_size=vocab_size)
|
||||
opt_bpe.fit(text)
|
||||
opt_time = time.time() - start
|
||||
|
||||
print("OptimizeBPE:")
|
||||
print(f"Время обучения: {opt_time:.4f} сек")
|
||||
print(f"Размер словаря: {len(opt_bpe.vocab)}")
|
||||
print(f"Словарь: {opt_bpe.vocab}")
|
||||
print(f"Ручная токенизация: {tokenize_manually(text, opt_bpe.vocab)}\n")
|
||||
|
||||
if opt_time > 0:
|
||||
print(f"Оптимизированная версия быстрее в {simple_time/opt_time:.1f} раз\n")
|
||||
|
||||
if __name__ == "__main__":
|
||||
text1 = "мама мыла раму, папа пил какао"
|
||||
text2 = "коты бегают быстро, собаки лают громко"
|
||||
|
||||
run_example(text1)
|
||||
run_example(text2)
|
||||
41
pyproject.toml
Normal file
41
pyproject.toml
Normal file
@@ -0,0 +1,41 @@
|
||||
[build-system]
|
||||
requires = ["setuptools>=61.0", "wheel"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "simple-llm"
|
||||
version = "0.1.0"
|
||||
description = "Simple BPE tokenizer implementation for educational purposes"
|
||||
readme = "README.md"
|
||||
authors = [
|
||||
{ name = "Sergey Penkovsky", email = "sergey.penkovsky@gmail.com" },
|
||||
]
|
||||
license = { text = "MIT" }
|
||||
classifiers = [
|
||||
"Development Status :: 3 - Alpha",
|
||||
"Intended Audience :: Education",
|
||||
"License :: OSI Approved :: MIT License",
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3.8",
|
||||
"Programming Language :: Python :: 3.9",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
"Programming Language :: Python :: 3.11",
|
||||
"Programming Language :: Python :: 3.12",
|
||||
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
||||
]
|
||||
requires-python = ">=3.8"
|
||||
dependencies = []
|
||||
|
||||
[project.urls]
|
||||
Homepage = "https://github.com/pese-git/simple-llm"
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["."]
|
||||
include = ["simple_llm*"]
|
||||
exclude = ["tests*", "example*"]
|
||||
|
||||
[project.optional-dependencies]
|
||||
dev = [
|
||||
"pytest>=7.0",
|
||||
"black>=23.0",
|
||||
]
|
||||
0
simple_llm/__init__.py
Normal file
0
simple_llm/__init__.py
Normal file
0
simple_llm/tokenizer/__init__.py
Normal file
0
simple_llm/tokenizer/__init__.py
Normal file
39
simple_llm/tokenizer/bpe_interface.py
Normal file
39
simple_llm/tokenizer/bpe_interface.py
Normal file
@@ -0,0 +1,39 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Dict
|
||||
|
||||
class BPE(ABC):
|
||||
"""
|
||||
Реализация алгоритма токенизации Byte Pair Encoding (BPE).
|
||||
|
||||
BPE — это итеративный алгоритм, последовательно объединяющий наиболее частые пары символов/токенов,
|
||||
чтобы построить эффективный словарь для работы с текстом: токенизации, обучения языковой модели и т.п.
|
||||
|
||||
Аргументы конструктора:
|
||||
vocab_size (int): Желаемый размер итогового словаря токенов (включая отдельные символы и составные токены).
|
||||
|
||||
Атрибуты:
|
||||
vocab (List[str]): Список токенов в порядке их получения (сначала символы, затем новые пары).
|
||||
token2id (Dict[str, int]): Словарь преобразования токена в его индекс.
|
||||
id2token (Dict[int, str]): Обратный словарь преобразования индекса в токен.
|
||||
"""
|
||||
def __init__(self, vocab_size: int):
|
||||
"""
|
||||
Инициализация BPE токенизатора.
|
||||
|
||||
Args:
|
||||
vocab_size (int): Размер словаря, к которому будет расширяться BPE.
|
||||
"""
|
||||
self.vocab_size = vocab_size
|
||||
self.vocab: List[str] = []
|
||||
self.token2id: Dict[str, int] = {}
|
||||
self.id2token: Dict[int, str] = {}
|
||||
|
||||
@abstractmethod
|
||||
def fit(self, text: str):
|
||||
pass
|
||||
|
||||
def encode(self, text: str):
|
||||
raise NotImplementedError("Implement in subclass if needed.")
|
||||
|
||||
def decode(self, ids: list[int]):
|
||||
raise NotImplementedError("Implement in subclass if needed.")
|
||||
147
simple_llm/tokenizer/optimize_bpe.py
Normal file
147
simple_llm/tokenizer/optimize_bpe.py
Normal file
@@ -0,0 +1,147 @@
|
||||
from .bpe_interface import BPE
|
||||
|
||||
from collections import Counter
|
||||
from typing import List, Tuple, Dict
|
||||
|
||||
class OptimizeBPE(BPE):
|
||||
|
||||
def fit(self, text: str) -> None:
|
||||
"""
|
||||
Обучает BPE-модель на предоставленном тексте.
|
||||
|
||||
Последовательно расширяет словарь за счёт объединения наиболее частых пар токенов до достижения vocab_size.
|
||||
|
||||
Args:
|
||||
text (str): Исходная строка для обучения токенизатора.
|
||||
"""
|
||||
sequence = list(text)
|
||||
self._init_vocab(sequence)
|
||||
pair_freq, pair_first_occurrence = self._get_pair_stats(sequence)
|
||||
|
||||
while len(self.vocab) < self.vocab_size and pair_freq:
|
||||
pair_to_merge = self._select_pair_to_merge(pair_freq, pair_first_occurrence)
|
||||
new_token = pair_to_merge[0] + pair_to_merge[1]
|
||||
|
||||
if new_token in self.vocab:
|
||||
# Защита от зацикливания: пара уже была добавлена как новый токен.
|
||||
del pair_freq[pair_to_merge]
|
||||
continue
|
||||
|
||||
self.vocab.append(new_token)
|
||||
sequence, pair_freq, pair_first_occurrence = self._merge_pair(
|
||||
sequence, pair_to_merge, new_token, pair_freq
|
||||
)
|
||||
|
||||
self._build_token_dicts()
|
||||
|
||||
def _init_vocab(self, sequence: List[str]) -> None:
|
||||
"""
|
||||
Формирует стартовый словарь уникальных символов из последовательности, отсортированный по символам.
|
||||
|
||||
Args:
|
||||
sequence (List[str]): Исходная последовательность символов.
|
||||
"""
|
||||
self.vocab = sorted(set(sequence))
|
||||
|
||||
def _get_pair_stats(self, sequence: List[str]) -> Tuple[Counter, Dict[Tuple[str, str], int]]:
|
||||
"""
|
||||
Вычисляет частоты появления и индексы первого появления всех пар соседних токенов в последовательности.
|
||||
|
||||
Args:
|
||||
sequence (List[str]): Текущая последовательность токенов.
|
||||
|
||||
Returns:
|
||||
Tuple[Counter, Dict[Tuple[str, str], int]]:
|
||||
- Counter по всем парам (их частоты),
|
||||
- Словарь первых индексов появления каждой пары.
|
||||
"""
|
||||
pair_freq = Counter()
|
||||
pair_first_occurrence = {}
|
||||
for i in range(len(sequence) - 1):
|
||||
pair = (sequence[i], sequence[i + 1])
|
||||
pair_freq[pair] += 1
|
||||
if pair not in pair_first_occurrence:
|
||||
pair_first_occurrence[pair] = i
|
||||
return pair_freq, pair_first_occurrence
|
||||
|
||||
def _select_pair_to_merge(self, pair_freq: Counter, pair_first_occurrence: Dict[Tuple[str, str], int]) -> Tuple[str, str]:
|
||||
"""
|
||||
Выбирает следующую пару для слияния:
|
||||
приоритет — самая частая; если таких несколько — та, которая встречается раньше других (наименьший индекс появления).
|
||||
|
||||
Args:
|
||||
pair_freq (Counter): Частоты всех пар.
|
||||
pair_first_occurrence (Dict[Tuple[str, str], int]): Индексы первых появлений каждой пары.
|
||||
|
||||
Returns:
|
||||
Tuple[str, str]: Пара для слияния (двойка токенов).
|
||||
"""
|
||||
pair_to_merge, _ = max(
|
||||
pair_freq.items(),
|
||||
key=lambda x: (x[1], -pair_first_occurrence.get(x[0], float('inf')))
|
||||
)
|
||||
return pair_to_merge
|
||||
|
||||
def _merge_pair(
|
||||
self,
|
||||
sequence: List[str],
|
||||
pair_to_merge: Tuple[str, str],
|
||||
new_token: str,
|
||||
pair_freq: Counter
|
||||
) -> Tuple[List[str], Counter, Dict[Tuple[str, str], int]]:
|
||||
"""
|
||||
Выполняет слияние заданной пары токенов в новой последовательности, корректирует частоты пар и индексы первых появлений.
|
||||
|
||||
Args:
|
||||
sequence (List[str]): Текущая последовательность токенов.
|
||||
pair_to_merge (Tuple[str, str]): Пара для слияния.
|
||||
new_token (str): Новый токен (результат слияния).
|
||||
pair_freq (Counter): Частоты текущих пар.
|
||||
|
||||
Returns:
|
||||
Tuple[List[str], Counter, Dict[Tuple[str, str], int]]:
|
||||
- Новая последовательность,
|
||||
- Обновлённые частоты пар,
|
||||
- Обновлённые индексы первых появлений пар.
|
||||
"""
|
||||
new_sequence = []
|
||||
i = 0
|
||||
pairs_to_decrement = Counter()
|
||||
pairs_to_increment = Counter()
|
||||
length = len(sequence)
|
||||
while i < length:
|
||||
if i < length - 1 and (sequence[i], sequence[i + 1]) == pair_to_merge:
|
||||
if i > 0:
|
||||
pairs_to_decrement[(sequence[i - 1], sequence[i])] += 1
|
||||
pairs_to_increment[(sequence[i - 1], new_token)] += 1
|
||||
if i + 2 < length:
|
||||
pairs_to_decrement[(sequence[i + 1], sequence[i + 2])] += 1
|
||||
pairs_to_increment[(new_token, sequence[i + 2])] += 1
|
||||
new_sequence.append(new_token)
|
||||
i += 2
|
||||
else:
|
||||
new_sequence.append(sequence[i])
|
||||
i += 1
|
||||
for pair, dec_count in pairs_to_decrement.items():
|
||||
pair_freq[pair] -= dec_count
|
||||
if pair_freq[pair] <= 0:
|
||||
del pair_freq[pair]
|
||||
for pair, inc_count in pairs_to_increment.items():
|
||||
pair_freq[pair] += inc_count
|
||||
# Пересчитываем первый индекс появления пар
|
||||
pair_first_occurrence = {}
|
||||
for idx in range(len(new_sequence) - 1):
|
||||
pair = (new_sequence[idx], new_sequence[idx + 1])
|
||||
if pair not in pair_first_occurrence:
|
||||
pair_first_occurrence[pair] = idx
|
||||
for pair in list(pair_freq.keys()):
|
||||
if pair not in pair_first_occurrence:
|
||||
del pair_freq[pair]
|
||||
return new_sequence, pair_freq, pair_first_occurrence
|
||||
|
||||
def _build_token_dicts(self) -> None:
|
||||
"""
|
||||
Формирует словари вида <токен, id> и <id, токен> по итоговому списку токенов.
|
||||
"""
|
||||
self.token2id = {token: idx for idx, token in enumerate(self.vocab)}
|
||||
self.id2token = {idx: token for idx, token in enumerate(self.vocab)}
|
||||
60
simple_llm/tokenizer/simple_bpe.py
Normal file
60
simple_llm/tokenizer/simple_bpe.py
Normal file
@@ -0,0 +1,60 @@
|
||||
from .bpe_interface import BPE
|
||||
|
||||
class SimpleBPE(BPE):
|
||||
def fit(self, text: str):
|
||||
# 1. Получаем уникальные токены (символы)
|
||||
unique_tokens = sorted(set(text))
|
||||
tokens = unique_tokens.copy()
|
||||
|
||||
# 2. Разбиваем текст на токены-символы
|
||||
sequence = list(text)
|
||||
|
||||
# 3. Объединяем токены до достижения нужного размера словаря
|
||||
while len(tokens) < self.vocab_size:
|
||||
#print(f'len={len(tokens)} < {self.vocab_size}')
|
||||
# Считаем частоты пар
|
||||
pair_freq = {}
|
||||
for i in range(len(sequence) - 1):
|
||||
pair = (sequence[i], sequence[i + 1])
|
||||
#print(f'pair = {pair}')
|
||||
if pair not in pair_freq:
|
||||
pair_freq[pair] = 0
|
||||
pair_freq[pair] += 1
|
||||
|
||||
|
||||
#print(f'pair_freq = {pair_freq}')
|
||||
if not pair_freq:
|
||||
break # нет пар — выходим
|
||||
|
||||
# Находим самую частую пару (в случае равенства — та, что встретилась первой)
|
||||
most_frequent_pair = max(pair_freq.items(), key=lambda x: (x[1], -self._pair_first_index(sequence, x[0])))[0]
|
||||
#print(most_frequent_pair)
|
||||
# Создаем новый токен
|
||||
new_token = most_frequent_pair[0] + most_frequent_pair[1]
|
||||
#print(f"new token={new_token}")
|
||||
tokens.append(new_token)
|
||||
#print(f"tokens={tokens}")
|
||||
|
||||
i = 0
|
||||
new_sequence = []
|
||||
|
||||
while i < len(sequence):
|
||||
if i < len(sequence) - 1 and (sequence[i], sequence[i + 1]) == most_frequent_pair:
|
||||
new_sequence.append(new_token)
|
||||
i += 2 # пропускаем два символа — заменённую пару
|
||||
else:
|
||||
new_sequence.append(sequence[i])
|
||||
i += 1
|
||||
sequence = new_sequence
|
||||
#break
|
||||
|
||||
# 4. Создаем словари
|
||||
self.vocab = tokens.copy()
|
||||
self.token2id = dict(zip(tokens, range(self.vocab_size)))
|
||||
self.id2token = dict(zip(range(self.vocab_size), tokens))
|
||||
|
||||
def _pair_first_index(self, sequence, pair):
|
||||
for i in range(len(sequence) - 1):
|
||||
if (sequence[i], sequence[i + 1]) == pair:
|
||||
return i
|
||||
return float('inf')
|
||||
0
tests/__init__.py
Normal file
0
tests/__init__.py
Normal file
13
tests/conftest.py
Normal file
13
tests/conftest.py
Normal file
@@ -0,0 +1,13 @@
|
||||
import pytest
|
||||
from simple_llm.tokenizer.simple_bpe import SimpleBPE
|
||||
from simple_llm.tokenizer.optimize_bpe import OptimizeBPE
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def large_text():
|
||||
"""Генерирует большой текст для тестирования"""
|
||||
return " ".join(["мама мыла раму"] * 1000)
|
||||
|
||||
@pytest.fixture(params=[SimpleBPE, OptimizeBPE])
|
||||
def bpe_class(request):
|
||||
"""Возвращает классы BPE для тестирования"""
|
||||
return request.param
|
||||
35
tests/integration/test_bpe_integration.py
Normal file
35
tests/integration/test_bpe_integration.py
Normal file
@@ -0,0 +1,35 @@
|
||||
import pytest
|
||||
from simple_llm.tokenizer.simple_bpe import SimpleBPE
|
||||
from simple_llm.tokenizer.optimize_bpe import OptimizeBPE
|
||||
|
||||
def test_large_text_processing(bpe_class, large_text):
|
||||
"""Тест обработки большого текста"""
|
||||
bpe = bpe_class(vocab_size=100)
|
||||
bpe.fit(large_text)
|
||||
|
||||
# Проверки
|
||||
assert 50 < len(bpe.vocab) <= 100
|
||||
assert all(len(token) <= 4 for token in bpe.vocab) # Проверка на разумную длину токенов
|
||||
assert "мама" in bpe.vocab or "ма" in bpe.vocab # Проверка на наличие ожидаемых токенов
|
||||
|
||||
def test_special_characters(bpe_class):
|
||||
"""Тест обработки специальных символов"""
|
||||
text = "!@#$%^&*()_+1234567890"
|
||||
bpe = bpe_class(vocab_size=30)
|
||||
bpe.fit(text)
|
||||
|
||||
# Проверки
|
||||
assert len(bpe.vocab) > 10
|
||||
for char in set(text):
|
||||
assert any(char in token for token in bpe.vocab) # Каждый символ должен быть в каком-то токене
|
||||
|
||||
def test_unicode_characters(bpe_class):
|
||||
"""Тест обработки unicode-символов"""
|
||||
text = "日本語 한국어 русский English"
|
||||
bpe = bpe_class(vocab_size=50)
|
||||
bpe.fit(text)
|
||||
|
||||
# Проверки
|
||||
assert len(bpe.vocab) > 20
|
||||
assert any("日" in token for token in bpe.vocab)
|
||||
assert any("한" in token for token in bpe.vocab)
|
||||
54
tests/test_bpe.py
Normal file
54
tests/test_bpe.py
Normal file
@@ -0,0 +1,54 @@
|
||||
import pytest
|
||||
from simple_llm.tokenizer.simple_bpe import SimpleBPE
|
||||
from simple_llm.tokenizer.optimize_bpe import OptimizeBPE
|
||||
|
||||
class TestBPE:
|
||||
@pytest.fixture(params=[SimpleBPE, OptimizeBPE])
|
||||
def bpe_class(self, request):
|
||||
return request.param
|
||||
|
||||
def test_initialization(self, bpe_class):
|
||||
"""Тест инициализации BPE-токенизатора"""
|
||||
bpe = bpe_class(vocab_size=100)
|
||||
assert bpe.vocab_size == 100
|
||||
assert bpe.vocab == []
|
||||
assert bpe.token2id == {}
|
||||
assert bpe.id2token == {}
|
||||
|
||||
def test_fit_simple_text(self, bpe_class):
|
||||
"""Тест обучения на простом тексте"""
|
||||
text = "мама мыла раму"
|
||||
bpe = bpe_class(vocab_size=20)
|
||||
bpe.fit(text)
|
||||
|
||||
# Проверки словаря
|
||||
assert isinstance(bpe.vocab, list)
|
||||
assert len(bpe.vocab) > 0
|
||||
assert len(bpe.vocab) <= 20
|
||||
assert all(isinstance(token, str) for token in bpe.vocab)
|
||||
|
||||
# Проверка словарей
|
||||
assert len(bpe.vocab) == len(bpe.token2id)
|
||||
assert len(bpe.vocab) == len(bpe.id2token)
|
||||
|
||||
# Проверка соответствия токенов и ID
|
||||
for token in bpe.vocab:
|
||||
assert bpe.token2id[token] == bpe.vocab.index(token)
|
||||
assert bpe.id2token[bpe.token2id[token]] == token
|
||||
|
||||
@pytest.mark.parametrize("text,expected_size", [
|
||||
("", 0),
|
||||
("а", 1),
|
||||
("ааааа", 2) # Должны быть 'а' и 'аа'
|
||||
])
|
||||
def test_edge_cases(self, bpe_class, text, expected_size):
|
||||
"""Тест граничных случаев"""
|
||||
bpe = bpe_class(vocab_size=10)
|
||||
bpe.fit(text)
|
||||
assert len(bpe.vocab) == expected_size
|
||||
|
||||
def test_duplicate_protection(self, bpe_class):
|
||||
"""Тест защиты от дубликатов токенов"""
|
||||
bpe = bpe_class(vocab_size=50)
|
||||
bpe.fit("аааааааааа" * 100) # Много повторений
|
||||
assert len(bpe.vocab) == len(set(bpe.vocab))
|
||||
Reference in New Issue
Block a user