mirror of
https://github.com/pese-git/simple-llm.git
synced 2026-01-23 21:14:17 +00:00
Обновление тестов FeedForward: упрощение проверок инициализации и dropout
This commit is contained in:
@@ -8,17 +8,15 @@ class TestFeedForward:
|
|||||||
return FeedForward(emb_size=512)
|
return FeedForward(emb_size=512)
|
||||||
|
|
||||||
def test_initialization(self, ff_layer):
|
def test_initialization(self, ff_layer):
|
||||||
assert isinstance(ff_layer.net, torch.nn.Sequential)
|
assert isinstance(ff_layer._layer1, torch.nn.Linear)
|
||||||
assert len(ff_layer.net) == 4
|
assert isinstance(ff_layer._layer2, torch.nn.Linear)
|
||||||
assert isinstance(ff_layer.net[0], torch.nn.Linear)
|
assert isinstance(ff_layer._relu, torch.nn.ReLU)
|
||||||
assert isinstance(ff_layer.net[1], torch.nn.ReLU)
|
assert isinstance(ff_layer._dropout, torch.nn.Dropout)
|
||||||
assert isinstance(ff_layer.net[2], torch.nn.Linear)
|
|
||||||
assert isinstance(ff_layer.net[3], torch.nn.Dropout)
|
|
||||||
|
|
||||||
assert ff_layer.net[0].in_features == 512
|
assert ff_layer._layer1.in_features == 512
|
||||||
assert ff_layer.net[0].out_features == 2048
|
assert ff_layer._layer1.out_features == 2048
|
||||||
assert ff_layer.net[2].in_features == 2048
|
assert ff_layer._layer2.in_features == 2048
|
||||||
assert ff_layer.net[2].out_features == 512
|
assert ff_layer._layer2.out_features == 512
|
||||||
|
|
||||||
def test_forward_pass_shape(self, ff_layer):
|
def test_forward_pass_shape(self, ff_layer):
|
||||||
batch_size = 4
|
batch_size = 4
|
||||||
@@ -35,9 +33,7 @@ class TestFeedForward:
|
|||||||
output = ff_layer(x)
|
output = ff_layer(x)
|
||||||
|
|
||||||
# Проверяем, что dropout действительно работает в режиме обучения
|
# Проверяем, что dropout действительно работает в режиме обучения
|
||||||
layers = ff_layer.net
|
assert not torch.allclose(output, ff_layer._layer2(ff_layer._relu(ff_layer._layer1(x))))
|
||||||
no_dropout = layers[2](layers[1](layers[0](x)))
|
|
||||||
assert not torch.allclose(output, no_dropout)
|
|
||||||
|
|
||||||
def test_dropout_eval(self):
|
def test_dropout_eval(self):
|
||||||
ff_layer = FeedForward(512, dropout=0.5)
|
ff_layer = FeedForward(512, dropout=0.5)
|
||||||
@@ -46,8 +42,7 @@ class TestFeedForward:
|
|||||||
output = ff_layer(x)
|
output = ff_layer(x)
|
||||||
|
|
||||||
# В eval режиме dropout не должен работать
|
# В eval режиме dropout не должен работать
|
||||||
layers = ff_layer.net
|
expected = ff_layer._layer2(ff_layer._relu(ff_layer._layer1(x)))
|
||||||
expected = layers[2](layers[1](layers[0](x)))
|
|
||||||
assert torch.allclose(output, expected)
|
assert torch.allclose(output, expected)
|
||||||
|
|
||||||
def test_dtype_preservation(self, ff_layer):
|
def test_dtype_preservation(self, ff_layer):
|
||||||
|
|||||||
Reference in New Issue
Block a user