OliverPerrin commited on
Commit
204fb3c
·
1 Parent(s): ba4cb76

Implemented ScaledDotProduct Attention and Multi-Head Attention

Browse files
src/models/attention.py CHANGED
@@ -5,6 +5,8 @@ This module implements the core attention mechanisms used in the Transformer mod
5
  - ScaledDotProductAttention: Fundamental attention operation
6
  - MultiHeadAttention: Parallel attention with learned projections
7
 
 
 
8
  Author: Oliver Perrin
9
  Date: 2025-10-23
10
  """
@@ -48,7 +50,7 @@ class ScaledDotProductAttention(nn.Module):
48
 
49
  def __init__(self):
50
  super().__init__()
51
- # TODO: Do you need any parameters here?
52
  pass
53
 
54
  def forward(
@@ -69,7 +71,115 @@ class ScaledDotProductAttention(nn.Module):
69
  5. Compute output: output = attention_weights @ value
70
  6. Return both output and attention_weights
71
  """
72
- pass
 
 
 
 
73
 
 
 
 
 
 
 
 
 
 
74
 
75
- # TODO: After you implement ScaledDotProductAttention, we'll add MultiHeadAttention
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  - ScaledDotProductAttention: Fundamental attention operation
6
  - MultiHeadAttention: Parallel attention with learned projections
7
 
8
+ Doing this first for Bottom-Up implementation of the Transformer
9
+
10
  Author: Oliver Perrin
11
  Date: 2025-10-23
12
  """
 
50
 
51
  def __init__(self):
52
  super().__init__()
53
+ # Params not needed here.
54
  pass
55
 
56
  def forward(
 
71
  5. Compute output: output = attention_weights @ value
72
  6. Return both output and attention_weights
73
  """
74
+ # Getting Dimension for Scaling
75
+ d_k = query.size(-1)
76
+
77
+ # Compute Attention Scores
78
+ scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
79
 
80
+ # Mask if provided
81
+ if mask is not None:
82
+ scores = scores.masked_fill(mask == 0, float('-inf'))
83
+ # Applying Softmax to get attention weights
84
+ attention_weights = F.softmax(scores, dim=-1)
85
+
86
+ return torch.matmul(attention_weights, value), attention_weights
87
+
88
+ # --------------- Multi-Head Attention ---------------
89
 
90
+ class MultiHeadAttention(nn.Module):
91
+ """
92
+ Multi-Head Attention mechanism.
93
+
94
+ Allows the model to jointly attend to information from different
95
+ representation subspaces at different positions.
96
+
97
+ Transforming the input into query, key, and value representations
98
+
99
+ Args:
100
+ d_model: Dimension of model (default: 512)
101
+ num_heads: Number of attention heads (default: 8)
102
+ dropout: Dropout probability (default: 0.1)
103
+ """
104
+
105
+ def __init__(self, d_model: int = 512, num_heads: int = 8, dropout: float = 0.1):
106
+ super().__init__()
107
+
108
+ # Assert that d_model is divisible by num_heads
109
+ # Why? Because d_k = d_model // num_heads must be an integer
110
+ assert d_model % num_heads == 0
111
+
112
+ # Assume d_v always equals d_k
113
+ self.d_model = d_model
114
+ self.num_heads = num_heads
115
+ self.d_k = d_model // num_heads
116
+
117
+ # Create 4 linear layers (W_Q, W_K, W_V, W_O)
118
+ # All should be nn.Linear(d_model, d_model)
119
+ self.W_Q = nn.Linear(d_model, d_model)
120
+ self.W_K = nn.Linear(d_model, d_model)
121
+ self.W_V = nn.Linear(d_model, d_model)
122
+ self.W_O = nn.Linear(d_model, d_model)
123
+ # Create ScaledDotProductAttention instance
124
+ self.attention = ScaledDotProductAttention()
125
+ # Create dropout layer
126
+ self.dropout = nn.Dropout(p=dropout)
127
+
128
+ def forward(
129
+ self,
130
+ query: torch.Tensor,
131
+ key: torch.Tensor,
132
+ value: torch.Tensor,
133
+ mask: Optional[torch.Tensor] = None
134
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
135
+ """
136
+ Args:
137
+ query: (batch, seq_len, d_model)
138
+ key: (batch, seq_len, d_model)
139
+ value: (batch, seq_len, d_model)
140
+ mask: Optional (batch, seq_len, seq_len) or (batch, 1, seq_len, seq_len)
141
+
142
+ Returns:
143
+ output: (batch, seq_len, d_model)
144
+ attention_weights: (batch, num_heads, seq_len, seq_len)
145
+ """
146
+ batch_size = query.size(0)
147
+
148
+ # Linear projections
149
+ Q = self.W_Q(query) # (batch, seq_len, d_model)
150
+ K = self.W_K(key)
151
+ V = self.W_V(value)
152
+
153
+ # Split into heads
154
+ # Reshape from (batch, seq_len, d_model) to (batch, num_heads, seq_len, d_k), Apply to Q, K, V
155
+ Q = Q.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
156
+ K = K.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
157
+ V = V.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
158
+ # Now: (batch, num_heads, seq_len, d_k)
159
+ # Now all are: (batch=2, num_heads=8, seq_len=10, d_k=64)
160
+
161
+ # Handle mask broadcasting for multi-head attention
162
+ if mask is not None:
163
+ # If mask is 3D (batch, seq, seq), add head dimension
164
+ if mask.dim() == 3:
165
+ mask = mask.unsqueeze(1) # (batch, 1, seq, seq)
166
+ # Now mask broadcasts across all heads: (batch, 1, seq, seq) → (batch, 8, seq, seq)
167
+
168
+ # Apply attention
169
+ output, attn_weights = self.attention(Q, K, V, mask)
170
+ # output: (batch, num_heads, seq_len, d_k)
171
+ # attn_weights: (batch, num_heads, seq_len, seq_len)
172
+
173
+ # Concatenate heads
174
+ # (batch, num_heads, seq_len, d_k) → (batch, seq_len, num_heads, d_k) → (batch, seq_len, d_model)
175
+ output = output.transpose(1, 2).contiguous()
176
+ output = output.view(batch_size, -1, self.d_model) # -1 in view means 'infer this dimension'
177
+ # After transpose, the tensor's memory layout
178
+ # is "scattered", contiguous() just reorganizes it in memory
179
+
180
+ # Final linear projection
181
+ output = self.W_O(output)
182
+ # Apply dropout
183
+ output = self.dropout(output)
184
+
185
+ return output, attn_weights
src/models/decoder.py ADDED
File without changes
src/models/encoder.py ADDED
File without changes
src/models/heads.py ADDED
File without changes
src/models/multitask.py ADDED
File without changes
src/models/positional_encoding.py ADDED
File without changes
tests/test_models/test_attention.py CHANGED
@@ -6,8 +6,7 @@ Run with: pytest tests/test_models/test_attention.py -v
6
 
7
  import pytest
8
  import torch
9
- from src.models.attention import ScaledDotProductAttention
10
-
11
 
12
  class TestScaledDotProductAttention:
13
  """Test suite for ScaledDotProductAttention."""
@@ -55,6 +54,94 @@ class TestScaledDotProductAttention:
55
  assert torch.allclose(weights[:, :, 3:], torch.zeros(batch_size, seq_len, 2), atol=1e-6)
56
 
57
  # TODO: Add more tests as you understand the mechanism better
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
 
60
  if __name__ == "__main__":
 
6
 
7
  import pytest
8
  import torch
9
+ from src.models.attention import ScaledDotProductAttention, MultiHeadAttention
 
10
 
11
  class TestScaledDotProductAttention:
12
  """Test suite for ScaledDotProductAttention."""
 
54
  assert torch.allclose(weights[:, :, 3:], torch.zeros(batch_size, seq_len, 2), atol=1e-6)
55
 
56
  # TODO: Add more tests as you understand the mechanism better
57
+ class TestMultiHeadAttention:
58
+ """Test suite for MultiHeadAttention."""
59
+
60
+ def test_output_shape(self):
61
+ """Test that output shapes are correct."""
62
+ d_model, num_heads = 512, 8
63
+ batch_size, seq_len = 2, 10
64
+
65
+ mha = MultiHeadAttention(d_model, num_heads)
66
+
67
+ Q = K = V = torch.randn(batch_size, seq_len, d_model)
68
+ output, attn_weights = mha(Q, K, V)
69
+
70
+ assert output.shape == (batch_size, seq_len, d_model)
71
+ assert attn_weights.shape == (batch_size, num_heads, seq_len, seq_len)
72
+
73
+ def test_different_qkv(self):
74
+ """Test with different Q, K, V (cross-attention scenario)."""
75
+ d_model, num_heads = 512, 8
76
+ batch_size = 2
77
+ seq_len_q, seq_len_kv = 10, 20
78
+
79
+ mha = MultiHeadAttention(d_model, num_heads)
80
+
81
+ Q = torch.randn(batch_size, seq_len_q, d_model)
82
+ K = torch.randn(batch_size, seq_len_kv, d_model)
83
+ V = torch.randn(batch_size, seq_len_kv, d_model)
84
+
85
+ output, attn_weights = mha(Q, K, V)
86
+
87
+ # Output has same length as query
88
+ assert output.shape == (batch_size, seq_len_q, d_model)
89
+ # Attention is query_len x key_len
90
+ assert attn_weights.shape == (batch_size, num_heads, seq_len_q, seq_len_kv)
91
+
92
+ def test_masking(self):
93
+ """Test that masking works correctly."""
94
+ d_model, num_heads = 512, 8
95
+ batch_size, seq_len = 2, 5
96
+
97
+ mha = MultiHeadAttention(d_model, num_heads)
98
+ Q = K = V = torch.randn(batch_size, seq_len, d_model)
99
+
100
+ # Mask out last 2 positions
101
+ mask = torch.ones(batch_size, seq_len, seq_len, dtype=torch.bool)
102
+ mask[:, :, -2:] = False
103
+
104
+ _, attn_weights = mha(Q, K, V, mask)
105
+
106
+ # Last 2 positions should have near-zero attention
107
+ assert torch.allclose(
108
+ attn_weights[:, :, :, -2:],
109
+ torch.zeros(batch_size, num_heads, seq_len, 2),
110
+ atol=1e-6
111
+ )
112
+
113
+ def test_parameters_exist(self):
114
+ """Test that learnable parameters are created."""
115
+ mha = MultiHeadAttention(512, 8)
116
+
117
+ # Should have 4 linear layers worth of parameters
118
+ param_names = [name for name, _ in mha.named_parameters()]
119
+
120
+ assert any('W_Q' in name or 'q_linear' in name.lower() for name in param_names)
121
+ assert any('W_K' in name or 'k_linear' in name.lower() for name in param_names)
122
+ assert any('W_V' in name or 'v_linear' in name.lower() for name in param_names)
123
+ assert any('W_O' in name or 'out' in name.lower() for name in param_names)
124
+
125
+ def test_dropout_changes_output(self):
126
+ """Test that dropout is actually applied during training."""
127
+ torch.manual_seed(42)
128
+ mha = MultiHeadAttention(512, 8, dropout=0.5)
129
+ mha.train() # Enable training mode
130
+
131
+ Q = K = V = torch.randn(2, 10, 512)
132
+
133
+ # Run twice with same input - should get different outputs due to dropout
134
+ output1, _ = mha(Q, K, V)
135
+ output2, _ = mha(Q, K, V)
136
+
137
+ assert not torch.allclose(output1, output2)
138
+
139
+ # In eval mode, should be deterministic
140
+ mha.eval()
141
+ output3, _ = mha(Q, K, V)
142
+ output4, _ = mha(Q, K, V)
143
+
144
+ assert torch.allclose(output3, output4)
145
 
146
 
147
  if __name__ == "__main__":
tests/test_models/test_attention_visual.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Create a file: tests/test_models/test_attention_visual.py
2
+
3
+ import torch
4
+ import matplotlib.pyplot as plt
5
+ import seaborn as sns
6
+ from src.models.attention import ScaledDotProductAttention
7
+
8
+ def test_attention_visualization():
9
+ """Visual test to understand attention patterns."""
10
+ attention = ScaledDotProductAttention()
11
+
12
+ # Create a simple case: 5 tokens, each token attends most to itself
13
+ batch_size = 1
14
+ seq_len = 5
15
+ d_k = 64
16
+
17
+ # Create Q, K, V
18
+ torch.manual_seed(42)
19
+ Q = torch.randn(batch_size, seq_len, d_k)
20
+ K = torch.randn(batch_size, seq_len, d_k)
21
+ V = torch.eye(seq_len, d_k).unsqueeze(0) # Identity-like
22
+
23
+ # Compute attention
24
+ output, weights = attention(Q, K, V)
25
+
26
+ # Plot attention weights
27
+ plt.figure(figsize=(8, 6))
28
+ sns.heatmap(
29
+ weights[0].detach().numpy(),
30
+ annot=True,
31
+ fmt='.2f',
32
+ cmap='viridis',
33
+ xticklabels=[f'Key {i}' for i in range(seq_len)],
34
+ yticklabels=[f'Query {i}' for i in range(seq_len)]
35
+ )
36
+ plt.title('Attention Weights Heatmap')
37
+ plt.xlabel('Keys (What we attend TO)')
38
+ plt.ylabel('Queries (What is attending)')
39
+ plt.tight_layout()
40
+ plt.savefig('outputs/attention_visualization.png')
41
+ print("✅ Saved visualization to outputs/attention_visualization.png")
42
+
43
+ # Print some analysis
44
+ print("\n" + "="*50)
45
+ print("Attention Analysis")
46
+ print("="*50)
47
+ for i in range(seq_len):
48
+ max_attn_idx = weights[0, i].argmax().item()
49
+ max_attn_val = weights[0, i, max_attn_idx].item()
50
+ print(f"Query {i} attends most to Key {max_attn_idx} (weight: {max_attn_val:.3f})")
51
+
52
+ if __name__ == "__main__":
53
+ test_attention_visualization()
tests/test_models/test_multihead_visual.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # tests/test_models/test_multihead_visual.py
2
+
3
+ import torch
4
+ import matplotlib.pyplot as plt
5
+ import seaborn as sns
6
+ import numpy as np
7
+ from src.models.attention import MultiHeadAttention
8
+
9
+ def visualize_multihead_attention():
10
+ """
11
+ Visual test to see what different attention heads learn.
12
+ Creates a heatmap showing attention patterns for each head.
13
+ """
14
+ # Setup
15
+ torch.manual_seed(42)
16
+ d_model, num_heads = 512, 8
17
+ batch_size, seq_len = 1, 10
18
+
19
+ mha = MultiHeadAttention(d_model, num_heads, dropout=0.0)
20
+ mha.eval() # No dropout for visualization
21
+
22
+ # Create input with some structure
23
+ # Let's make tokens attend to nearby tokens
24
+ X = torch.randn(batch_size, seq_len, d_model)
25
+
26
+ # Add positional bias (tokens are more similar to nearby tokens)
27
+ for i in range(seq_len):
28
+ for j in range(seq_len):
29
+ distance = abs(i - j)
30
+ X[0, i] += 0.5 * X[0, j] / (distance + 1)
31
+
32
+ # Forward pass
33
+ output, attn_weights = mha(X, X, X)
34
+
35
+ # attn_weights shape: (1, 8, 10, 10) = batch, heads, query_pos, key_pos
36
+ attn_weights = attn_weights[0].detach().numpy() # Remove batch dim: (8, 10, 10)
37
+
38
+ # Create visualization
39
+ fig, axes = plt.subplots(2, 4, figsize=(16, 8))
40
+ fig.suptitle('Multi-Head Attention: What Each Head Learns', fontsize=16, y=1.02)
41
+
42
+ for head_idx in range(num_heads):
43
+ row = head_idx // 4
44
+ col = head_idx % 4
45
+ ax = axes[row, col]
46
+
47
+ # Plot attention heatmap for this head
48
+ sns.heatmap(
49
+ attn_weights[head_idx],
50
+ annot=True,
51
+ fmt='.2f',
52
+ cmap='viridis',
53
+ cbar=True,
54
+ square=True,
55
+ ax=ax,
56
+ vmin=0,
57
+ vmax=attn_weights[head_idx].max(),
58
+ xticklabels=[f'K{i}' for i in range(seq_len)],
59
+ yticklabels=[f'Q{i}' for i in range(seq_len)]
60
+ )
61
+ ax.set_title(f'Head {head_idx}', fontweight='bold')
62
+ ax.set_xlabel('Keys (attend TO)')
63
+ ax.set_ylabel('Queries (attending FROM)')
64
+
65
+ plt.tight_layout()
66
+ plt.savefig('outputs/multihead_attention_visualization.png', dpi=150, bbox_inches='tight')
67
+ print("✅ Saved visualization to outputs/multihead_attention_visualization.png")
68
+
69
+ # Print statistics
70
+ print("\n" + "="*60)
71
+ print("Multi-Head Attention Analysis")
72
+ print("="*60)
73
+
74
+ for head_idx in range(num_heads):
75
+ head_attn = attn_weights[head_idx]
76
+
77
+ # Find dominant pattern
78
+ diagonal_strength = np.trace(head_attn) / seq_len
79
+ off_diagonal = (head_attn.sum() - np.trace(head_attn)) / (seq_len * (seq_len - 1))
80
+
81
+ print(f"\nHead {head_idx}:")
82
+ print(f" Self-attention strength: {diagonal_strength:.3f}")
83
+ print(f" Cross-attention strength: {off_diagonal:.3f}")
84
+
85
+ # Find which position each query attends to most
86
+ max_attentions = head_attn.argmax(axis=1)
87
+ print(f" Attention pattern: {max_attentions.tolist()}")
88
+
89
+
90
+ def compare_single_vs_multihead():
91
+ """
92
+ Compare single-head vs multi-head attention capacity.
93
+ """
94
+ torch.manual_seed(42)
95
+ seq_len, d_model = 8, 512
96
+
97
+ # Create data with two different patterns
98
+ # Pattern 1: Sequential (token i attends to i+1)
99
+ # Pattern 2: Pairwise (tokens 0-1, 2-3, 4-5, 6-7 attend to each other)
100
+
101
+ X = torch.randn(1, seq_len, d_model)
102
+
103
+ # Test with 1 head vs 8 heads
104
+ mha_1head = MultiHeadAttention(d_model, num_heads=1, dropout=0.0)
105
+ mha_8heads = MultiHeadAttention(d_model, num_heads=8, dropout=0.0)
106
+
107
+ mha_1head.eval()
108
+ mha_8heads.eval()
109
+
110
+ _, attn_1head = mha_1head(X, X, X)
111
+ _, attn_8heads = mha_8heads(X, X, X)
112
+
113
+ # Plot comparison
114
+ fig, axes = plt.subplots(1, 2, figsize=(12, 5))
115
+
116
+ # Single head
117
+ sns.heatmap(
118
+ attn_1head[0, 0].detach().numpy(),
119
+ annot=True,
120
+ fmt='.2f',
121
+ cmap='viridis',
122
+ cbar=True,
123
+ square=True,
124
+ ax=axes[0]
125
+ )
126
+ axes[0].set_title('Single-Head Attention\n(Limited expressiveness)', fontweight='bold')
127
+ axes[0].set_xlabel('Keys')
128
+ axes[0].set_ylabel('Queries')
129
+
130
+ # Multi-head average
131
+ avg_attn = attn_8heads[0].mean(dim=0).detach().numpy()
132
+ sns.heatmap(
133
+ avg_attn,
134
+ annot=True,
135
+ fmt='.2f',
136
+ cmap='viridis',
137
+ cbar=True,
138
+ square=True,
139
+ ax=axes[1]
140
+ )
141
+ axes[1].set_title('8-Head Attention (Average)\n(Richer patterns)', fontweight='bold')
142
+ axes[1].set_xlabel('Keys')
143
+ axes[1].set_ylabel('Queries')
144
+
145
+ plt.tight_layout()
146
+ plt.savefig('outputs/single_vs_multihead.png', dpi=150, bbox_inches='tight')
147
+ print("✅ Saved comparison to outputs/single_vs_multihead.png")
148
+
149
+
150
+ if __name__ == "__main__":
151
+ import os
152
+ os.makedirs('outputs', exist_ok=True)
153
+
154
+ print("Visualizing multi-head attention patterns...")
155
+ visualize_multihead_attention()
156
+
157
+ print("\nComparing single-head vs multi-head...")
158
+ compare_single_vs_multihead()
159
+
160
+ print("\n" + "="*60)
161
+ print("✅ All visualizations complete!")
162
+ print("="*60)