diff --git a/example/example_token_embeddings.py b/example/example_token_embeddings.py index 0b80448..4539a27 100644 --- a/example/example_token_embeddings.py +++ b/example/example_token_embeddings.py @@ -10,7 +10,7 @@ import torch import matplotlib.pyplot as plt from sklearn.decomposition import PCA -from simple_llm.embedding.token_embedings import TokenEmbeddings +from simple_llm.embedding.token_embeddings import TokenEmbeddings def basic_example(): """Базовый пример использования TokenEmbeddings""" diff --git a/tests/test_token_embeddings.py b/tests/test_token_embeddings.py new file mode 100644 index 0000000..e28b747 --- /dev/null +++ b/tests/test_token_embeddings.py @@ -0,0 +1,47 @@ +import torch +import pytest +from simple_llm.embedding.token_embedings import TokenEmbeddings + +class TestTokenEmbeddings: + """Unit tests for TokenEmbeddings class""" + + @pytest.fixture + def embedding_layer(self): + return TokenEmbeddings(vocab_size=100, emb_size=32) + + def test_initialization(self, embedding_layer): + """Test layer initialization""" + assert isinstance(embedding_layer, torch.nn.Module) + assert embedding_layer._embedding.num_embeddings == 100 + assert embedding_layer._embedding.embedding_dim == 32 + + def test_forward_shape(self, embedding_layer): + """Test output shape of forward pass""" + test_input = torch.tensor([ + [1, 2, 3], + [4, 5, 6] + ]) + output = embedding_layer(test_input) + assert output.shape == (2, 3, 32) # batch_size=2, seq_len=3, emb_size=32 + + def test_embedding_values(self, embedding_layer): + """Test that embeddings are trainable""" + input_tensor = torch.tensor([[1]]) + before = embedding_layer(input_tensor).clone() + + # Simulate training step + optimizer = torch.optim.SGD(embedding_layer.parameters(), lr=0.1) + loss = embedding_layer(input_tensor).sum() + loss.backward() + optimizer.step() + + after = embedding_layer(input_tensor) + assert not torch.allclose(before, after), "Embeddings should change after training" + + def test_out_of_vocab(self, embedding_layer): + """Test handling of out-of-vocabulary indices""" + with pytest.raises(IndexError): + embedding_layer(torch.tensor([[100]])) # vocab_size=100 + +if __name__ == "__main__": + pytest.main(["-v", __file__])