mirror of
https://github.com/pese-git/simple-llm.git
synced 2026-01-23 21:14:17 +00:00
feat: implement bpe algorithm
This commit is contained in:
0
simple_llm/__init__.py
Normal file
0
simple_llm/__init__.py
Normal file
0
simple_llm/tokenizer/__init__.py
Normal file
0
simple_llm/tokenizer/__init__.py
Normal file
39
simple_llm/tokenizer/bpe_interface.py
Normal file
39
simple_llm/tokenizer/bpe_interface.py
Normal file
@@ -0,0 +1,39 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Dict
|
||||
|
||||
class BPE(ABC):
|
||||
"""
|
||||
Реализация алгоритма токенизации Byte Pair Encoding (BPE).
|
||||
|
||||
BPE — это итеративный алгоритм, последовательно объединяющий наиболее частые пары символов/токенов,
|
||||
чтобы построить эффективный словарь для работы с текстом: токенизации, обучения языковой модели и т.п.
|
||||
|
||||
Аргументы конструктора:
|
||||
vocab_size (int): Желаемый размер итогового словаря токенов (включая отдельные символы и составные токены).
|
||||
|
||||
Атрибуты:
|
||||
vocab (List[str]): Список токенов в порядке их получения (сначала символы, затем новые пары).
|
||||
token2id (Dict[str, int]): Словарь преобразования токена в его индекс.
|
||||
id2token (Dict[int, str]): Обратный словарь преобразования индекса в токен.
|
||||
"""
|
||||
def __init__(self, vocab_size: int):
|
||||
"""
|
||||
Инициализация BPE токенизатора.
|
||||
|
||||
Args:
|
||||
vocab_size (int): Размер словаря, к которому будет расширяться BPE.
|
||||
"""
|
||||
self.vocab_size = vocab_size
|
||||
self.vocab: List[str] = []
|
||||
self.token2id: Dict[str, int] = {}
|
||||
self.id2token: Dict[int, str] = {}
|
||||
|
||||
@abstractmethod
|
||||
def fit(self, text: str):
|
||||
pass
|
||||
|
||||
def encode(self, text: str):
|
||||
raise NotImplementedError("Implement in subclass if needed.")
|
||||
|
||||
def decode(self, ids: list[int]):
|
||||
raise NotImplementedError("Implement in subclass if needed.")
|
||||
147
simple_llm/tokenizer/optimize_bpe.py
Normal file
147
simple_llm/tokenizer/optimize_bpe.py
Normal file
@@ -0,0 +1,147 @@
|
||||
from .bpe_interface import BPE
|
||||
|
||||
from collections import Counter
|
||||
from typing import List, Tuple, Dict
|
||||
|
||||
class OptimizeBPE(BPE):
|
||||
|
||||
def fit(self, text: str) -> None:
|
||||
"""
|
||||
Обучает BPE-модель на предоставленном тексте.
|
||||
|
||||
Последовательно расширяет словарь за счёт объединения наиболее частых пар токенов до достижения vocab_size.
|
||||
|
||||
Args:
|
||||
text (str): Исходная строка для обучения токенизатора.
|
||||
"""
|
||||
sequence = list(text)
|
||||
self._init_vocab(sequence)
|
||||
pair_freq, pair_first_occurrence = self._get_pair_stats(sequence)
|
||||
|
||||
while len(self.vocab) < self.vocab_size and pair_freq:
|
||||
pair_to_merge = self._select_pair_to_merge(pair_freq, pair_first_occurrence)
|
||||
new_token = pair_to_merge[0] + pair_to_merge[1]
|
||||
|
||||
if new_token in self.vocab:
|
||||
# Защита от зацикливания: пара уже была добавлена как новый токен.
|
||||
del pair_freq[pair_to_merge]
|
||||
continue
|
||||
|
||||
self.vocab.append(new_token)
|
||||
sequence, pair_freq, pair_first_occurrence = self._merge_pair(
|
||||
sequence, pair_to_merge, new_token, pair_freq
|
||||
)
|
||||
|
||||
self._build_token_dicts()
|
||||
|
||||
def _init_vocab(self, sequence: List[str]) -> None:
|
||||
"""
|
||||
Формирует стартовый словарь уникальных символов из последовательности, отсортированный по символам.
|
||||
|
||||
Args:
|
||||
sequence (List[str]): Исходная последовательность символов.
|
||||
"""
|
||||
self.vocab = sorted(set(sequence))
|
||||
|
||||
def _get_pair_stats(self, sequence: List[str]) -> Tuple[Counter, Dict[Tuple[str, str], int]]:
|
||||
"""
|
||||
Вычисляет частоты появления и индексы первого появления всех пар соседних токенов в последовательности.
|
||||
|
||||
Args:
|
||||
sequence (List[str]): Текущая последовательность токенов.
|
||||
|
||||
Returns:
|
||||
Tuple[Counter, Dict[Tuple[str, str], int]]:
|
||||
- Counter по всем парам (их частоты),
|
||||
- Словарь первых индексов появления каждой пары.
|
||||
"""
|
||||
pair_freq = Counter()
|
||||
pair_first_occurrence = {}
|
||||
for i in range(len(sequence) - 1):
|
||||
pair = (sequence[i], sequence[i + 1])
|
||||
pair_freq[pair] += 1
|
||||
if pair not in pair_first_occurrence:
|
||||
pair_first_occurrence[pair] = i
|
||||
return pair_freq, pair_first_occurrence
|
||||
|
||||
def _select_pair_to_merge(self, pair_freq: Counter, pair_first_occurrence: Dict[Tuple[str, str], int]) -> Tuple[str, str]:
|
||||
"""
|
||||
Выбирает следующую пару для слияния:
|
||||
приоритет — самая частая; если таких несколько — та, которая встречается раньше других (наименьший индекс появления).
|
||||
|
||||
Args:
|
||||
pair_freq (Counter): Частоты всех пар.
|
||||
pair_first_occurrence (Dict[Tuple[str, str], int]): Индексы первых появлений каждой пары.
|
||||
|
||||
Returns:
|
||||
Tuple[str, str]: Пара для слияния (двойка токенов).
|
||||
"""
|
||||
pair_to_merge, _ = max(
|
||||
pair_freq.items(),
|
||||
key=lambda x: (x[1], -pair_first_occurrence.get(x[0], float('inf')))
|
||||
)
|
||||
return pair_to_merge
|
||||
|
||||
def _merge_pair(
|
||||
self,
|
||||
sequence: List[str],
|
||||
pair_to_merge: Tuple[str, str],
|
||||
new_token: str,
|
||||
pair_freq: Counter
|
||||
) -> Tuple[List[str], Counter, Dict[Tuple[str, str], int]]:
|
||||
"""
|
||||
Выполняет слияние заданной пары токенов в новой последовательности, корректирует частоты пар и индексы первых появлений.
|
||||
|
||||
Args:
|
||||
sequence (List[str]): Текущая последовательность токенов.
|
||||
pair_to_merge (Tuple[str, str]): Пара для слияния.
|
||||
new_token (str): Новый токен (результат слияния).
|
||||
pair_freq (Counter): Частоты текущих пар.
|
||||
|
||||
Returns:
|
||||
Tuple[List[str], Counter, Dict[Tuple[str, str], int]]:
|
||||
- Новая последовательность,
|
||||
- Обновлённые частоты пар,
|
||||
- Обновлённые индексы первых появлений пар.
|
||||
"""
|
||||
new_sequence = []
|
||||
i = 0
|
||||
pairs_to_decrement = Counter()
|
||||
pairs_to_increment = Counter()
|
||||
length = len(sequence)
|
||||
while i < length:
|
||||
if i < length - 1 and (sequence[i], sequence[i + 1]) == pair_to_merge:
|
||||
if i > 0:
|
||||
pairs_to_decrement[(sequence[i - 1], sequence[i])] += 1
|
||||
pairs_to_increment[(sequence[i - 1], new_token)] += 1
|
||||
if i + 2 < length:
|
||||
pairs_to_decrement[(sequence[i + 1], sequence[i + 2])] += 1
|
||||
pairs_to_increment[(new_token, sequence[i + 2])] += 1
|
||||
new_sequence.append(new_token)
|
||||
i += 2
|
||||
else:
|
||||
new_sequence.append(sequence[i])
|
||||
i += 1
|
||||
for pair, dec_count in pairs_to_decrement.items():
|
||||
pair_freq[pair] -= dec_count
|
||||
if pair_freq[pair] <= 0:
|
||||
del pair_freq[pair]
|
||||
for pair, inc_count in pairs_to_increment.items():
|
||||
pair_freq[pair] += inc_count
|
||||
# Пересчитываем первый индекс появления пар
|
||||
pair_first_occurrence = {}
|
||||
for idx in range(len(new_sequence) - 1):
|
||||
pair = (new_sequence[idx], new_sequence[idx + 1])
|
||||
if pair not in pair_first_occurrence:
|
||||
pair_first_occurrence[pair] = idx
|
||||
for pair in list(pair_freq.keys()):
|
||||
if pair not in pair_first_occurrence:
|
||||
del pair_freq[pair]
|
||||
return new_sequence, pair_freq, pair_first_occurrence
|
||||
|
||||
def _build_token_dicts(self) -> None:
|
||||
"""
|
||||
Формирует словари вида <токен, id> и <id, токен> по итоговому списку токенов.
|
||||
"""
|
||||
self.token2id = {token: idx for idx, token in enumerate(self.vocab)}
|
||||
self.id2token = {idx: token for idx, token in enumerate(self.vocab)}
|
||||
60
simple_llm/tokenizer/simple_bpe.py
Normal file
60
simple_llm/tokenizer/simple_bpe.py
Normal file
@@ -0,0 +1,60 @@
|
||||
from .bpe_interface import BPE
|
||||
|
||||
class SimpleBPE(BPE):
|
||||
def fit(self, text: str):
|
||||
# 1. Получаем уникальные токены (символы)
|
||||
unique_tokens = sorted(set(text))
|
||||
tokens = unique_tokens.copy()
|
||||
|
||||
# 2. Разбиваем текст на токены-символы
|
||||
sequence = list(text)
|
||||
|
||||
# 3. Объединяем токены до достижения нужного размера словаря
|
||||
while len(tokens) < self.vocab_size:
|
||||
#print(f'len={len(tokens)} < {self.vocab_size}')
|
||||
# Считаем частоты пар
|
||||
pair_freq = {}
|
||||
for i in range(len(sequence) - 1):
|
||||
pair = (sequence[i], sequence[i + 1])
|
||||
#print(f'pair = {pair}')
|
||||
if pair not in pair_freq:
|
||||
pair_freq[pair] = 0
|
||||
pair_freq[pair] += 1
|
||||
|
||||
|
||||
#print(f'pair_freq = {pair_freq}')
|
||||
if not pair_freq:
|
||||
break # нет пар — выходим
|
||||
|
||||
# Находим самую частую пару (в случае равенства — та, что встретилась первой)
|
||||
most_frequent_pair = max(pair_freq.items(), key=lambda x: (x[1], -self._pair_first_index(sequence, x[0])))[0]
|
||||
#print(most_frequent_pair)
|
||||
# Создаем новый токен
|
||||
new_token = most_frequent_pair[0] + most_frequent_pair[1]
|
||||
#print(f"new token={new_token}")
|
||||
tokens.append(new_token)
|
||||
#print(f"tokens={tokens}")
|
||||
|
||||
i = 0
|
||||
new_sequence = []
|
||||
|
||||
while i < len(sequence):
|
||||
if i < len(sequence) - 1 and (sequence[i], sequence[i + 1]) == most_frequent_pair:
|
||||
new_sequence.append(new_token)
|
||||
i += 2 # пропускаем два символа — заменённую пару
|
||||
else:
|
||||
new_sequence.append(sequence[i])
|
||||
i += 1
|
||||
sequence = new_sequence
|
||||
#break
|
||||
|
||||
# 4. Создаем словари
|
||||
self.vocab = tokens.copy()
|
||||
self.token2id = dict(zip(tokens, range(self.vocab_size)))
|
||||
self.id2token = dict(zip(range(self.vocab_size), tokens))
|
||||
|
||||
def _pair_first_index(self, sequence, pair):
|
||||
for i in range(len(sequence) - 1):
|
||||
if (sequence[i], sequence[i + 1]) == pair:
|
||||
return i
|
||||
return float('inf')
|
||||
Reference in New Issue
Block a user