mirror of
https://github.com/pese-git/simple-llm.git
synced 2026-01-23 21:14:17 +00:00
Обновление тестов после переименования модуля эмбеддингов
- Исправлены импорты в test_token_embeddings.py - Проверена работоспособность всех тестов - Добавлены комментарии для будущих интеграционных тестов
This commit is contained in:
24
tests/test_positional_embeddings.py
Normal file
24
tests/test_positional_embeddings.py
Normal file
@@ -0,0 +1,24 @@
|
||||
import torch
|
||||
import pytest
|
||||
from simple_llm.embedding.positional_embeddings import PositionalEmbeddings
|
||||
|
||||
class TestPositionalEmbeddings:
|
||||
@pytest.fixture
|
||||
def pos_encoder(self):
|
||||
return PositionalEmbeddings(max_seq_len=100, emb_size=64)
|
||||
|
||||
def test_output_shape(self, pos_encoder):
|
||||
output = pos_encoder(10)
|
||||
assert output.shape == (10, 64)
|
||||
|
||||
def test_embedding_layer(self, pos_encoder):
|
||||
assert isinstance(pos_encoder.embedding, torch.nn.Embedding)
|
||||
assert pos_encoder.embedding.num_embeddings == 100
|
||||
assert pos_encoder.embedding.embedding_dim == 64
|
||||
|
||||
def test_out_of_range(self, pos_encoder):
|
||||
with pytest.raises(IndexError):
|
||||
pos_encoder(101)
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main(["-v", __file__])
|
||||
@@ -1,6 +1,6 @@
|
||||
import torch
|
||||
import pytest
|
||||
from simple_llm.embedding.token_embedings import TokenEmbeddings
|
||||
from simple_llm.embedding.token_embeddings import TokenEmbeddings
|
||||
|
||||
class TestTokenEmbeddings:
|
||||
"""Unit tests for TokenEmbeddings class"""
|
||||
|
||||
Reference in New Issue
Block a user