File size: 6,475 Bytes
d18b34d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
import os

import matplotlib
import torch

matplotlib.use("Agg")  # use non-interactive backend
import matplotlib.pyplot as plt
import seaborn as sns

from src.models.attention import MultiHeadAttention, ScaledDotProductAttention
from src.models.positional_encoding import PositionalEncoding

OUTPUTS_DIR = "outputs"


def ensure_outputs_dir():
    os.makedirs(OUTPUTS_DIR, exist_ok=True)


def test_attention_visualization():
    """Visual test to understand attention patterns."""
    ensure_outputs_dir()
    attention = ScaledDotProductAttention()

    # Create a simple case: 5 tokens, each token attends most to itself
    batch_size = 1
    seq_len = 5
    d_k = 64

    # Create Q, K, V
    torch.manual_seed(42)
    Q = torch.randn(batch_size, seq_len, d_k)
    K = torch.randn(batch_size, seq_len, d_k)
    V = torch.eye(seq_len, d_k).unsqueeze(0)  # Identity-like

    # Compute attention
    output, weights = attention(Q, K, V, return_attn_weights=True)

    # Plot attention weights
    plt.figure(figsize=(8, 6))
    sns.heatmap(
        weights[0].detach().numpy(),
        annot=True,
        fmt=".2f",
        cmap="viridis",
        xticklabels=[f"Key {i}" for i in range(seq_len)],
        yticklabels=[f"Query {i}" for i in range(seq_len)],
    )
    plt.title("Attention Weights Heatmap")
    plt.xlabel("Keys (What we attend TO)")
    plt.ylabel("Queries (What is attending)")
    plt.tight_layout()
    save_path = os.path.join(OUTPUTS_DIR, "attention_visualization.png")
    plt.savefig(save_path)
    print(f"✅ Saved visualization to {save_path}")
    plt.close()


def test_visualize_multihead_attention():
    """
    Visual test to see what different attention heads learn.
    Creates a heatmap showing attention patterns for each head.
    """
    ensure_outputs_dir()
    # Setup
    torch.manual_seed(42)
    d_model, num_heads = 512, 8
    batch_size, seq_len = 1, 10

    mha = MultiHeadAttention(d_model, num_heads, dropout=0.0)
    mha.eval()  # No dropout for visualization

    # Create input with some structure
    # Let's make tokens attend to nearby tokens
    X = torch.randn(batch_size, seq_len, d_model)

    # Add positional bias (tokens are more similar to nearby tokens)
    for i in range(seq_len):
        for j in range(seq_len):
            distance = abs(i - j)
            X[0, i] += 0.5 * X[0, j] / (distance + 1)

    # Forward pass
    output, attn_weights = mha(X, X, X, return_attn_weights=True)

    # attn_weights shape: (1, 8, 10, 10) = batch, heads, query_pos, key_pos
    attn_weights = attn_weights[0].detach().numpy()  # Remove batch dim: (8, 10, 10)

    # Create visualization
    fig, axes = plt.subplots(2, 4, figsize=(16, 8))
    fig.suptitle("Multi-Head Attention: What Each Head Learns", fontsize=16, y=1.02)

    for head_idx in range(num_heads):
        row = head_idx // 4
        col = head_idx % 4
        ax = axes[row, col]

        # Plot attention heatmap for this head
        sns.heatmap(
            attn_weights[head_idx],
            annot=True,
            fmt=".2f",
            cmap="viridis",
            cbar=True,
            square=True,
            ax=ax,
            vmin=0,
            vmax=attn_weights[head_idx].max(),
            xticklabels=[f"K{i}" for i in range(seq_len)],
            yticklabels=[f"Q{i}" for i in range(seq_len)],
        )
        ax.set_title(f"Head {head_idx}", fontweight="bold")
        ax.set_xlabel("Keys (attend TO)")
        ax.set_ylabel("Queries (attending FROM)")

    plt.tight_layout()
    save_path = os.path.join(OUTPUTS_DIR, "multihead_attention_visualization.png")
    plt.savefig(save_path, dpi=150, bbox_inches="tight")
    print(f"✅ Saved visualization to {save_path}")
    plt.close()


def test_compare_single_vs_multihead():
    """
    Compare single-head vs multi-head attention capacity.
    """
    ensure_outputs_dir()
    torch.manual_seed(42)
    seq_len, d_model = 8, 512

    X = torch.randn(1, seq_len, d_model)

    # Test with 1 head vs 8 heads
    mha_1head = MultiHeadAttention(d_model, num_heads=1, dropout=0.0)
    mha_8heads = MultiHeadAttention(d_model, num_heads=8, dropout=0.0)

    mha_1head.eval()
    mha_8heads.eval()

    _, attn_1head = mha_1head(X, X, X, return_attn_weights=True)
    _, attn_8heads = mha_8heads(X, X, X, return_attn_weights=True)

    # Plot comparison
    fig, axes = plt.subplots(1, 2, figsize=(12, 5))

    # Single head
    sns.heatmap(
        attn_1head[0, 0].detach().numpy(),
        annot=True,
        fmt=".2f",
        cmap="viridis",
        cbar=True,
        square=True,
        ax=axes[0],
    )
    axes[0].set_title("Single-Head Attention\n(Limited expressiveness)", fontweight="bold")
    axes[0].set_xlabel("Keys")
    axes[0].set_ylabel("Queries")

    # Multi-head average
    avg_attn = attn_8heads[0].mean(dim=0).detach().numpy()
    sns.heatmap(avg_attn, annot=True, fmt=".2f", cmap="viridis", cbar=True, square=True, ax=axes[1])
    axes[1].set_title("8-Head Attention (Average)\n(Richer patterns)", fontweight="bold")
    axes[1].set_xlabel("Keys")
    axes[1].set_ylabel("Queries")

    plt.tight_layout()
    save_path = os.path.join(OUTPUTS_DIR, "single_vs_multihead.png")
    plt.savefig(save_path, dpi=150, bbox_inches="tight")
    print(f"✅ Saved comparison to {save_path}")
    plt.close()


def test_visualize_positional_encoding():
    """
    Visualize the positional encoding pattern.
    Creates heatmap showing encoding values.
    """
    ensure_outputs_dir()
    pos_enc = PositionalEncoding(d_model=128, max_len=100, dropout=0.0)

    # Get encoding matrix
    pe = pos_enc.pe.squeeze(0).numpy()  # (max_len, d_model)

    # Plot first 50 positions and 64 dimensions
    plt.figure(figsize=(12, 8))
    sns.heatmap(
        pe[:50, :64].T,
        cmap="RdBu_r",
        center=0,
        xticklabels=5,
        yticklabels=8,
        cbar_kws={"label": "Encoding Value"},
    )
    plt.xlabel("Position in Sequence")
    plt.ylabel("Embedding Dimension")
    plt.title("Positional Encoding Pattern\n(Notice the wave patterns with different frequencies)")
    plt.tight_layout()
    save_path = os.path.join(OUTPUTS_DIR, "positional_encoding_heatmap.png")
    plt.savefig(save_path, dpi=150)
    print(f"✅ Saved to {save_path}")
    plt.close()


if __name__ == "__main__":
    test_attention_visualization()
    test_visualize_multihead_attention()
    test_compare_single_vs_multihead()
    test_visualize_positional_encoding()