doc(datasets): update docstrings and tests

This commit is contained in:
Sergey Penkovsky
2025-10-17 10:49:45 +03:00
parent 38c271ca3c
commit 613d784565
10 changed files with 563 additions and 177 deletions

View File

@@ -16,7 +16,7 @@ import torch
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from llm.tokenizers import BPETokenizer
from llm.training.dataset import TextDataset
from llm.datasets.text_dataset import TextDataset
from llm.training.trainer import Trainer
from shared.data import (

View File

View File

@@ -0,0 +1,120 @@
import torch
from torch.utils.data import Dataset
from typing import List, Any
class StreamingTextDataset(Dataset):
"""
StreamingTextDataset — потоковый датасет для LLM/NLP на базе списка строк.
Назначение:
-----------
- Позволяет эффективно обрабатывать большие текстовые выборки, итерируя по заранее подготовленному списку строк.
- При итерации строки токенизируются на лету, превращаются в примеры фиксированной длины block_size (padding/truncation внутри класса).
- Поддерживает стандартный DataLoader PyTorch.
Ключевые особенности:
---------------------
- Не требует загрузки всей коллекции токенов в RAM: поддерживает работу с любым размером датасета, если список строк заранее подготовлен.
- Каждый пример (sample) формируется при обращении; не хранит массив батчей, не использует файлы внутри.
- Поддерживает любой токенизатор с методом encode (например, BPE, SentencePiece, HF Tokenizer).
- batch_size и параллелизм (num_workers) контролируются через DataLoader.
Аргументы конструктора:
-----------------------
texts: List[str] — список строк (предварительно загруженных обучающих примеров).
tokenizer: BaseTokenizer/Any — объект с методом encode(str, **kwargs) -> List[int].
block_size: int — длина одного выходного примера в токенах (padding/truncation если нужно).
Пример использования:
---------------------
>>> texts = open("wiki_sample.txt", encoding="utf-8").read().splitlines()
>>> ds = StreamingTextDataset(texts, tokenizer=tokenizer, block_size=512)
>>> loader = torch.utils.data.DataLoader(ds, batch_size=8)
>>> for batch in loader:
... print(batch['input_ids'].shape) # torch.Size([8, 512])
Особенности:
------------
- Проектирован для бесконечного стриминга текстовых данных из больших коллекций.
- При batch_size > 1 каждый batch формируется DataLoader-ом из yield'ов этого датасета.
- Не работает с файлами напрямую, только со строками (списком).
- Подходит для обучения LLM, тестирования, дообучения, оценки на больших потоковых данных.
References:
-----------
- PyTorch IterableDataset: https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset
- HuggingFace streaming datasets: https://huggingface.co/docs/datasets/stream
- Практика масштабного обучения LLM: https://github.com/karpathy/nanoGPT/issues/182
"""
def __init__(self, texts: List[str], tokenizer: Any, block_size: int = 128):
"""
Инициализация StreamingTextDataset из списка строк.
Аргументы:
texts (List[str]): Список строк — текстовые обучающие примеры; весь датасет должен помещаться в этот список.
tokenizer (Any): Токенизатор с методом encode(text, **kwargs) -> List[int].
block_size (int, по умолчанию 128): Желаемая длина токенизированного примера (padding/truncation внутри класса).
Особенности:
- Поддерживает итеративную загрузку, эффективен для больших текстовых выборок.
- Каждый пример автоматически дополняется или усекается до block_size.
- Не читает данные из файла/буфера, а только из заранее подготовленного списка строк.
Пример:
>>> ds = StreamingTextDataset(texts=all_lines, tokenizer=tokenizer, block_size=256)
>>> for ex in ds:
... print(ex['input_ids'].shape) # torch.Size([256])
"""
self.texts = texts
self.tokenizer = tokenizer
self.block_size = block_size
# Получаем pad_token_id из токенизатора
self.pad_token_id = getattr(tokenizer, "pad_token_id", 0)
def __len__(self):
"""
Возвращает количество доступных примеров в датасете.
Returns:
int: Число примеров (равно длине исходного списка строк).
"""
return len(self.texts)
def __getitem__(self, idx):
"""
Получить обработанный пример по индексу из потокового датасета.
Аргументы:
idx (int): Индекс примера в исходном списке строк.
Возвращает:
dict: Словарь с тензорами для обучения LLM:
- 'input_ids': torch.Tensor формы [block_size] — индексы токенов (padding/truncation выполнены)
- 'labels': torch.Tensor формы [block_size] — целевые метки (обычно совпадают с input_ids)
Пример:
>>> item = dataset[10]
>>> assert isinstance(item, dict)
>>> assert item['input_ids'].shape == (block_size,)
>>> assert 'labels' in item
"""
text = self.texts[idx]
# Токенизация на лету
input_ids = self.tokenizer.encode(text, add_special_tokens=False)
# Обрезаем или дополняем до нужной длины
if len(input_ids) > self.block_size:
input_ids = input_ids[: self.block_size]
else:
input_ids = input_ids + [self.pad_token_id] * (
self.block_size - len(input_ids)
)
input_ids = torch.tensor(input_ids, dtype=torch.long)
labels = input_ids.clone()
return {"input_ids": input_ids, "labels": labels}

View File

@@ -0,0 +1,112 @@
import torch
from torch.utils.data import Dataset
from typing import List, Any
class TextDataset(Dataset):
"""
TextDataset — простой датасет для подачи обучающих токенов LLM (batch-режим или по одному примеру).
Назначение:
-----------
- Хранит последовательности текста (каждую строку или пример) в виде списка строк.
- При обращении сам токенизирует строку в последовательность индексов с помощью заданного токенизатора.
- Каждый пример автоматически усекётся или будет дополнен до фиксированной длины block_size (padding — zeros).
Формат и аргументы конструктора:
-------------------------------
texts: List[str]
Список строк, каждая из которых рассматривается как отдельный обучающий пример.
tokenizer: любой объект с методом encode(str, **kwargs) → List[int]
Обеспечивает сопоставление строки списку токенов (например, BPE, HuggingFace, SentencePiece и др.).
block_size: int, по умолчанию 128
Желаемая длина выходной последовательности (padding/truncation внутри класса).
Особенности:
------------
- Класс не работает с файлами напрямую: данные передаются готовым списком строк.
- При недостаточной длине пример дополняется паддингом (нулём или другим токеном, зависит от реализации).
- Может возвращать dict с input_ids, labels и прочими ключами (см. реализацию в функции __getitem__).
Пример использования:
---------------------
>>> with open("dataset.txt", encoding="utf-8") as f:
... texts = f.read().splitlines()
>>> dataset = TextDataset(texts, tokenizer, block_size=256)
>>> from torch.utils.data import DataLoader
>>> loader = DataLoader(dataset, batch_size=4)
>>> for item in loader:
... # item['input_ids'] для обучения LLM
References:
-----------
- Torch Dataset: https://pytorch.org/docs/stable/data.html
- Примеры LLM датасетов в open-source: https://github.com/karpathy/nanoGPT/blob/master/data/shakespeare_char/tokenize.py
"""
def __init__(self, texts: List[str], tokenizer: Any, block_size: int = 128):
"""
Инициализация датасета из списка строк.
Аргументы:
texts (List[str]): Список строк — каждый элемент отдельный обучающий пример.
tokenizer (Any): Токенизатор с методом encode(str, **kwargs) -> List[int].
block_size (int, по умолчанию 128): Желаемая длина результата —
длинные последовательности будут усечены, короткие — дополнены паддингом (pad_token_id или 0).
Особенности:
- Строки не фильтруются и не изменяются внутри датасета.
- Для PAD используется pad_token_id из токенизатора (если есть) либо 0.
- Dict, возвращаемый __getitem__, содержит 'input_ids' и 'labels'.
Пример:
>>> dataset = TextDataset(["hello world", "test string"], tokenizer, block_size=16)
"""
self.examples = []
self.tokenizer = tokenizer
self.block_size = block_size
for text in texts:
# Кодируем текст в токены
input_ids = tokenizer.encode(text, add_special_tokens=False)
# Обрезаем или дополняем до нужной длины
if len(input_ids) > block_size:
input_ids = input_ids[:block_size]
else:
# Дополняем pad_token_id
pad_token_id = getattr(tokenizer, "pad_token_id", 0)
input_ids = input_ids + [pad_token_id] * (block_size - len(input_ids))
self.examples.append(input_ids)
def __len__(self):
"""
Возвращает количество примеров в датасете (длина списка текстов).
Returns:
int: Число примеров в датасете.
"""
return len(self.examples)
def __getitem__(self, idx):
"""
Получить пример из датасета по индексу.
Аргументы:
idx (int): Индекс примера.
Возвращает:
dict: Словарь с тензорами токенов для модели:
- 'input_ids': torch.Tensor shape [block_size], индексы токенов для входа.
- 'labels': torch.Tensor shape [block_size], метки для LM задачи (обычно совпадают с input_ids).
Пример:
>>> item = dataset[7]
>>> assert isinstance(item, dict)
>>> assert item['input_ids'].shape == (block_size,)
>>> assert 'labels' in item
"""
input_ids = torch.tensor(self.examples[idx], dtype=torch.long)
labels = input_ids.clone()
return {"input_ids": input_ids, "labels": labels}

View File

@@ -0,0 +1,124 @@
import torch
from torch.utils.data import Dataset
from typing import List, Any
from llm.datasets.text_dataset import TextDataset
class TextWithSpecialTokensDataset(TextDataset):
"""
TextWithSpecialTokensDataset — датасет для языковых моделей с поддержкой специальных токенов (BOS, EOS, PAD).
Назначение:
-----------
- Работает с уже готовым списком строк (не с файлом!).
- Токенизирует строки с помощью заданного токенизатора, вручную вставляет специальные токены (BOS/ EOS/ PAD).
- Обрезает или дополняет каждую последовательность до длины block_size.
Аргументы конструктора:
-----------------------
texts (List[str]): Список обучающих строк (примеров).
tokenizer (Any): Любой токенизатор с методом encode(text, **kwargs).
block_size (int, default=128): Желаемая длина примера (padding/truncation).
add_bos (bool, default=False): Если True, добавляет BOS-токен в начало каждой последовательности.
add_eos (bool, default=False): Если True, добавляет EOS-токен в конец.
Особенности:
------------
- Если pad_token_id не задан — по умолчанию паддит нулями.
- Все returned примеры — dict с 'input_ids' и 'labels' (shape == block_size).
- Обрезание/дополнение учётное: BOS/EOS не "выдавливаются" обрезкой.
- Пример вызова:
>>> texts = ["пример текста", "ещё текст"]
>>> ds = TextWithSpecialTokensDataset(texts, tokenizer, block_size=16, add_bos=True, add_eos=True)
>>> out = ds[0]
>>> assert out['input_ids'].shape == (16,)
References:
-----------
- OpenAI GPT-2 data loader: https://github.com/openai/gpt-2/blob/master/src/encode.py
- HuggingFace data docs: https://huggingface.co/docs/transformers/pad_truncation
"""
def __init__(
self,
texts: List[str],
tokenizer: Any,
block_size: int = 128,
add_bos: bool = False,
add_eos: bool = False,
):
"""
Инициализация датасета с поддержкой специальных токенов.
Args:
texts (List[str]): Список строк (все ваши обучающие примеры).
tokenizer (Any): Токенизатор с методом encode(text, **kwargs).
block_size (int): Длина выходного примера.
add_bos (bool): Добавлять ли BOS токен в начало.
add_eos (bool): Добавлять ли EOS токен в конец.
"""
self.examples = []
self.tokenizer = tokenizer
self.block_size = block_size
self.add_bos = add_bos
self.add_eos = add_eos
for text in texts:
# Кодируем с специальными токенами
input_ids = tokenizer.encode(
text, add_special_tokens=True, add_bos_token=add_bos, add_eos_token=add_eos
)
# Учитываем специальные токены при обрезке/дополнении
effective_block_size = block_size
if add_bos:
effective_block_size -= 1
if add_eos:
effective_block_size -= 1
if len(input_ids) > effective_block_size:
input_ids = input_ids[:effective_block_size]
# Добавляем специальные токены если нужно
if (
add_bos
and hasattr(tokenizer, "bos_token_id")
and tokenizer.bos_token_id is not None
):
input_ids = [tokenizer.bos_token_id] + input_ids
if (
add_eos
and hasattr(tokenizer, "eos_token_id")
and tokenizer.eos_token_id is not None
):
input_ids = input_ids + [tokenizer.eos_token_id]
# Дополняем до полной длины
pad_token_id = getattr(tokenizer, "pad_token_id", 0)
if len(input_ids) < block_size:
input_ids = input_ids + [pad_token_id] * (block_size - len(input_ids))
self.examples.append(input_ids)
def __len__(self):
"""
Возвращает количество примеров в датасете.
Returns:
int: Размер (len(self.examples)).
"""
return len(self.examples)
def __getitem__(self, idx):
"""
Получить пример с учётом специальных токенов и паддинга.
Args:
idx (int): Индекс в dataset.
Returns:
dict: {'input_ids': torch.Tensor [block_size], 'labels': torch.Tensor [block_size]}
"""
input_ids = torch.tensor(self.examples[idx], dtype=torch.long)
labels = input_ids.clone()
return {"input_ids": input_ids, "labels": labels}

View File

@@ -107,12 +107,12 @@ class Llama(BaseModel):
) -> tuple:
"""
Прямой проход: возвращает logits (и возможно обновлённый cache) по входным токенам.
Args:
x (torch.Tensor): [batch, seq_len] — индексы токенов, shape [batch, seq_len]
use_cache (bool): использовать механизм KV cache (ускоряет autoregressive generation)
cache (list or None): предыдущий кэш, если нужен
Returns:
logits: torch.Tensor [batch, seq_len, vocab_size]
new_cache: новый кэш attention (или None)
@@ -178,25 +178,50 @@ class Llama(BaseModel):
use_cache: bool = True,
) -> torch.Tensor:
"""
Генерация текста c помощью LLaMA (autoregressive Transformer).
Поддерживается:
- greedy и вероятностное сэмплирование (top-k, top-p, temperature)
- кэш attention для ускорения генерации длинных последовательностей
Args:
x (Tensor[int]): начальная последовательность [batch, seq_len]
max_new_tokens (int): сколько новых токенов сгенерировать
do_sample (bool): использовать стохастику (True) или жадный выбор (False)
temperature (float): масштаб для softmax (важно для sampling)
top_k (int|None): ограничение на количество кандидатов (top-k sampling)
top_p (float|None): nucleus sampling
use_cache (bool): ускоряет autoregressive при длинной генерации
Returns:
output (Tensor[int]): [batch, seq_len + max_new_tokens]
Пример:
>>> prompt = tokenizer.encode('Meta AI', return_tensors="pt")
>>> generated = model.generate(prompt, max_new_tokens=30, do_sample=True)
>>> print(tokenizer.decode(generated[0]))
Авторегрессивная генерация последовательностей на основе LLaMA (greedy, temperature, top-k, top-p/nucleus, поддержка KV-кэша).
Аргументы:
x (torch.Tensor): Входной тензор с токенами shape [batch_size, seq_len].
max_new_tokens (int): Максимальное количество новых токенов для генерации.
do_sample (bool): Использовать вероятностное сэмплирование (True) или жадный режим (False, argmax).
temperature (float): Температура (сглаживание распределения вероятностей, >0; по умолчанию 1.0).
>1.0 — менее предсказуемые, более разнообразные выборки.
<1.0 — более строгие, консервативные выборки.
top_k (int, опционально): Top-k сэмплирование (ограничение выбора k самыми вероятными токенами).
top_p (float, опционально): Nucleus (top-p) sampling (срез по кумулятивной вероятности ≤ top_p, см. Holtzman et al., 2019).
use_cache (bool, по умолчанию True): Использовать KV-кэш для ускорения генерации.
Возвращает:
torch.Tensor: Последовательность токенов shape [batch_size, seq_len + max_new_tokens].
Исключения:
ValueError: Если x длиннее максимально допустимой длины (max_seq_len модели).
ValueError: Если temperature ≤ 0.
ValueError: Если одновременно заданы top_k и top_p.
ValueError: Если top_k ≤ 0.
ValueError: Если top_p не в диапазоне (0, 1].
Примеры:
>>> # Строго жадная генерация
>>> out = model.generate(input_ids, max_new_tokens=16, do_sample=False)
>>> # Вероятностная генерация с температурой
>>> out = model.generate(input_ids, max_new_tokens=16, do_sample=True, temperature=0.7)
>>> # Top-k sampling
>>> out = model.generate(input_ids, max_new_tokens=16, do_sample=True, top_k=50)
>>> # Top-p (nucleus)
>>> out = model.generate(input_ids, max_new_tokens=16, do_sample=True, top_p=0.92)
>>> # Комбинация температуры и top-k
>>> out = model.generate(input_ids, max_new_tokens=16, do_sample=True, temperature=1.0, top_k=100)
Примечания:
- temperature, top_k, top_p применяются только если do_sample=True.
- Одновременное использование top_k и top_p запрещено.
- Для воспроизводимых результатов зафиксируйте seed через torch.manual_seed.
- Возвращается только индексы токенов; для получения вероятностей используйте forward.
Ссылки:
- Holtzman et al., "The Curious Case of Neural Text Degeneration" (nucleus/top-p): https://arxiv.org/abs/1904.09751
- LLaMA: https://arxiv.org/abs/2302.13971
"""
cache = None

View File

@@ -1,155 +0,0 @@
import torch
from torch.utils.data import Dataset
from typing import List, Any
class TextDataset(Dataset):
"""
Простой датасет для языкового моделирования (LLM).
Работает с любым токенизатором, реализующим интерфейс BaseTokenizer.
"""
def __init__(self, texts: List[str], tokenizer: Any, block_size: int = 128):
"""
Инициализация датасета.
Args:
texts: Список текстов для обучения
tokenizer: Токенизатор с методами encode/decode
block_size: Максимальная длина последовательности
"""
self.examples = []
self.tokenizer = tokenizer
self.block_size = block_size
for text in texts:
# Кодируем текст в токены
input_ids = tokenizer.encode(text, add_special_tokens=False)
# Обрезаем или дополняем до нужной длины
if len(input_ids) > block_size:
input_ids = input_ids[:block_size]
else:
# Дополняем pad_token_id
pad_token_id = getattr(tokenizer, "pad_token_id", 0)
input_ids = input_ids + [pad_token_id] * (block_size - len(input_ids))
self.examples.append(input_ids)
def __len__(self):
return len(self.examples)
def __getitem__(self, idx):
input_ids = torch.tensor(self.examples[idx], dtype=torch.long)
labels = input_ids.clone()
return {"input_ids": input_ids, "labels": labels}
class StreamingTextDataset(Dataset):
"""
Датасет для потоковой обработки больших текстов.
Токенизация происходит на лету, что экономит память.
"""
def __init__(self, texts: List[str], tokenizer: Any, block_size: int = 128):
self.texts = texts
self.tokenizer = tokenizer
self.block_size = block_size
# Получаем pad_token_id из токенизатора
self.pad_token_id = getattr(tokenizer, "pad_token_id", 0)
def __len__(self):
return len(self.texts)
def __getitem__(self, idx):
text = self.texts[idx]
# Токенизация на лету
input_ids = self.tokenizer.encode(text, add_special_tokens=False)
# Обрезаем или дополняем до нужной длины
if len(input_ids) > self.block_size:
input_ids = input_ids[: self.block_size]
else:
input_ids = input_ids + [self.pad_token_id] * (
self.block_size - len(input_ids)
)
input_ids = torch.tensor(input_ids, dtype=torch.long)
labels = input_ids.clone()
return {"input_ids": input_ids, "labels": labels}
class TextDatasetWithSpecialTokens(TextDataset):
"""
Расширенная версия TextDataset с поддержкой специальных токенов.
"""
def __init__(
self,
texts: List[str],
tokenizer: Any,
block_size: int = 128,
add_bos: bool = False,
add_eos: bool = False,
):
"""
Args:
texts: Список текстов
tokenizer: Токенизатор
block_size: Максимальная длина
add_bos: Добавлять токен начала последовательности
add_eos: Добавлять токен конца последовательности
"""
self.examples = []
self.tokenizer = tokenizer
self.block_size = block_size
self.add_bos = add_bos
self.add_eos = add_eos
for text in texts:
# Кодируем с специальными токенами
input_ids = tokenizer.encode(
text, add_special_tokens=True, add_bos_token=add_bos, add_eos_token=eos
)
# Учитываем специальные токены при обрезке/дополнении
effective_block_size = block_size
if add_bos:
effective_block_size -= 1
if add_eos:
effective_block_size -= 1
if len(input_ids) > effective_block_size:
input_ids = input_ids[:effective_block_size]
# Добавляем специальные токены если нужно
if (
add_bos
and hasattr(tokenizer, "bos_token_id")
and tokenizer.bos_token_id is not None
):
input_ids = [tokenizer.bos_token_id] + input_ids
if (
add_eos
and hasattr(tokenizer, "eos_token_id")
and tokenizer.eos_token_id is not None
):
input_ids = input_ids + [tokenizer.eos_token_id]
# Дополняем до полной длины
pad_token_id = getattr(tokenizer, "pad_token_id", 0)
if len(input_ids) < block_size:
input_ids = input_ids + [pad_token_id] * (block_size - len(input_ids))
self.examples.append(input_ids)
def __len__(self):
return len(self.examples)
def __getitem__(self, idx):
input_ids = torch.tensor(self.examples[idx], dtype=torch.long)
labels = input_ids.clone()
return {"input_ids": input_ids, "labels": labels}

View File

@@ -0,0 +1,49 @@
import torch
import pytest
from llm.datasets.streaming_text_dataset import StreamingTextDataset
class DummyTokenizer:
def __init__(self, vocab_size=100):
self.vocab_size = vocab_size
def encode(self, text, **kwargs):
return [len(w) % self.vocab_size for w in text.strip().split()]
def test_streaming_textdataset_basic_shape():
texts = ["hello world", "big transformers are fun", "LLM test string"]
tokenizer = DummyTokenizer(50)
block_size = 7
ds = StreamingTextDataset(texts, tokenizer, block_size)
assert len(ds) == 3
for i in range(len(ds)):
item = ds[i]
assert isinstance(item, dict)
assert "input_ids" in item
assert item["input_ids"].shape == (block_size,)
assert "labels" in item
assert item["labels"].shape == (block_size,)
def test_streaming_textdataset_padding_and_truncation():
texts = ["short", "one two three four five six seven eight nine ten"]
tokenizer = DummyTokenizer(40)
block_size = 4
ds = StreamingTextDataset(texts, tokenizer, block_size)
# короткое предложение padded
assert (ds[0]["input_ids"].shape[0] == block_size)
# длинное предложение truncated
assert (ds[1]["input_ids"].shape[0] == block_size)
def test_streaming_textdataset_index_error():
texts = ["sample"]
tokenizer = DummyTokenizer(10)
ds = StreamingTextDataset(texts, tokenizer, block_size=5)
with pytest.raises(IndexError):
_ = ds[1]
def test_streaming_textdataset_content_matching():
texts = ["foo bar baz", "abc def"]
tokenizer = DummyTokenizer(99)
block_size = 5
ds = StreamingTextDataset(texts, tokenizer, block_size)
# Проверка, что input_ids и labels совпадают точно
for i in range(len(ds)):
assert torch.equal(ds[i]["input_ids"], ds[i]["labels"])

View File

@@ -0,0 +1,49 @@
import torch
import pytest
from llm.datasets.text_dataset import TextDataset
class DummyTokenizer:
def __init__(self, vocab_size=100):
self.vocab_size = vocab_size
def encode(self, text, **kwargs):
return [len(w) % self.vocab_size for w in text.strip().split()]
def test_textdataset_shape_and_basic():
texts = ["hello world", "this is a test", "Transformer model"]
tokenizer = DummyTokenizer(50)
block_size = 6
dataset = TextDataset(texts, tokenizer, block_size=block_size)
for i in range(len(dataset)):
x = dataset[i]
assert isinstance(x, dict)
assert "input_ids" in x
assert isinstance(x["input_ids"], torch.Tensor)
assert x["input_ids"].shape == (block_size,)
def test_textdataset_truncation_and_padding():
texts = ["one two three four five six seven", "short"]
tokenizer = DummyTokenizer(100)
block_size = 5
dataset = TextDataset(texts, tokenizer, block_size=block_size)
assert isinstance(dataset[0], dict)
assert dataset[0]["input_ids"].shape[0] == 5
assert dataset[1]["input_ids"].shape[0] == 5
def test_textdataset_index_error():
texts = ["a", "b"]
tokenizer = DummyTokenizer(10)
dataset = TextDataset(texts, tokenizer, block_size=3)
with pytest.raises(IndexError):
_ = dataset[2]
def test_textdataset_encoding():
texts = ["привет", "мир"]
tokenizer = DummyTokenizer(20)
block_size = 4
dataset = TextDataset(texts, tokenizer, block_size=block_size)
assert len(dataset) == 2
x = dataset[0]
assert isinstance(x, dict)
assert "input_ids" in x
assert isinstance(x["input_ids"], torch.Tensor)
assert x["input_ids"].shape == (block_size,)

View File

@@ -0,0 +1,62 @@
import torch
import pytest
from llm.datasets.text_with_special_tokens_dataset import TextWithSpecialTokensDataset
class DummyTokenizer:
def __init__(self):
self.bos_token_id = 101
self.eos_token_id = 102
self.pad_token_id = 0
def encode(self, text, add_special_tokens=False, add_bos_token=False, add_eos_token=False):
ids = [ord(c) % 50 for c in text.strip()]
if add_bos_token:
ids = [self.bos_token_id] + ids
if add_eos_token:
ids = ids + [self.eos_token_id]
return ids
def test_specialtokens_basic_bos_eos():
texts = ["abc", "d"]
tokenizer = DummyTokenizer()
block_size = 6
ds = TextWithSpecialTokensDataset(texts, tokenizer, block_size=block_size, add_bos=True, add_eos=True)
for i in range(len(ds)):
item = ds[i]
assert isinstance(item, dict)
assert "input_ids" in item
assert item["input_ids"].shape == (block_size,)
assert item["input_ids"][0] == tokenizer.bos_token_id
assert item["input_ids"][item["input_ids"].ne(tokenizer.pad_token_id).sum() - 1] == tokenizer.eos_token_id
def test_specialtokens_padding_and_truncation():
texts = ["qwertyuiop", "z"]
tokenizer = DummyTokenizer()
block_size = 5
ds = TextWithSpecialTokensDataset(texts, tokenizer, block_size=block_size, add_bos=True)
assert ds[0]["input_ids"].shape[0] == block_size
assert ds[1]["input_ids"][-1] == tokenizer.pad_token_id
def test_specialtokens_no_bos_eos():
texts = ["xyz"]
tokenizer = DummyTokenizer()
block_size = 6
ds = TextWithSpecialTokensDataset(texts, tokenizer, block_size=block_size, add_bos=False, add_eos=False)
item = ds[0]["input_ids"]
assert tokenizer.bos_token_id not in item
assert tokenizer.eos_token_id not in item
assert item.shape == (block_size,)
def test_specialtokens_index_error():
texts = ["sample"]
tokenizer = DummyTokenizer()
ds = TextWithSpecialTokensDataset(texts, tokenizer, block_size=8)
with pytest.raises(IndexError):
_ = ds[1]
def test_specialtokens_labels():
texts = ["abcd"]
tokenizer = DummyTokenizer()
block_size = 7
ds = TextWithSpecialTokensDataset(texts, tokenizer, block_size=block_size, add_bos=True, add_eos=True)
item = ds[0]
assert torch.equal(item["input_ids"], item["labels"])