mirror of
https://github.com/pese-git/simple-llm.git
synced 2026-01-23 21:14:17 +00:00
Обновление тестов после переименования модуля эмбеддингов
- Исправлены импорты в test_token_embeddings.py - Проверена работоспособность всех тестов - Добавлены комментарии для будущих интеграционных тестов
This commit is contained in:
24
tests/test_positional_embeddings.py
Normal file
24
tests/test_positional_embeddings.py
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
import torch
|
||||||
|
import pytest
|
||||||
|
from simple_llm.embedding.positional_embeddings import PositionalEmbeddings
|
||||||
|
|
||||||
|
class TestPositionalEmbeddings:
|
||||||
|
@pytest.fixture
|
||||||
|
def pos_encoder(self):
|
||||||
|
return PositionalEmbeddings(max_seq_len=100, emb_size=64)
|
||||||
|
|
||||||
|
def test_output_shape(self, pos_encoder):
|
||||||
|
output = pos_encoder(10)
|
||||||
|
assert output.shape == (10, 64)
|
||||||
|
|
||||||
|
def test_embedding_layer(self, pos_encoder):
|
||||||
|
assert isinstance(pos_encoder.embedding, torch.nn.Embedding)
|
||||||
|
assert pos_encoder.embedding.num_embeddings == 100
|
||||||
|
assert pos_encoder.embedding.embedding_dim == 64
|
||||||
|
|
||||||
|
def test_out_of_range(self, pos_encoder):
|
||||||
|
with pytest.raises(IndexError):
|
||||||
|
pos_encoder(101)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
pytest.main(["-v", __file__])
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
import torch
|
import torch
|
||||||
import pytest
|
import pytest
|
||||||
from simple_llm.embedding.token_embedings import TokenEmbeddings
|
from simple_llm.embedding.token_embeddings import TokenEmbeddings
|
||||||
|
|
||||||
class TestTokenEmbeddings:
|
class TestTokenEmbeddings:
|
||||||
"""Unit tests for TokenEmbeddings class"""
|
"""Unit tests for TokenEmbeddings class"""
|
||||||
|
|||||||
Reference in New Issue
Block a user