mirror of
https://github.com/pese-git/llm-arch-research.git
synced 2026-01-23 21:10:54 +00:00
test(core): fix FeedForward and MultiHeadAttention tests for unified interface and tuple outputs
This commit is contained in:
@@ -19,7 +19,7 @@ class TestFeedForward:
|
||||
# Check internal layers
|
||||
assert hasattr(ff, '_layer1')
|
||||
assert hasattr(ff, '_layer2')
|
||||
assert hasattr(ff, '_relu')
|
||||
assert hasattr(ff, '_activation')
|
||||
assert hasattr(ff, '_dropout')
|
||||
|
||||
# Check layer dimensions
|
||||
@@ -78,7 +78,7 @@ class TestFeedForward:
|
||||
|
||||
# Manually compute expected output without dropout for deterministic comparison
|
||||
hidden = ff._layer1(random_float_inputs)
|
||||
activated = ff._relu(hidden)
|
||||
activated = ff._activation(hidden)
|
||||
expected_output = ff._layer2(activated)
|
||||
|
||||
# Compare with forward pass in eval mode (no dropout)
|
||||
|
||||
@@ -27,7 +27,7 @@ class TestMultiHeadAttention:
|
||||
attention = MultiHeadAttention(num_heads, embed_dim, head_size, max_seq_len=1024)
|
||||
|
||||
# Forward pass
|
||||
output = attention(random_embeddings)
|
||||
output, _ = attention(random_embeddings)
|
||||
|
||||
# Check output shape
|
||||
assert output.shape == random_embeddings.shape
|
||||
@@ -43,7 +43,7 @@ class TestMultiHeadAttention:
|
||||
mask = torch.tril(torch.ones(seq_len, seq_len)) # Causal mask
|
||||
|
||||
# Forward pass with mask
|
||||
output = attention(random_embeddings, mask=mask)
|
||||
output, _ = attention(random_embeddings, mask=mask)
|
||||
|
||||
# Check output shape
|
||||
assert output.shape == random_embeddings.shape
|
||||
@@ -58,7 +58,7 @@ class TestMultiHeadAttention:
|
||||
causal_mask = torch.tril(torch.ones(seq_len, seq_len))
|
||||
|
||||
# Forward pass with causal mask
|
||||
output = attention(random_embeddings, mask=causal_mask)
|
||||
output, _ = attention(random_embeddings, mask=causal_mask)
|
||||
|
||||
# Check output shape
|
||||
assert output.shape == random_embeddings.shape
|
||||
@@ -69,7 +69,7 @@ class TestMultiHeadAttention:
|
||||
attention = MultiHeadAttention(num_heads, embed_dim, head_size, max_seq_len=1024)
|
||||
|
||||
# Forward pass
|
||||
output = attention(random_embeddings)
|
||||
output, _ = attention(random_embeddings)
|
||||
|
||||
# Check output shape
|
||||
assert output.shape == random_embeddings.shape
|
||||
@@ -80,7 +80,7 @@ class TestMultiHeadAttention:
|
||||
attention = MultiHeadAttention(num_heads, embed_dim, head_size, max_seq_len=1024)
|
||||
|
||||
# Forward pass
|
||||
output = attention(random_embeddings)
|
||||
output, _ = attention(random_embeddings)
|
||||
|
||||
# Create a dummy loss and backward pass
|
||||
loss = output.sum()
|
||||
@@ -98,7 +98,7 @@ class TestMultiHeadAttention:
|
||||
inputs = random_embeddings.to(device)
|
||||
|
||||
# Forward pass
|
||||
output = attention(inputs)
|
||||
output, _ = attention(inputs)
|
||||
|
||||
# Check device consistency
|
||||
assert output.device == device
|
||||
@@ -119,7 +119,7 @@ class TestMultiHeadAttention:
|
||||
batch_size, seq_len = 2, 16
|
||||
inputs = torch.randn(batch_size, seq_len, embed_dim)
|
||||
|
||||
output = attention(inputs)
|
||||
output, _ = attention(inputs)
|
||||
|
||||
assert output.shape == inputs.shape
|
||||
|
||||
@@ -128,7 +128,7 @@ class TestMultiHeadAttention:
|
||||
head_size = embed_dim // num_heads
|
||||
attention = MultiHeadAttention(num_heads, embed_dim, head_size, max_seq_len=1024)
|
||||
|
||||
output = attention(random_embeddings)
|
||||
output, _ = attention(random_embeddings)
|
||||
|
||||
# Output shouldn't have extreme values
|
||||
assert output.abs().max() < 100 # Reasonable upper bound
|
||||
@@ -140,7 +140,7 @@ class TestMultiHeadAttention:
|
||||
attention = MultiHeadAttention(num_heads, embed_dim, head_size, max_seq_len=1024)
|
||||
|
||||
inputs = torch.randn(batch_size, seq_len, embed_dim)
|
||||
output = attention(inputs)
|
||||
output, _ = attention(inputs)
|
||||
|
||||
assert output.shape == (batch_size, seq_len, embed_dim)
|
||||
|
||||
@@ -158,8 +158,8 @@ class TestMultiHeadAttention:
|
||||
attention.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
output1 = attention(base_sequence)
|
||||
output2 = attention(identical_sequence)
|
||||
output1, _ = attention(base_sequence)
|
||||
output2, _ = attention(identical_sequence)
|
||||
|
||||
# With identical inputs and same parameters, outputs should be identical
|
||||
assert torch.allclose(output1, output2, rtol=1e-5)
|
||||
|
||||
@@ -98,7 +98,7 @@ def test_multi_head_attention():
|
||||
batch_size, seq_len = 2, 16
|
||||
inputs = torch.randn(batch_size, seq_len, emb_size)
|
||||
|
||||
output = attention(inputs)
|
||||
output, _ = attention(inputs)
|
||||
|
||||
assert output.shape == inputs.shape
|
||||
print("✅ Multi-head attention test passed")
|
||||
|
||||
Reference in New Issue
Block a user