mirror of
https://github.com/pese-git/llm-arch-research.git
synced 2026-01-23 21:10:54 +00:00
doc(datasets): update docstrings and tests
This commit is contained in:
49
llm/tests/datasets/test_streaming_text_dataset.py
Normal file
49
llm/tests/datasets/test_streaming_text_dataset.py
Normal file
@@ -0,0 +1,49 @@
|
||||
import torch
|
||||
import pytest
|
||||
from llm.datasets.streaming_text_dataset import StreamingTextDataset
|
||||
|
||||
class DummyTokenizer:
|
||||
def __init__(self, vocab_size=100):
|
||||
self.vocab_size = vocab_size
|
||||
def encode(self, text, **kwargs):
|
||||
return [len(w) % self.vocab_size for w in text.strip().split()]
|
||||
|
||||
def test_streaming_textdataset_basic_shape():
|
||||
texts = ["hello world", "big transformers are fun", "LLM test string"]
|
||||
tokenizer = DummyTokenizer(50)
|
||||
block_size = 7
|
||||
ds = StreamingTextDataset(texts, tokenizer, block_size)
|
||||
assert len(ds) == 3
|
||||
for i in range(len(ds)):
|
||||
item = ds[i]
|
||||
assert isinstance(item, dict)
|
||||
assert "input_ids" in item
|
||||
assert item["input_ids"].shape == (block_size,)
|
||||
assert "labels" in item
|
||||
assert item["labels"].shape == (block_size,)
|
||||
|
||||
def test_streaming_textdataset_padding_and_truncation():
|
||||
texts = ["short", "one two three four five six seven eight nine ten"]
|
||||
tokenizer = DummyTokenizer(40)
|
||||
block_size = 4
|
||||
ds = StreamingTextDataset(texts, tokenizer, block_size)
|
||||
# короткое предложение padded
|
||||
assert (ds[0]["input_ids"].shape[0] == block_size)
|
||||
# длинное предложение truncated
|
||||
assert (ds[1]["input_ids"].shape[0] == block_size)
|
||||
|
||||
def test_streaming_textdataset_index_error():
|
||||
texts = ["sample"]
|
||||
tokenizer = DummyTokenizer(10)
|
||||
ds = StreamingTextDataset(texts, tokenizer, block_size=5)
|
||||
with pytest.raises(IndexError):
|
||||
_ = ds[1]
|
||||
|
||||
def test_streaming_textdataset_content_matching():
|
||||
texts = ["foo bar baz", "abc def"]
|
||||
tokenizer = DummyTokenizer(99)
|
||||
block_size = 5
|
||||
ds = StreamingTextDataset(texts, tokenizer, block_size)
|
||||
# Проверка, что input_ids и labels совпадают точно
|
||||
for i in range(len(ds)):
|
||||
assert torch.equal(ds[i]["input_ids"], ds[i]["labels"])
|
||||
49
llm/tests/datasets/test_text_dataset.py
Normal file
49
llm/tests/datasets/test_text_dataset.py
Normal file
@@ -0,0 +1,49 @@
|
||||
import torch
|
||||
import pytest
|
||||
from llm.datasets.text_dataset import TextDataset
|
||||
|
||||
class DummyTokenizer:
|
||||
def __init__(self, vocab_size=100):
|
||||
self.vocab_size = vocab_size
|
||||
def encode(self, text, **kwargs):
|
||||
return [len(w) % self.vocab_size for w in text.strip().split()]
|
||||
|
||||
def test_textdataset_shape_and_basic():
|
||||
texts = ["hello world", "this is a test", "Transformer model"]
|
||||
tokenizer = DummyTokenizer(50)
|
||||
block_size = 6
|
||||
dataset = TextDataset(texts, tokenizer, block_size=block_size)
|
||||
for i in range(len(dataset)):
|
||||
x = dataset[i]
|
||||
assert isinstance(x, dict)
|
||||
assert "input_ids" in x
|
||||
assert isinstance(x["input_ids"], torch.Tensor)
|
||||
assert x["input_ids"].shape == (block_size,)
|
||||
|
||||
def test_textdataset_truncation_and_padding():
|
||||
texts = ["one two three four five six seven", "short"]
|
||||
tokenizer = DummyTokenizer(100)
|
||||
block_size = 5
|
||||
dataset = TextDataset(texts, tokenizer, block_size=block_size)
|
||||
assert isinstance(dataset[0], dict)
|
||||
assert dataset[0]["input_ids"].shape[0] == 5
|
||||
assert dataset[1]["input_ids"].shape[0] == 5
|
||||
|
||||
def test_textdataset_index_error():
|
||||
texts = ["a", "b"]
|
||||
tokenizer = DummyTokenizer(10)
|
||||
dataset = TextDataset(texts, tokenizer, block_size=3)
|
||||
with pytest.raises(IndexError):
|
||||
_ = dataset[2]
|
||||
|
||||
def test_textdataset_encoding():
|
||||
texts = ["привет", "мир"]
|
||||
tokenizer = DummyTokenizer(20)
|
||||
block_size = 4
|
||||
dataset = TextDataset(texts, tokenizer, block_size=block_size)
|
||||
assert len(dataset) == 2
|
||||
x = dataset[0]
|
||||
assert isinstance(x, dict)
|
||||
assert "input_ids" in x
|
||||
assert isinstance(x["input_ids"], torch.Tensor)
|
||||
assert x["input_ids"].shape == (block_size,)
|
||||
62
llm/tests/datasets/test_text_with_special_tokens_dataset.py
Normal file
62
llm/tests/datasets/test_text_with_special_tokens_dataset.py
Normal file
@@ -0,0 +1,62 @@
|
||||
import torch
|
||||
import pytest
|
||||
from llm.datasets.text_with_special_tokens_dataset import TextWithSpecialTokensDataset
|
||||
|
||||
class DummyTokenizer:
|
||||
def __init__(self):
|
||||
self.bos_token_id = 101
|
||||
self.eos_token_id = 102
|
||||
self.pad_token_id = 0
|
||||
def encode(self, text, add_special_tokens=False, add_bos_token=False, add_eos_token=False):
|
||||
ids = [ord(c) % 50 for c in text.strip()]
|
||||
if add_bos_token:
|
||||
ids = [self.bos_token_id] + ids
|
||||
if add_eos_token:
|
||||
ids = ids + [self.eos_token_id]
|
||||
return ids
|
||||
|
||||
def test_specialtokens_basic_bos_eos():
|
||||
texts = ["abc", "d"]
|
||||
tokenizer = DummyTokenizer()
|
||||
block_size = 6
|
||||
ds = TextWithSpecialTokensDataset(texts, tokenizer, block_size=block_size, add_bos=True, add_eos=True)
|
||||
for i in range(len(ds)):
|
||||
item = ds[i]
|
||||
assert isinstance(item, dict)
|
||||
assert "input_ids" in item
|
||||
assert item["input_ids"].shape == (block_size,)
|
||||
assert item["input_ids"][0] == tokenizer.bos_token_id
|
||||
assert item["input_ids"][item["input_ids"].ne(tokenizer.pad_token_id).sum() - 1] == tokenizer.eos_token_id
|
||||
|
||||
def test_specialtokens_padding_and_truncation():
|
||||
texts = ["qwertyuiop", "z"]
|
||||
tokenizer = DummyTokenizer()
|
||||
block_size = 5
|
||||
ds = TextWithSpecialTokensDataset(texts, tokenizer, block_size=block_size, add_bos=True)
|
||||
assert ds[0]["input_ids"].shape[0] == block_size
|
||||
assert ds[1]["input_ids"][-1] == tokenizer.pad_token_id
|
||||
|
||||
def test_specialtokens_no_bos_eos():
|
||||
texts = ["xyz"]
|
||||
tokenizer = DummyTokenizer()
|
||||
block_size = 6
|
||||
ds = TextWithSpecialTokensDataset(texts, tokenizer, block_size=block_size, add_bos=False, add_eos=False)
|
||||
item = ds[0]["input_ids"]
|
||||
assert tokenizer.bos_token_id not in item
|
||||
assert tokenizer.eos_token_id not in item
|
||||
assert item.shape == (block_size,)
|
||||
|
||||
def test_specialtokens_index_error():
|
||||
texts = ["sample"]
|
||||
tokenizer = DummyTokenizer()
|
||||
ds = TextWithSpecialTokensDataset(texts, tokenizer, block_size=8)
|
||||
with pytest.raises(IndexError):
|
||||
_ = ds[1]
|
||||
|
||||
def test_specialtokens_labels():
|
||||
texts = ["abcd"]
|
||||
tokenizer = DummyTokenizer()
|
||||
block_size = 7
|
||||
ds = TextWithSpecialTokensDataset(texts, tokenizer, block_size=block_size, add_bos=True, add_eos=True)
|
||||
item = ds[0]
|
||||
assert torch.equal(item["input_ids"], item["labels"])
|
||||
Reference in New Issue
Block a user