Spaces:
Running
Running
Commit
·
204fb3c
1
Parent(s):
ba4cb76
Implemented ScaledDotProduct Attention and Multi-Head Attention
Browse files- src/models/attention.py +113 -3
- src/models/decoder.py +0 -0
- src/models/encoder.py +0 -0
- src/models/heads.py +0 -0
- src/models/multitask.py +0 -0
- src/models/positional_encoding.py +0 -0
- tests/test_models/test_attention.py +89 -2
- tests/test_models/test_attention_visual.py +53 -0
- tests/test_models/test_multihead_visual.py +162 -0
src/models/attention.py
CHANGED
|
@@ -5,6 +5,8 @@ This module implements the core attention mechanisms used in the Transformer mod
|
|
| 5 |
- ScaledDotProductAttention: Fundamental attention operation
|
| 6 |
- MultiHeadAttention: Parallel attention with learned projections
|
| 7 |
|
|
|
|
|
|
|
| 8 |
Author: Oliver Perrin
|
| 9 |
Date: 2025-10-23
|
| 10 |
"""
|
|
@@ -48,7 +50,7 @@ class ScaledDotProductAttention(nn.Module):
|
|
| 48 |
|
| 49 |
def __init__(self):
|
| 50 |
super().__init__()
|
| 51 |
-
#
|
| 52 |
pass
|
| 53 |
|
| 54 |
def forward(
|
|
@@ -69,7 +71,115 @@ class ScaledDotProductAttention(nn.Module):
|
|
| 69 |
5. Compute output: output = attention_weights @ value
|
| 70 |
6. Return both output and attention_weights
|
| 71 |
"""
|
| 72 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
|
| 75 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
- ScaledDotProductAttention: Fundamental attention operation
|
| 6 |
- MultiHeadAttention: Parallel attention with learned projections
|
| 7 |
|
| 8 |
+
Doing this first for Bottom-Up implementation of the Transformer
|
| 9 |
+
|
| 10 |
Author: Oliver Perrin
|
| 11 |
Date: 2025-10-23
|
| 12 |
"""
|
|
|
|
| 50 |
|
| 51 |
def __init__(self):
|
| 52 |
super().__init__()
|
| 53 |
+
# Params not needed here.
|
| 54 |
pass
|
| 55 |
|
| 56 |
def forward(
|
|
|
|
| 71 |
5. Compute output: output = attention_weights @ value
|
| 72 |
6. Return both output and attention_weights
|
| 73 |
"""
|
| 74 |
+
# Getting Dimension for Scaling
|
| 75 |
+
d_k = query.size(-1)
|
| 76 |
+
|
| 77 |
+
# Compute Attention Scores
|
| 78 |
+
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
|
| 79 |
|
| 80 |
+
# Mask if provided
|
| 81 |
+
if mask is not None:
|
| 82 |
+
scores = scores.masked_fill(mask == 0, float('-inf'))
|
| 83 |
+
# Applying Softmax to get attention weights
|
| 84 |
+
attention_weights = F.softmax(scores, dim=-1)
|
| 85 |
+
|
| 86 |
+
return torch.matmul(attention_weights, value), attention_weights
|
| 87 |
+
|
| 88 |
+
# --------------- Multi-Head Attention ---------------
|
| 89 |
|
| 90 |
+
class MultiHeadAttention(nn.Module):
|
| 91 |
+
"""
|
| 92 |
+
Multi-Head Attention mechanism.
|
| 93 |
+
|
| 94 |
+
Allows the model to jointly attend to information from different
|
| 95 |
+
representation subspaces at different positions.
|
| 96 |
+
|
| 97 |
+
Transforming the input into query, key, and value representations
|
| 98 |
+
|
| 99 |
+
Args:
|
| 100 |
+
d_model: Dimension of model (default: 512)
|
| 101 |
+
num_heads: Number of attention heads (default: 8)
|
| 102 |
+
dropout: Dropout probability (default: 0.1)
|
| 103 |
+
"""
|
| 104 |
+
|
| 105 |
+
def __init__(self, d_model: int = 512, num_heads: int = 8, dropout: float = 0.1):
|
| 106 |
+
super().__init__()
|
| 107 |
+
|
| 108 |
+
# Assert that d_model is divisible by num_heads
|
| 109 |
+
# Why? Because d_k = d_model // num_heads must be an integer
|
| 110 |
+
assert d_model % num_heads == 0
|
| 111 |
+
|
| 112 |
+
# Assume d_v always equals d_k
|
| 113 |
+
self.d_model = d_model
|
| 114 |
+
self.num_heads = num_heads
|
| 115 |
+
self.d_k = d_model // num_heads
|
| 116 |
+
|
| 117 |
+
# Create 4 linear layers (W_Q, W_K, W_V, W_O)
|
| 118 |
+
# All should be nn.Linear(d_model, d_model)
|
| 119 |
+
self.W_Q = nn.Linear(d_model, d_model)
|
| 120 |
+
self.W_K = nn.Linear(d_model, d_model)
|
| 121 |
+
self.W_V = nn.Linear(d_model, d_model)
|
| 122 |
+
self.W_O = nn.Linear(d_model, d_model)
|
| 123 |
+
# Create ScaledDotProductAttention instance
|
| 124 |
+
self.attention = ScaledDotProductAttention()
|
| 125 |
+
# Create dropout layer
|
| 126 |
+
self.dropout = nn.Dropout(p=dropout)
|
| 127 |
+
|
| 128 |
+
def forward(
|
| 129 |
+
self,
|
| 130 |
+
query: torch.Tensor,
|
| 131 |
+
key: torch.Tensor,
|
| 132 |
+
value: torch.Tensor,
|
| 133 |
+
mask: Optional[torch.Tensor] = None
|
| 134 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 135 |
+
"""
|
| 136 |
+
Args:
|
| 137 |
+
query: (batch, seq_len, d_model)
|
| 138 |
+
key: (batch, seq_len, d_model)
|
| 139 |
+
value: (batch, seq_len, d_model)
|
| 140 |
+
mask: Optional (batch, seq_len, seq_len) or (batch, 1, seq_len, seq_len)
|
| 141 |
+
|
| 142 |
+
Returns:
|
| 143 |
+
output: (batch, seq_len, d_model)
|
| 144 |
+
attention_weights: (batch, num_heads, seq_len, seq_len)
|
| 145 |
+
"""
|
| 146 |
+
batch_size = query.size(0)
|
| 147 |
+
|
| 148 |
+
# Linear projections
|
| 149 |
+
Q = self.W_Q(query) # (batch, seq_len, d_model)
|
| 150 |
+
K = self.W_K(key)
|
| 151 |
+
V = self.W_V(value)
|
| 152 |
+
|
| 153 |
+
# Split into heads
|
| 154 |
+
# Reshape from (batch, seq_len, d_model) to (batch, num_heads, seq_len, d_k), Apply to Q, K, V
|
| 155 |
+
Q = Q.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
|
| 156 |
+
K = K.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
|
| 157 |
+
V = V.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
|
| 158 |
+
# Now: (batch, num_heads, seq_len, d_k)
|
| 159 |
+
# Now all are: (batch=2, num_heads=8, seq_len=10, d_k=64)
|
| 160 |
+
|
| 161 |
+
# Handle mask broadcasting for multi-head attention
|
| 162 |
+
if mask is not None:
|
| 163 |
+
# If mask is 3D (batch, seq, seq), add head dimension
|
| 164 |
+
if mask.dim() == 3:
|
| 165 |
+
mask = mask.unsqueeze(1) # (batch, 1, seq, seq)
|
| 166 |
+
# Now mask broadcasts across all heads: (batch, 1, seq, seq) → (batch, 8, seq, seq)
|
| 167 |
+
|
| 168 |
+
# Apply attention
|
| 169 |
+
output, attn_weights = self.attention(Q, K, V, mask)
|
| 170 |
+
# output: (batch, num_heads, seq_len, d_k)
|
| 171 |
+
# attn_weights: (batch, num_heads, seq_len, seq_len)
|
| 172 |
+
|
| 173 |
+
# Concatenate heads
|
| 174 |
+
# (batch, num_heads, seq_len, d_k) → (batch, seq_len, num_heads, d_k) → (batch, seq_len, d_model)
|
| 175 |
+
output = output.transpose(1, 2).contiguous()
|
| 176 |
+
output = output.view(batch_size, -1, self.d_model) # -1 in view means 'infer this dimension'
|
| 177 |
+
# After transpose, the tensor's memory layout
|
| 178 |
+
# is "scattered", contiguous() just reorganizes it in memory
|
| 179 |
+
|
| 180 |
+
# Final linear projection
|
| 181 |
+
output = self.W_O(output)
|
| 182 |
+
# Apply dropout
|
| 183 |
+
output = self.dropout(output)
|
| 184 |
+
|
| 185 |
+
return output, attn_weights
|
src/models/decoder.py
ADDED
|
File without changes
|
src/models/encoder.py
ADDED
|
File without changes
|
src/models/heads.py
ADDED
|
File without changes
|
src/models/multitask.py
ADDED
|
File without changes
|
src/models/positional_encoding.py
ADDED
|
File without changes
|
tests/test_models/test_attention.py
CHANGED
|
@@ -6,8 +6,7 @@ Run with: pytest tests/test_models/test_attention.py -v
|
|
| 6 |
|
| 7 |
import pytest
|
| 8 |
import torch
|
| 9 |
-
from src.models.attention import ScaledDotProductAttention
|
| 10 |
-
|
| 11 |
|
| 12 |
class TestScaledDotProductAttention:
|
| 13 |
"""Test suite for ScaledDotProductAttention."""
|
|
@@ -55,6 +54,94 @@ class TestScaledDotProductAttention:
|
|
| 55 |
assert torch.allclose(weights[:, :, 3:], torch.zeros(batch_size, seq_len, 2), atol=1e-6)
|
| 56 |
|
| 57 |
# TODO: Add more tests as you understand the mechanism better
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
|
| 59 |
|
| 60 |
if __name__ == "__main__":
|
|
|
|
| 6 |
|
| 7 |
import pytest
|
| 8 |
import torch
|
| 9 |
+
from src.models.attention import ScaledDotProductAttention, MultiHeadAttention
|
|
|
|
| 10 |
|
| 11 |
class TestScaledDotProductAttention:
|
| 12 |
"""Test suite for ScaledDotProductAttention."""
|
|
|
|
| 54 |
assert torch.allclose(weights[:, :, 3:], torch.zeros(batch_size, seq_len, 2), atol=1e-6)
|
| 55 |
|
| 56 |
# TODO: Add more tests as you understand the mechanism better
|
| 57 |
+
class TestMultiHeadAttention:
|
| 58 |
+
"""Test suite for MultiHeadAttention."""
|
| 59 |
+
|
| 60 |
+
def test_output_shape(self):
|
| 61 |
+
"""Test that output shapes are correct."""
|
| 62 |
+
d_model, num_heads = 512, 8
|
| 63 |
+
batch_size, seq_len = 2, 10
|
| 64 |
+
|
| 65 |
+
mha = MultiHeadAttention(d_model, num_heads)
|
| 66 |
+
|
| 67 |
+
Q = K = V = torch.randn(batch_size, seq_len, d_model)
|
| 68 |
+
output, attn_weights = mha(Q, K, V)
|
| 69 |
+
|
| 70 |
+
assert output.shape == (batch_size, seq_len, d_model)
|
| 71 |
+
assert attn_weights.shape == (batch_size, num_heads, seq_len, seq_len)
|
| 72 |
+
|
| 73 |
+
def test_different_qkv(self):
|
| 74 |
+
"""Test with different Q, K, V (cross-attention scenario)."""
|
| 75 |
+
d_model, num_heads = 512, 8
|
| 76 |
+
batch_size = 2
|
| 77 |
+
seq_len_q, seq_len_kv = 10, 20
|
| 78 |
+
|
| 79 |
+
mha = MultiHeadAttention(d_model, num_heads)
|
| 80 |
+
|
| 81 |
+
Q = torch.randn(batch_size, seq_len_q, d_model)
|
| 82 |
+
K = torch.randn(batch_size, seq_len_kv, d_model)
|
| 83 |
+
V = torch.randn(batch_size, seq_len_kv, d_model)
|
| 84 |
+
|
| 85 |
+
output, attn_weights = mha(Q, K, V)
|
| 86 |
+
|
| 87 |
+
# Output has same length as query
|
| 88 |
+
assert output.shape == (batch_size, seq_len_q, d_model)
|
| 89 |
+
# Attention is query_len x key_len
|
| 90 |
+
assert attn_weights.shape == (batch_size, num_heads, seq_len_q, seq_len_kv)
|
| 91 |
+
|
| 92 |
+
def test_masking(self):
|
| 93 |
+
"""Test that masking works correctly."""
|
| 94 |
+
d_model, num_heads = 512, 8
|
| 95 |
+
batch_size, seq_len = 2, 5
|
| 96 |
+
|
| 97 |
+
mha = MultiHeadAttention(d_model, num_heads)
|
| 98 |
+
Q = K = V = torch.randn(batch_size, seq_len, d_model)
|
| 99 |
+
|
| 100 |
+
# Mask out last 2 positions
|
| 101 |
+
mask = torch.ones(batch_size, seq_len, seq_len, dtype=torch.bool)
|
| 102 |
+
mask[:, :, -2:] = False
|
| 103 |
+
|
| 104 |
+
_, attn_weights = mha(Q, K, V, mask)
|
| 105 |
+
|
| 106 |
+
# Last 2 positions should have near-zero attention
|
| 107 |
+
assert torch.allclose(
|
| 108 |
+
attn_weights[:, :, :, -2:],
|
| 109 |
+
torch.zeros(batch_size, num_heads, seq_len, 2),
|
| 110 |
+
atol=1e-6
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
def test_parameters_exist(self):
|
| 114 |
+
"""Test that learnable parameters are created."""
|
| 115 |
+
mha = MultiHeadAttention(512, 8)
|
| 116 |
+
|
| 117 |
+
# Should have 4 linear layers worth of parameters
|
| 118 |
+
param_names = [name for name, _ in mha.named_parameters()]
|
| 119 |
+
|
| 120 |
+
assert any('W_Q' in name or 'q_linear' in name.lower() for name in param_names)
|
| 121 |
+
assert any('W_K' in name or 'k_linear' in name.lower() for name in param_names)
|
| 122 |
+
assert any('W_V' in name or 'v_linear' in name.lower() for name in param_names)
|
| 123 |
+
assert any('W_O' in name or 'out' in name.lower() for name in param_names)
|
| 124 |
+
|
| 125 |
+
def test_dropout_changes_output(self):
|
| 126 |
+
"""Test that dropout is actually applied during training."""
|
| 127 |
+
torch.manual_seed(42)
|
| 128 |
+
mha = MultiHeadAttention(512, 8, dropout=0.5)
|
| 129 |
+
mha.train() # Enable training mode
|
| 130 |
+
|
| 131 |
+
Q = K = V = torch.randn(2, 10, 512)
|
| 132 |
+
|
| 133 |
+
# Run twice with same input - should get different outputs due to dropout
|
| 134 |
+
output1, _ = mha(Q, K, V)
|
| 135 |
+
output2, _ = mha(Q, K, V)
|
| 136 |
+
|
| 137 |
+
assert not torch.allclose(output1, output2)
|
| 138 |
+
|
| 139 |
+
# In eval mode, should be deterministic
|
| 140 |
+
mha.eval()
|
| 141 |
+
output3, _ = mha(Q, K, V)
|
| 142 |
+
output4, _ = mha(Q, K, V)
|
| 143 |
+
|
| 144 |
+
assert torch.allclose(output3, output4)
|
| 145 |
|
| 146 |
|
| 147 |
if __name__ == "__main__":
|
tests/test_models/test_attention_visual.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Create a file: tests/test_models/test_attention_visual.py
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import matplotlib.pyplot as plt
|
| 5 |
+
import seaborn as sns
|
| 6 |
+
from src.models.attention import ScaledDotProductAttention
|
| 7 |
+
|
| 8 |
+
def test_attention_visualization():
|
| 9 |
+
"""Visual test to understand attention patterns."""
|
| 10 |
+
attention = ScaledDotProductAttention()
|
| 11 |
+
|
| 12 |
+
# Create a simple case: 5 tokens, each token attends most to itself
|
| 13 |
+
batch_size = 1
|
| 14 |
+
seq_len = 5
|
| 15 |
+
d_k = 64
|
| 16 |
+
|
| 17 |
+
# Create Q, K, V
|
| 18 |
+
torch.manual_seed(42)
|
| 19 |
+
Q = torch.randn(batch_size, seq_len, d_k)
|
| 20 |
+
K = torch.randn(batch_size, seq_len, d_k)
|
| 21 |
+
V = torch.eye(seq_len, d_k).unsqueeze(0) # Identity-like
|
| 22 |
+
|
| 23 |
+
# Compute attention
|
| 24 |
+
output, weights = attention(Q, K, V)
|
| 25 |
+
|
| 26 |
+
# Plot attention weights
|
| 27 |
+
plt.figure(figsize=(8, 6))
|
| 28 |
+
sns.heatmap(
|
| 29 |
+
weights[0].detach().numpy(),
|
| 30 |
+
annot=True,
|
| 31 |
+
fmt='.2f',
|
| 32 |
+
cmap='viridis',
|
| 33 |
+
xticklabels=[f'Key {i}' for i in range(seq_len)],
|
| 34 |
+
yticklabels=[f'Query {i}' for i in range(seq_len)]
|
| 35 |
+
)
|
| 36 |
+
plt.title('Attention Weights Heatmap')
|
| 37 |
+
plt.xlabel('Keys (What we attend TO)')
|
| 38 |
+
plt.ylabel('Queries (What is attending)')
|
| 39 |
+
plt.tight_layout()
|
| 40 |
+
plt.savefig('outputs/attention_visualization.png')
|
| 41 |
+
print("✅ Saved visualization to outputs/attention_visualization.png")
|
| 42 |
+
|
| 43 |
+
# Print some analysis
|
| 44 |
+
print("\n" + "="*50)
|
| 45 |
+
print("Attention Analysis")
|
| 46 |
+
print("="*50)
|
| 47 |
+
for i in range(seq_len):
|
| 48 |
+
max_attn_idx = weights[0, i].argmax().item()
|
| 49 |
+
max_attn_val = weights[0, i, max_attn_idx].item()
|
| 50 |
+
print(f"Query {i} attends most to Key {max_attn_idx} (weight: {max_attn_val:.3f})")
|
| 51 |
+
|
| 52 |
+
if __name__ == "__main__":
|
| 53 |
+
test_attention_visualization()
|
tests/test_models/test_multihead_visual.py
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# tests/test_models/test_multihead_visual.py
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import matplotlib.pyplot as plt
|
| 5 |
+
import seaborn as sns
|
| 6 |
+
import numpy as np
|
| 7 |
+
from src.models.attention import MultiHeadAttention
|
| 8 |
+
|
| 9 |
+
def visualize_multihead_attention():
|
| 10 |
+
"""
|
| 11 |
+
Visual test to see what different attention heads learn.
|
| 12 |
+
Creates a heatmap showing attention patterns for each head.
|
| 13 |
+
"""
|
| 14 |
+
# Setup
|
| 15 |
+
torch.manual_seed(42)
|
| 16 |
+
d_model, num_heads = 512, 8
|
| 17 |
+
batch_size, seq_len = 1, 10
|
| 18 |
+
|
| 19 |
+
mha = MultiHeadAttention(d_model, num_heads, dropout=0.0)
|
| 20 |
+
mha.eval() # No dropout for visualization
|
| 21 |
+
|
| 22 |
+
# Create input with some structure
|
| 23 |
+
# Let's make tokens attend to nearby tokens
|
| 24 |
+
X = torch.randn(batch_size, seq_len, d_model)
|
| 25 |
+
|
| 26 |
+
# Add positional bias (tokens are more similar to nearby tokens)
|
| 27 |
+
for i in range(seq_len):
|
| 28 |
+
for j in range(seq_len):
|
| 29 |
+
distance = abs(i - j)
|
| 30 |
+
X[0, i] += 0.5 * X[0, j] / (distance + 1)
|
| 31 |
+
|
| 32 |
+
# Forward pass
|
| 33 |
+
output, attn_weights = mha(X, X, X)
|
| 34 |
+
|
| 35 |
+
# attn_weights shape: (1, 8, 10, 10) = batch, heads, query_pos, key_pos
|
| 36 |
+
attn_weights = attn_weights[0].detach().numpy() # Remove batch dim: (8, 10, 10)
|
| 37 |
+
|
| 38 |
+
# Create visualization
|
| 39 |
+
fig, axes = plt.subplots(2, 4, figsize=(16, 8))
|
| 40 |
+
fig.suptitle('Multi-Head Attention: What Each Head Learns', fontsize=16, y=1.02)
|
| 41 |
+
|
| 42 |
+
for head_idx in range(num_heads):
|
| 43 |
+
row = head_idx // 4
|
| 44 |
+
col = head_idx % 4
|
| 45 |
+
ax = axes[row, col]
|
| 46 |
+
|
| 47 |
+
# Plot attention heatmap for this head
|
| 48 |
+
sns.heatmap(
|
| 49 |
+
attn_weights[head_idx],
|
| 50 |
+
annot=True,
|
| 51 |
+
fmt='.2f',
|
| 52 |
+
cmap='viridis',
|
| 53 |
+
cbar=True,
|
| 54 |
+
square=True,
|
| 55 |
+
ax=ax,
|
| 56 |
+
vmin=0,
|
| 57 |
+
vmax=attn_weights[head_idx].max(),
|
| 58 |
+
xticklabels=[f'K{i}' for i in range(seq_len)],
|
| 59 |
+
yticklabels=[f'Q{i}' for i in range(seq_len)]
|
| 60 |
+
)
|
| 61 |
+
ax.set_title(f'Head {head_idx}', fontweight='bold')
|
| 62 |
+
ax.set_xlabel('Keys (attend TO)')
|
| 63 |
+
ax.set_ylabel('Queries (attending FROM)')
|
| 64 |
+
|
| 65 |
+
plt.tight_layout()
|
| 66 |
+
plt.savefig('outputs/multihead_attention_visualization.png', dpi=150, bbox_inches='tight')
|
| 67 |
+
print("✅ Saved visualization to outputs/multihead_attention_visualization.png")
|
| 68 |
+
|
| 69 |
+
# Print statistics
|
| 70 |
+
print("\n" + "="*60)
|
| 71 |
+
print("Multi-Head Attention Analysis")
|
| 72 |
+
print("="*60)
|
| 73 |
+
|
| 74 |
+
for head_idx in range(num_heads):
|
| 75 |
+
head_attn = attn_weights[head_idx]
|
| 76 |
+
|
| 77 |
+
# Find dominant pattern
|
| 78 |
+
diagonal_strength = np.trace(head_attn) / seq_len
|
| 79 |
+
off_diagonal = (head_attn.sum() - np.trace(head_attn)) / (seq_len * (seq_len - 1))
|
| 80 |
+
|
| 81 |
+
print(f"\nHead {head_idx}:")
|
| 82 |
+
print(f" Self-attention strength: {diagonal_strength:.3f}")
|
| 83 |
+
print(f" Cross-attention strength: {off_diagonal:.3f}")
|
| 84 |
+
|
| 85 |
+
# Find which position each query attends to most
|
| 86 |
+
max_attentions = head_attn.argmax(axis=1)
|
| 87 |
+
print(f" Attention pattern: {max_attentions.tolist()}")
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def compare_single_vs_multihead():
|
| 91 |
+
"""
|
| 92 |
+
Compare single-head vs multi-head attention capacity.
|
| 93 |
+
"""
|
| 94 |
+
torch.manual_seed(42)
|
| 95 |
+
seq_len, d_model = 8, 512
|
| 96 |
+
|
| 97 |
+
# Create data with two different patterns
|
| 98 |
+
# Pattern 1: Sequential (token i attends to i+1)
|
| 99 |
+
# Pattern 2: Pairwise (tokens 0-1, 2-3, 4-5, 6-7 attend to each other)
|
| 100 |
+
|
| 101 |
+
X = torch.randn(1, seq_len, d_model)
|
| 102 |
+
|
| 103 |
+
# Test with 1 head vs 8 heads
|
| 104 |
+
mha_1head = MultiHeadAttention(d_model, num_heads=1, dropout=0.0)
|
| 105 |
+
mha_8heads = MultiHeadAttention(d_model, num_heads=8, dropout=0.0)
|
| 106 |
+
|
| 107 |
+
mha_1head.eval()
|
| 108 |
+
mha_8heads.eval()
|
| 109 |
+
|
| 110 |
+
_, attn_1head = mha_1head(X, X, X)
|
| 111 |
+
_, attn_8heads = mha_8heads(X, X, X)
|
| 112 |
+
|
| 113 |
+
# Plot comparison
|
| 114 |
+
fig, axes = plt.subplots(1, 2, figsize=(12, 5))
|
| 115 |
+
|
| 116 |
+
# Single head
|
| 117 |
+
sns.heatmap(
|
| 118 |
+
attn_1head[0, 0].detach().numpy(),
|
| 119 |
+
annot=True,
|
| 120 |
+
fmt='.2f',
|
| 121 |
+
cmap='viridis',
|
| 122 |
+
cbar=True,
|
| 123 |
+
square=True,
|
| 124 |
+
ax=axes[0]
|
| 125 |
+
)
|
| 126 |
+
axes[0].set_title('Single-Head Attention\n(Limited expressiveness)', fontweight='bold')
|
| 127 |
+
axes[0].set_xlabel('Keys')
|
| 128 |
+
axes[0].set_ylabel('Queries')
|
| 129 |
+
|
| 130 |
+
# Multi-head average
|
| 131 |
+
avg_attn = attn_8heads[0].mean(dim=0).detach().numpy()
|
| 132 |
+
sns.heatmap(
|
| 133 |
+
avg_attn,
|
| 134 |
+
annot=True,
|
| 135 |
+
fmt='.2f',
|
| 136 |
+
cmap='viridis',
|
| 137 |
+
cbar=True,
|
| 138 |
+
square=True,
|
| 139 |
+
ax=axes[1]
|
| 140 |
+
)
|
| 141 |
+
axes[1].set_title('8-Head Attention (Average)\n(Richer patterns)', fontweight='bold')
|
| 142 |
+
axes[1].set_xlabel('Keys')
|
| 143 |
+
axes[1].set_ylabel('Queries')
|
| 144 |
+
|
| 145 |
+
plt.tight_layout()
|
| 146 |
+
plt.savefig('outputs/single_vs_multihead.png', dpi=150, bbox_inches='tight')
|
| 147 |
+
print("✅ Saved comparison to outputs/single_vs_multihead.png")
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
if __name__ == "__main__":
|
| 151 |
+
import os
|
| 152 |
+
os.makedirs('outputs', exist_ok=True)
|
| 153 |
+
|
| 154 |
+
print("Visualizing multi-head attention patterns...")
|
| 155 |
+
visualize_multihead_attention()
|
| 156 |
+
|
| 157 |
+
print("\nComparing single-head vs multi-head...")
|
| 158 |
+
compare_single_vs_multihead()
|
| 159 |
+
|
| 160 |
+
print("\n" + "="*60)
|
| 161 |
+
print("✅ All visualizations complete!")
|
| 162 |
+
print("="*60)
|