mirror of
https://github.com/pese-git/llm-arch-research.git
synced 2026-01-24 05:21:16 +00:00
49 lines
1.7 KiB
Python
49 lines
1.7 KiB
Python
|
|
import torch
|
||
|
|
import pytest
|
||
|
|
from llm.datasets.text_dataset import TextDataset
|
||
|
|
|
||
|
|
class DummyTokenizer:
|
||
|
|
def __init__(self, vocab_size=100):
|
||
|
|
self.vocab_size = vocab_size
|
||
|
|
def encode(self, text, **kwargs):
|
||
|
|
return [len(w) % self.vocab_size for w in text.strip().split()]
|
||
|
|
|
||
|
|
def test_textdataset_shape_and_basic():
|
||
|
|
texts = ["hello world", "this is a test", "Transformer model"]
|
||
|
|
tokenizer = DummyTokenizer(50)
|
||
|
|
block_size = 6
|
||
|
|
dataset = TextDataset(texts, tokenizer, block_size=block_size)
|
||
|
|
for i in range(len(dataset)):
|
||
|
|
x = dataset[i]
|
||
|
|
assert isinstance(x, dict)
|
||
|
|
assert "input_ids" in x
|
||
|
|
assert isinstance(x["input_ids"], torch.Tensor)
|
||
|
|
assert x["input_ids"].shape == (block_size,)
|
||
|
|
|
||
|
|
def test_textdataset_truncation_and_padding():
|
||
|
|
texts = ["one two three four five six seven", "short"]
|
||
|
|
tokenizer = DummyTokenizer(100)
|
||
|
|
block_size = 5
|
||
|
|
dataset = TextDataset(texts, tokenizer, block_size=block_size)
|
||
|
|
assert isinstance(dataset[0], dict)
|
||
|
|
assert dataset[0]["input_ids"].shape[0] == 5
|
||
|
|
assert dataset[1]["input_ids"].shape[0] == 5
|
||
|
|
|
||
|
|
def test_textdataset_index_error():
|
||
|
|
texts = ["a", "b"]
|
||
|
|
tokenizer = DummyTokenizer(10)
|
||
|
|
dataset = TextDataset(texts, tokenizer, block_size=3)
|
||
|
|
with pytest.raises(IndexError):
|
||
|
|
_ = dataset[2]
|
||
|
|
|
||
|
|
def test_textdataset_encoding():
|
||
|
|
texts = ["привет", "мир"]
|
||
|
|
tokenizer = DummyTokenizer(20)
|
||
|
|
block_size = 4
|
||
|
|
dataset = TextDataset(texts, tokenizer, block_size=block_size)
|
||
|
|
assert len(dataset) == 2
|
||
|
|
x = dataset[0]
|
||
|
|
assert isinstance(x, dict)
|
||
|
|
assert "input_ids" in x
|
||
|
|
assert isinstance(x["input_ids"], torch.Tensor)
|
||
|
|
assert x["input_ids"].shape == (block_size,)
|