test(core): fix FeedForward and MultiHeadAttention tests for unified interface and tuple outputs

This commit is contained in:
Sergey Penkovsky
2025-10-05 19:26:18 +03:00
parent c39e68d71a
commit 3843e64098
3 changed files with 14 additions and 14 deletions

View File

@@ -19,7 +19,7 @@ class TestFeedForward:
# Check internal layers
assert hasattr(ff, '_layer1')
assert hasattr(ff, '_layer2')
assert hasattr(ff, '_relu')
assert hasattr(ff, '_activation')
assert hasattr(ff, '_dropout')
# Check layer dimensions
@@ -78,7 +78,7 @@ class TestFeedForward:
# Manually compute expected output without dropout for deterministic comparison
hidden = ff._layer1(random_float_inputs)
activated = ff._relu(hidden)
activated = ff._activation(hidden)
expected_output = ff._layer2(activated)
# Compare with forward pass in eval mode (no dropout)

View File

@@ -27,7 +27,7 @@ class TestMultiHeadAttention:
attention = MultiHeadAttention(num_heads, embed_dim, head_size, max_seq_len=1024)
# Forward pass
output = attention(random_embeddings)
output, _ = attention(random_embeddings)
# Check output shape
assert output.shape == random_embeddings.shape
@@ -43,7 +43,7 @@ class TestMultiHeadAttention:
mask = torch.tril(torch.ones(seq_len, seq_len)) # Causal mask
# Forward pass with mask
output = attention(random_embeddings, mask=mask)
output, _ = attention(random_embeddings, mask=mask)
# Check output shape
assert output.shape == random_embeddings.shape
@@ -58,7 +58,7 @@ class TestMultiHeadAttention:
causal_mask = torch.tril(torch.ones(seq_len, seq_len))
# Forward pass with causal mask
output = attention(random_embeddings, mask=causal_mask)
output, _ = attention(random_embeddings, mask=causal_mask)
# Check output shape
assert output.shape == random_embeddings.shape
@@ -69,7 +69,7 @@ class TestMultiHeadAttention:
attention = MultiHeadAttention(num_heads, embed_dim, head_size, max_seq_len=1024)
# Forward pass
output = attention(random_embeddings)
output, _ = attention(random_embeddings)
# Check output shape
assert output.shape == random_embeddings.shape
@@ -80,7 +80,7 @@ class TestMultiHeadAttention:
attention = MultiHeadAttention(num_heads, embed_dim, head_size, max_seq_len=1024)
# Forward pass
output = attention(random_embeddings)
output, _ = attention(random_embeddings)
# Create a dummy loss and backward pass
loss = output.sum()
@@ -98,7 +98,7 @@ class TestMultiHeadAttention:
inputs = random_embeddings.to(device)
# Forward pass
output = attention(inputs)
output, _ = attention(inputs)
# Check device consistency
assert output.device == device
@@ -119,7 +119,7 @@ class TestMultiHeadAttention:
batch_size, seq_len = 2, 16
inputs = torch.randn(batch_size, seq_len, embed_dim)
output = attention(inputs)
output, _ = attention(inputs)
assert output.shape == inputs.shape
@@ -128,7 +128,7 @@ class TestMultiHeadAttention:
head_size = embed_dim // num_heads
attention = MultiHeadAttention(num_heads, embed_dim, head_size, max_seq_len=1024)
output = attention(random_embeddings)
output, _ = attention(random_embeddings)
# Output shouldn't have extreme values
assert output.abs().max() < 100 # Reasonable upper bound
@@ -140,7 +140,7 @@ class TestMultiHeadAttention:
attention = MultiHeadAttention(num_heads, embed_dim, head_size, max_seq_len=1024)
inputs = torch.randn(batch_size, seq_len, embed_dim)
output = attention(inputs)
output, _ = attention(inputs)
assert output.shape == (batch_size, seq_len, embed_dim)
@@ -158,8 +158,8 @@ class TestMultiHeadAttention:
attention.eval()
with torch.no_grad():
output1 = attention(base_sequence)
output2 = attention(identical_sequence)
output1, _ = attention(base_sequence)
output2, _ = attention(identical_sequence)
# With identical inputs and same parameters, outputs should be identical
assert torch.allclose(output1, output2, rtol=1e-5)

View File

@@ -98,7 +98,7 @@ def test_multi_head_attention():
batch_size, seq_len = 2, 16
inputs = torch.randn(batch_size, seq_len, emb_size)
output = attention(inputs)
output, _ = attention(inputs)
assert output.shape == inputs.shape
print("✅ Multi-head attention test passed")