File size: 4,907 Bytes
1bdd1c1
d18b34d
 
1bdd1c1
 
d18b34d
 
1bdd1c1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d18b34d
1bdd1c1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d18b34d
1bdd1c1
 
 
 
 
 
67c3a83
 
 
1bdd1c1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d18b34d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import pytest
import torch

from src.models.decoder import (
    TransformerDecoder,
    TransformerDecoderLayer,
    create_causal_mask,
)


def test_create_causal_mask_properties():
    mask = create_causal_mask(5)
    assert mask.shape == (5, 5)
    # diagonal and below should be True
    for i in range(5):
        for j in range(5):
            if j <= i:
                assert mask[i, j].item() is True
            else:
                assert mask[i, j].item() is False


def test_decoder_layer_shapes_and_grad():
    torch.manual_seed(0)
    d_model, num_heads, d_ff = 32, 4, 64
    batch_size, tgt_len, src_len = 2, 6, 7

    layer = TransformerDecoderLayer(d_model=d_model, num_heads=num_heads, d_ff=d_ff, dropout=0.0)
    tgt = torch.randn(batch_size, tgt_len, d_model, requires_grad=True)
    memory = torch.randn(batch_size, src_len, d_model)

    # No masks
    out, attn = layer(tgt, memory, tgt_mask=None, memory_mask=None, collect_attn=True)
    assert out.shape == (batch_size, tgt_len, d_model)
    assert isinstance(attn, dict)
    assert "self" in attn and "cross" in attn
    assert attn["self"].shape == (batch_size, num_heads, tgt_len, tgt_len)
    assert attn["cross"].shape == (batch_size, num_heads, tgt_len, src_len)

    # Backprop works
    loss = out.sum()
    loss.backward()
    grads = [p.grad for p in layer.parameters() if p.requires_grad]
    assert any(g is not None for g in grads)


def test_decoder_layer_causal_mask_blocks_future():
    torch.manual_seed(1)
    d_model, num_heads, d_ff = 48, 6, 128
    batch_size, tgt_len, src_len = 1, 5, 5

    layer = TransformerDecoderLayer(d_model=d_model, num_heads=num_heads, d_ff=d_ff, dropout=0.0)
    # create trivial increasing tgt embeddings so attention patterns are deterministic-ish
    tgt = torch.randn(batch_size, tgt_len, d_model)
    memory = torch.randn(batch_size, src_len, d_model)

    causal = create_causal_mask(tgt_len, device=tgt.device)  # (T, T)
    tgt_mask = causal.unsqueeze(0)  # (1, T, T) -> layer will handle unsqueeze to heads

    out, attn = layer(tgt, memory, tgt_mask=tgt_mask, memory_mask=None, collect_attn=True)
    self_attn = attn["self"].detach()
    # Ensure upper triangle of attention weights is zero (no future attention)
    # For each head and query i, keys j>i should be zero
    B, H, Tq, Tk = self_attn.shape
    for i in range(Tq):
        for j in range(i + 1, Tk):
            assert torch.allclose(self_attn[:, :, i, j], torch.zeros(B, H)), (
                f"Found nonzero attention to future position {j} from query {i}"
            )


def test_decoder_stack_and_greedy_decode_shapes():
    torch.manual_seed(2)
    vocab_size = 30
    d_model = 32
    num_layers = 2
    num_heads = 4
    d_ff = 128
    batch_size = 2
    src_len = 7
    max_tgt = 6

    decoder = TransformerDecoder(
        vocab_size=vocab_size,
        d_model=d_model,
        num_layers=num_layers,
        num_heads=num_heads,
        d_ff=d_ff,
        dropout=0.0,
        max_len=max_tgt,
        pad_token_id=0,
    )

    # Random memory from encoder
    memory = torch.randn(batch_size, src_len, d_model)

    # Greedy decode: should produce (B, <= max_tgt)
    generated = decoder.greedy_decode(memory, max_len=max_tgt, start_token_id=1, end_token_id=None)
    assert generated.shape[0] == batch_size
    assert generated.shape[1] <= max_tgt
    assert (generated[:, 0] == 1).all()  # starts with start token

    # Also test forward with embeddings and collect_attn
    embeddings = torch.randn(batch_size, max_tgt, d_model)
    logits, attn_list = decoder(embeddings, memory, collect_attn=True)
    assert logits.shape == (batch_size, max_tgt, vocab_size)
    assert isinstance(attn_list, list)
    assert len(attn_list) == num_layers
    for attn in attn_list:
        assert "self" in attn and "cross" in attn


def test_decoder_train_eval_dropout_behavior():
    torch.manual_seed(3)
    vocab_size = 40
    d_model = 32
    num_layers = 2
    num_heads = 4
    d_ff = 128
    batch_size = 2
    src_len = 6
    tgt_len = 5

    decoder = TransformerDecoder(
        vocab_size=vocab_size,
        d_model=d_model,
        num_layers=num_layers,
        num_heads=num_heads,
        d_ff=d_ff,
        dropout=0.4,
        max_len=tgt_len,
        pad_token_id=0,
    )

    # token ids with padding possible
    input_ids = torch.randint(1, vocab_size, (batch_size, tgt_len), dtype=torch.long)
    input_ids[0, -1] = 0

    memory = torch.randn(batch_size, src_len, d_model)

    decoder.train()
    out1 = decoder(input_ids, memory)
    out2 = decoder(input_ids, memory)
    # With dropout in train mode, outputs should usually differ
    assert not torch.allclose(out1, out2)

    decoder.eval()
    out3 = decoder(input_ids, memory)
    out4 = decoder(input_ids, memory)
    assert torch.allclose(out3, out4)


if __name__ == "__main__":
    pytest.main([__file__, "-q"])