mirror of
https://github.com/pese-git/llm-arch-research.git
synced 2026-01-24 13:32:08 +00:00
50 lines
1.8 KiB
Python
50 lines
1.8 KiB
Python
|
|
import torch
|
||
|
|
import pytest
|
||
|
|
from llm.datasets.streaming_text_dataset import StreamingTextDataset
|
||
|
|
|
||
|
|
class DummyTokenizer:
|
||
|
|
def __init__(self, vocab_size=100):
|
||
|
|
self.vocab_size = vocab_size
|
||
|
|
def encode(self, text, **kwargs):
|
||
|
|
return [len(w) % self.vocab_size for w in text.strip().split()]
|
||
|
|
|
||
|
|
def test_streaming_textdataset_basic_shape():
|
||
|
|
texts = ["hello world", "big transformers are fun", "LLM test string"]
|
||
|
|
tokenizer = DummyTokenizer(50)
|
||
|
|
block_size = 7
|
||
|
|
ds = StreamingTextDataset(texts, tokenizer, block_size)
|
||
|
|
assert len(ds) == 3
|
||
|
|
for i in range(len(ds)):
|
||
|
|
item = ds[i]
|
||
|
|
assert isinstance(item, dict)
|
||
|
|
assert "input_ids" in item
|
||
|
|
assert item["input_ids"].shape == (block_size,)
|
||
|
|
assert "labels" in item
|
||
|
|
assert item["labels"].shape == (block_size,)
|
||
|
|
|
||
|
|
def test_streaming_textdataset_padding_and_truncation():
|
||
|
|
texts = ["short", "one two three four five six seven eight nine ten"]
|
||
|
|
tokenizer = DummyTokenizer(40)
|
||
|
|
block_size = 4
|
||
|
|
ds = StreamingTextDataset(texts, tokenizer, block_size)
|
||
|
|
# короткое предложение padded
|
||
|
|
assert (ds[0]["input_ids"].shape[0] == block_size)
|
||
|
|
# длинное предложение truncated
|
||
|
|
assert (ds[1]["input_ids"].shape[0] == block_size)
|
||
|
|
|
||
|
|
def test_streaming_textdataset_index_error():
|
||
|
|
texts = ["sample"]
|
||
|
|
tokenizer = DummyTokenizer(10)
|
||
|
|
ds = StreamingTextDataset(texts, tokenizer, block_size=5)
|
||
|
|
with pytest.raises(IndexError):
|
||
|
|
_ = ds[1]
|
||
|
|
|
||
|
|
def test_streaming_textdataset_content_matching():
|
||
|
|
texts = ["foo bar baz", "abc def"]
|
||
|
|
tokenizer = DummyTokenizer(99)
|
||
|
|
block_size = 5
|
||
|
|
ds = StreamingTextDataset(texts, tokenizer, block_size)
|
||
|
|
# Проверка, что input_ids и labels совпадают точно
|
||
|
|
for i in range(len(ds)):
|
||
|
|
assert torch.equal(ds[i]["input_ids"], ds[i]["labels"])
|