feat: implement encode/decode methods

This commit is contained in:
Sergey Penkovsky
2025-07-12 11:48:34 +03:00
parent 3de32b2856
commit 6d746a960e

View File

@@ -33,7 +33,55 @@ class BPE(ABC):
pass pass
def encode(self, text: str): def encode(self, text: str):
raise NotImplementedError("Implement in subclass if needed.") # 1. Разбиваем текст на токены-символы
sequence = list(text)
# 2. Инициализация пустого списка токенов
tokens = []
# 3. Установить i = 0
i = 0
while i < len(text):
# 3.1 Найти все токены в словаре, начинающиеся с text[i]
start_char = text[i]
result = [token for token in self.vocab if token.startswith(start_char)]
# 3.2 Выбрать самый длинный подходящий токен
find_token = self._find_max_matching_token(text[i:], result)
if find_token is None:
# Обработка неизвестного символа
tokens.append(text[i]) # Добавляем сам символ как токен
i += 1
else:
# 3.3 Добавить токен в результат
tokens.append(find_token)
# 3.4 Увеличить i на длину токена
i += len(find_token)
# 4. Заменить токены на их ID
return self._tokens_to_ids(tokens)
def _find_max_matching_token(self, text: str, tokens: list):
"""Находит самый длинный токен из списка, с которого начинается текст"""
matching = [token for token in tokens if text.startswith(token)]
return max(matching, key=len) if matching else None
def _tokens_to_ids(self, tokens):
"""Конвертирует список токенов в их ID с обработкой неизвестных токенов"""
ids = []
for token in tokens:
if token in self.token2id:
ids.append(self.token2id[token])
else:
ids.append(-1) # Специальное значение
return ids
def decode(self, ids: list[int]): def decode(self, ids: list[int]):
raise NotImplementedError("Implement in subclass if needed.") return ''.join(self._ids_to_tokens(ids))
def _ids_to_tokens(self, ids: list) -> list:
"""Конвертирует список Ids в их tokens"""
tokens = []
for id in ids:
if id in self.id2token:
tokens.append(self.id2token[id])
else:
tokens.append('') # Специальное значение
return tokens