From b43e6a85f451e17b1435f7a757e4702fdf9d311c Mon Sep 17 00:00:00 2001 From: Sergey Penkovsky Date: Fri, 18 Jul 2025 00:24:00 +0300 Subject: [PATCH] =?UTF-8?q?=D0=9E=D0=B1=D0=BD=D0=BE=D0=B2=D0=BB=D0=B5?= =?UTF-8?q?=D0=BD=D0=B8=D0=B5=20=D1=82=D0=B5=D1=81=D1=82=D0=BE=D0=B2=20?= =?UTF-8?q?=D0=BF=D0=BE=D1=81=D0=BB=D0=B5=20=D0=BF=D0=B5=D1=80=D0=B5=D0=B8?= =?UTF-8?q?=D0=BC=D0=B5=D0=BD=D0=BE=D0=B2=D0=B0=D0=BD=D0=B8=D1=8F=20=D0=BC?= =?UTF-8?q?=D0=BE=D0=B4=D1=83=D0=BB=D1=8F=20=D1=8D=D0=BC=D0=B1=D0=B5=D0=B4?= =?UTF-8?q?=D0=B4=D0=B8=D0=BD=D0=B3=D0=BE=D0=B2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Исправлены импорты в test_token_embeddings.py - Проверена работоспособность всех тестов - Добавлены комментарии для будущих интеграционных тестов --- tests/test_positional_embeddings.py | 24 ++++++++++++++++++++++++ tests/test_token_embeddings.py | 2 +- 2 files changed, 25 insertions(+), 1 deletion(-) create mode 100644 tests/test_positional_embeddings.py diff --git a/tests/test_positional_embeddings.py b/tests/test_positional_embeddings.py new file mode 100644 index 0000000..072e14e --- /dev/null +++ b/tests/test_positional_embeddings.py @@ -0,0 +1,24 @@ +import torch +import pytest +from simple_llm.embedding.positional_embeddings import PositionalEmbeddings + +class TestPositionalEmbeddings: + @pytest.fixture + def pos_encoder(self): + return PositionalEmbeddings(max_seq_len=100, emb_size=64) + + def test_output_shape(self, pos_encoder): + output = pos_encoder(10) + assert output.shape == (10, 64) + + def test_embedding_layer(self, pos_encoder): + assert isinstance(pos_encoder.embedding, torch.nn.Embedding) + assert pos_encoder.embedding.num_embeddings == 100 + assert pos_encoder.embedding.embedding_dim == 64 + + def test_out_of_range(self, pos_encoder): + with pytest.raises(IndexError): + pos_encoder(101) + +if __name__ == "__main__": + pytest.main(["-v", __file__]) diff --git a/tests/test_token_embeddings.py b/tests/test_token_embeddings.py index e28b747..07c23f9 100644 --- a/tests/test_token_embeddings.py +++ b/tests/test_token_embeddings.py @@ -1,6 +1,6 @@ import torch import pytest -from simple_llm.embedding.token_embedings import TokenEmbeddings +from simple_llm.embedding.token_embeddings import TokenEmbeddings class TestTokenEmbeddings: """Unit tests for TokenEmbeddings class"""