From ea9d63da3ae2cdd3d909a15866f6128a20a8a8d5 Mon Sep 17 00:00:00 2001 From: Sergey Penkovsky Date: Fri, 18 Jul 2025 00:19:43 +0300 Subject: [PATCH] =?UTF-8?q?=D0=9E=D0=B1=D0=BD=D0=BE=D0=B2=D0=B8=D0=BB=20?= =?UTF-8?q?=D0=B8=D0=BC=D0=BF=D0=BE=D1=80=D1=82=D1=8B=20=D0=BF=D0=BE=D1=81?= =?UTF-8?q?=D0=BB=D0=B5=20=D0=BF=D0=B5=D1=80=D0=B5=D0=B8=D0=BC=D0=B5=D0=BD?= =?UTF-8?q?=D0=BE=D0=B2=D0=B0=D0=BD=D0=B8=D1=8F=20token=5Fembedings.py=20?= =?UTF-8?q?=E2=86=92=20token=5Fembeddings.py?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- example/example_token_embeddings.py | 2 +- tests/test_token_embeddings.py | 47 +++++++++++++++++++++++++++++ 2 files changed, 48 insertions(+), 1 deletion(-) create mode 100644 tests/test_token_embeddings.py diff --git a/example/example_token_embeddings.py b/example/example_token_embeddings.py index 0b80448..4539a27 100644 --- a/example/example_token_embeddings.py +++ b/example/example_token_embeddings.py @@ -10,7 +10,7 @@ import torch import matplotlib.pyplot as plt from sklearn.decomposition import PCA -from simple_llm.embedding.token_embedings import TokenEmbeddings +from simple_llm.embedding.token_embeddings import TokenEmbeddings def basic_example(): """Базовый пример использования TokenEmbeddings""" diff --git a/tests/test_token_embeddings.py b/tests/test_token_embeddings.py new file mode 100644 index 0000000..e28b747 --- /dev/null +++ b/tests/test_token_embeddings.py @@ -0,0 +1,47 @@ +import torch +import pytest +from simple_llm.embedding.token_embedings import TokenEmbeddings + +class TestTokenEmbeddings: + """Unit tests for TokenEmbeddings class""" + + @pytest.fixture + def embedding_layer(self): + return TokenEmbeddings(vocab_size=100, emb_size=32) + + def test_initialization(self, embedding_layer): + """Test layer initialization""" + assert isinstance(embedding_layer, torch.nn.Module) + assert embedding_layer._embedding.num_embeddings == 100 + assert embedding_layer._embedding.embedding_dim == 32 + + def test_forward_shape(self, embedding_layer): + """Test output shape of forward pass""" + test_input = torch.tensor([ + [1, 2, 3], + [4, 5, 6] + ]) + output = embedding_layer(test_input) + assert output.shape == (2, 3, 32) # batch_size=2, seq_len=3, emb_size=32 + + def test_embedding_values(self, embedding_layer): + """Test that embeddings are trainable""" + input_tensor = torch.tensor([[1]]) + before = embedding_layer(input_tensor).clone() + + # Simulate training step + optimizer = torch.optim.SGD(embedding_layer.parameters(), lr=0.1) + loss = embedding_layer(input_tensor).sum() + loss.backward() + optimizer.step() + + after = embedding_layer(input_tensor) + assert not torch.allclose(before, after), "Embeddings should change after training" + + def test_out_of_vocab(self, embedding_layer): + """Test handling of out-of-vocabulary indices""" + with pytest.raises(IndexError): + embedding_layer(torch.tensor([[100]])) # vocab_size=100 + +if __name__ == "__main__": + pytest.main(["-v", __file__])