doc(datasets): update docstrings and tests

This commit is contained in:
Sergey Penkovsky
2025-10-17 10:49:45 +03:00
parent 38c271ca3c
commit 613d784565
10 changed files with 563 additions and 177 deletions

View File

@@ -0,0 +1,49 @@
import torch
import pytest
from llm.datasets.text_dataset import TextDataset
class DummyTokenizer:
def __init__(self, vocab_size=100):
self.vocab_size = vocab_size
def encode(self, text, **kwargs):
return [len(w) % self.vocab_size for w in text.strip().split()]
def test_textdataset_shape_and_basic():
texts = ["hello world", "this is a test", "Transformer model"]
tokenizer = DummyTokenizer(50)
block_size = 6
dataset = TextDataset(texts, tokenizer, block_size=block_size)
for i in range(len(dataset)):
x = dataset[i]
assert isinstance(x, dict)
assert "input_ids" in x
assert isinstance(x["input_ids"], torch.Tensor)
assert x["input_ids"].shape == (block_size,)
def test_textdataset_truncation_and_padding():
texts = ["one two three four five six seven", "short"]
tokenizer = DummyTokenizer(100)
block_size = 5
dataset = TextDataset(texts, tokenizer, block_size=block_size)
assert isinstance(dataset[0], dict)
assert dataset[0]["input_ids"].shape[0] == 5
assert dataset[1]["input_ids"].shape[0] == 5
def test_textdataset_index_error():
texts = ["a", "b"]
tokenizer = DummyTokenizer(10)
dataset = TextDataset(texts, tokenizer, block_size=3)
with pytest.raises(IndexError):
_ = dataset[2]
def test_textdataset_encoding():
texts = ["привет", "мир"]
tokenizer = DummyTokenizer(20)
block_size = 4
dataset = TextDataset(texts, tokenizer, block_size=block_size)
assert len(dataset) == 2
x = dataset[0]
assert isinstance(x, dict)
assert "input_ids" in x
assert isinstance(x["input_ids"], torch.Tensor)
assert x["input_ids"].shape == (block_size,)