diff --git a/tests/test_positional_embeddings.py b/tests/test_positional_embeddings.py new file mode 100644 index 0000000..072e14e --- /dev/null +++ b/tests/test_positional_embeddings.py @@ -0,0 +1,24 @@ +import torch +import pytest +from simple_llm.embedding.positional_embeddings import PositionalEmbeddings + +class TestPositionalEmbeddings: + @pytest.fixture + def pos_encoder(self): + return PositionalEmbeddings(max_seq_len=100, emb_size=64) + + def test_output_shape(self, pos_encoder): + output = pos_encoder(10) + assert output.shape == (10, 64) + + def test_embedding_layer(self, pos_encoder): + assert isinstance(pos_encoder.embedding, torch.nn.Embedding) + assert pos_encoder.embedding.num_embeddings == 100 + assert pos_encoder.embedding.embedding_dim == 64 + + def test_out_of_range(self, pos_encoder): + with pytest.raises(IndexError): + pos_encoder(101) + +if __name__ == "__main__": + pytest.main(["-v", __file__]) diff --git a/tests/test_token_embeddings.py b/tests/test_token_embeddings.py index e28b747..07c23f9 100644 --- a/tests/test_token_embeddings.py +++ b/tests/test_token_embeddings.py @@ -1,6 +1,6 @@ import torch import pytest -from simple_llm.embedding.token_embedings import TokenEmbeddings +from simple_llm.embedding.token_embeddings import TokenEmbeddings class TestTokenEmbeddings: """Unit tests for TokenEmbeddings class"""