mirror of
https://github.com/pese-git/llm-arch-research.git
synced 2026-01-23 21:10:54 +00:00
test(core): fix FeedForward and MultiHeadAttention tests for unified interface and tuple outputs
This commit is contained in:
@@ -19,7 +19,7 @@ class TestFeedForward:
|
|||||||
# Check internal layers
|
# Check internal layers
|
||||||
assert hasattr(ff, '_layer1')
|
assert hasattr(ff, '_layer1')
|
||||||
assert hasattr(ff, '_layer2')
|
assert hasattr(ff, '_layer2')
|
||||||
assert hasattr(ff, '_relu')
|
assert hasattr(ff, '_activation')
|
||||||
assert hasattr(ff, '_dropout')
|
assert hasattr(ff, '_dropout')
|
||||||
|
|
||||||
# Check layer dimensions
|
# Check layer dimensions
|
||||||
@@ -78,7 +78,7 @@ class TestFeedForward:
|
|||||||
|
|
||||||
# Manually compute expected output without dropout for deterministic comparison
|
# Manually compute expected output without dropout for deterministic comparison
|
||||||
hidden = ff._layer1(random_float_inputs)
|
hidden = ff._layer1(random_float_inputs)
|
||||||
activated = ff._relu(hidden)
|
activated = ff._activation(hidden)
|
||||||
expected_output = ff._layer2(activated)
|
expected_output = ff._layer2(activated)
|
||||||
|
|
||||||
# Compare with forward pass in eval mode (no dropout)
|
# Compare with forward pass in eval mode (no dropout)
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ class TestMultiHeadAttention:
|
|||||||
attention = MultiHeadAttention(num_heads, embed_dim, head_size, max_seq_len=1024)
|
attention = MultiHeadAttention(num_heads, embed_dim, head_size, max_seq_len=1024)
|
||||||
|
|
||||||
# Forward pass
|
# Forward pass
|
||||||
output = attention(random_embeddings)
|
output, _ = attention(random_embeddings)
|
||||||
|
|
||||||
# Check output shape
|
# Check output shape
|
||||||
assert output.shape == random_embeddings.shape
|
assert output.shape == random_embeddings.shape
|
||||||
@@ -43,7 +43,7 @@ class TestMultiHeadAttention:
|
|||||||
mask = torch.tril(torch.ones(seq_len, seq_len)) # Causal mask
|
mask = torch.tril(torch.ones(seq_len, seq_len)) # Causal mask
|
||||||
|
|
||||||
# Forward pass with mask
|
# Forward pass with mask
|
||||||
output = attention(random_embeddings, mask=mask)
|
output, _ = attention(random_embeddings, mask=mask)
|
||||||
|
|
||||||
# Check output shape
|
# Check output shape
|
||||||
assert output.shape == random_embeddings.shape
|
assert output.shape == random_embeddings.shape
|
||||||
@@ -58,7 +58,7 @@ class TestMultiHeadAttention:
|
|||||||
causal_mask = torch.tril(torch.ones(seq_len, seq_len))
|
causal_mask = torch.tril(torch.ones(seq_len, seq_len))
|
||||||
|
|
||||||
# Forward pass with causal mask
|
# Forward pass with causal mask
|
||||||
output = attention(random_embeddings, mask=causal_mask)
|
output, _ = attention(random_embeddings, mask=causal_mask)
|
||||||
|
|
||||||
# Check output shape
|
# Check output shape
|
||||||
assert output.shape == random_embeddings.shape
|
assert output.shape == random_embeddings.shape
|
||||||
@@ -69,7 +69,7 @@ class TestMultiHeadAttention:
|
|||||||
attention = MultiHeadAttention(num_heads, embed_dim, head_size, max_seq_len=1024)
|
attention = MultiHeadAttention(num_heads, embed_dim, head_size, max_seq_len=1024)
|
||||||
|
|
||||||
# Forward pass
|
# Forward pass
|
||||||
output = attention(random_embeddings)
|
output, _ = attention(random_embeddings)
|
||||||
|
|
||||||
# Check output shape
|
# Check output shape
|
||||||
assert output.shape == random_embeddings.shape
|
assert output.shape == random_embeddings.shape
|
||||||
@@ -80,7 +80,7 @@ class TestMultiHeadAttention:
|
|||||||
attention = MultiHeadAttention(num_heads, embed_dim, head_size, max_seq_len=1024)
|
attention = MultiHeadAttention(num_heads, embed_dim, head_size, max_seq_len=1024)
|
||||||
|
|
||||||
# Forward pass
|
# Forward pass
|
||||||
output = attention(random_embeddings)
|
output, _ = attention(random_embeddings)
|
||||||
|
|
||||||
# Create a dummy loss and backward pass
|
# Create a dummy loss and backward pass
|
||||||
loss = output.sum()
|
loss = output.sum()
|
||||||
@@ -98,7 +98,7 @@ class TestMultiHeadAttention:
|
|||||||
inputs = random_embeddings.to(device)
|
inputs = random_embeddings.to(device)
|
||||||
|
|
||||||
# Forward pass
|
# Forward pass
|
||||||
output = attention(inputs)
|
output, _ = attention(inputs)
|
||||||
|
|
||||||
# Check device consistency
|
# Check device consistency
|
||||||
assert output.device == device
|
assert output.device == device
|
||||||
@@ -119,7 +119,7 @@ class TestMultiHeadAttention:
|
|||||||
batch_size, seq_len = 2, 16
|
batch_size, seq_len = 2, 16
|
||||||
inputs = torch.randn(batch_size, seq_len, embed_dim)
|
inputs = torch.randn(batch_size, seq_len, embed_dim)
|
||||||
|
|
||||||
output = attention(inputs)
|
output, _ = attention(inputs)
|
||||||
|
|
||||||
assert output.shape == inputs.shape
|
assert output.shape == inputs.shape
|
||||||
|
|
||||||
@@ -128,7 +128,7 @@ class TestMultiHeadAttention:
|
|||||||
head_size = embed_dim // num_heads
|
head_size = embed_dim // num_heads
|
||||||
attention = MultiHeadAttention(num_heads, embed_dim, head_size, max_seq_len=1024)
|
attention = MultiHeadAttention(num_heads, embed_dim, head_size, max_seq_len=1024)
|
||||||
|
|
||||||
output = attention(random_embeddings)
|
output, _ = attention(random_embeddings)
|
||||||
|
|
||||||
# Output shouldn't have extreme values
|
# Output shouldn't have extreme values
|
||||||
assert output.abs().max() < 100 # Reasonable upper bound
|
assert output.abs().max() < 100 # Reasonable upper bound
|
||||||
@@ -140,7 +140,7 @@ class TestMultiHeadAttention:
|
|||||||
attention = MultiHeadAttention(num_heads, embed_dim, head_size, max_seq_len=1024)
|
attention = MultiHeadAttention(num_heads, embed_dim, head_size, max_seq_len=1024)
|
||||||
|
|
||||||
inputs = torch.randn(batch_size, seq_len, embed_dim)
|
inputs = torch.randn(batch_size, seq_len, embed_dim)
|
||||||
output = attention(inputs)
|
output, _ = attention(inputs)
|
||||||
|
|
||||||
assert output.shape == (batch_size, seq_len, embed_dim)
|
assert output.shape == (batch_size, seq_len, embed_dim)
|
||||||
|
|
||||||
@@ -158,8 +158,8 @@ class TestMultiHeadAttention:
|
|||||||
attention.eval()
|
attention.eval()
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
output1 = attention(base_sequence)
|
output1, _ = attention(base_sequence)
|
||||||
output2 = attention(identical_sequence)
|
output2, _ = attention(identical_sequence)
|
||||||
|
|
||||||
# With identical inputs and same parameters, outputs should be identical
|
# With identical inputs and same parameters, outputs should be identical
|
||||||
assert torch.allclose(output1, output2, rtol=1e-5)
|
assert torch.allclose(output1, output2, rtol=1e-5)
|
||||||
|
|||||||
@@ -98,7 +98,7 @@ def test_multi_head_attention():
|
|||||||
batch_size, seq_len = 2, 16
|
batch_size, seq_len = 2, 16
|
||||||
inputs = torch.randn(batch_size, seq_len, emb_size)
|
inputs = torch.randn(batch_size, seq_len, emb_size)
|
||||||
|
|
||||||
output = attention(inputs)
|
output, _ = attention(inputs)
|
||||||
|
|
||||||
assert output.shape == inputs.shape
|
assert output.shape == inputs.shape
|
||||||
print("✅ Multi-head attention test passed")
|
print("✅ Multi-head attention test passed")
|
||||||
|
|||||||
Reference in New Issue
Block a user