OliverPerrin commited on
Commit
5a20c96
·
1 Parent(s): 204fb3c

Implemented the following parts for the Transformer model: poisional_encoding, feedfoward, encoder, and skeleton of decoder. As well as Test cases for each class respectively and visualizations

Browse files
src/models/attention.py CHANGED
@@ -79,11 +79,35 @@ class ScaledDotProductAttention(nn.Module):
79
 
80
  # Mask if provided
81
  if mask is not None:
82
- scores = scores.masked_fill(mask == 0, float('-inf'))
 
 
 
83
  # Applying Softmax to get attention weights
84
  attention_weights = F.softmax(scores, dim=-1)
85
 
86
- return torch.matmul(attention_weights, value), attention_weights
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
  # --------------- Multi-Head Attention ---------------
89
 
 
79
 
80
  # Mask if provided
81
  if mask is not None:
82
+ # Ensure mask is boolean and on same device as scores
83
+ mask_bool = mask.to(dtype=torch.bool, device=scores.device)
84
+ # masked_fill expects broadcastable mask: True means keep, False means mask out
85
+ scores = scores.masked_fill(~mask_bool, float("-1e9"))
86
  # Applying Softmax to get attention weights
87
  attention_weights = F.softmax(scores, dim=-1)
88
 
89
+ # Softmax to get attention probabilities
90
+ p_attn = F.softmax(scores, dim=-1)
91
+
92
+ # If mask was provided, ensure masked positions are exactly zero (and handle all-masked rows)
93
+ if mask is not None:
94
+ # Convert mask to same dtype as p_attn for multiplication
95
+ mask_float = mask.to(dtype=p_attn.dtype, device=p_attn.device)
96
+ # Broadcast-multiply (zero out masked key positions)
97
+ p_attn = p_attn * mask_float
98
+ # Replace any NaNs (can occur when a row was entirely -inf prior to softmax) with 0.0
99
+ # torch.nan_to_num is efficient and handles negative/positive inf as well
100
+ p_attn = torch.nan_to_num(p_attn, nan=0.0, posinf=0.0, neginf=0.0)
101
+
102
+ # re-normalize rows that still have non-zero sum, this is not strictly necessary
103
+ # if mask is correct, but safe to avoid tiny numerical issues:
104
+ row_sums = p_attn.sum(dim=-1, keepdim=True)
105
+ # Avoid division by zero; only divide where row_sums > 0
106
+ nonzero_rows = row_sums > 0
107
+ p_attn = torch.where(nonzero_rows, p_attn / (row_sums + 1e-12), p_attn)
108
+
109
+ output = torch.matmul(p_attn, value)
110
+ return output, p_attn
111
 
112
  # --------------- Multi-Head Attention ---------------
113
 
src/models/decoder.py CHANGED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Transformer Decoder layout (Pre-LN)
3
+
4
+ Contents:
5
+ - create_causal_mask: utility to build a causal (subsequent) mask
6
+ - TransformerDecoderLayer: one decoder block (masked self-attn, cross-attn, FFN)
7
+ - TransformerDecoder: embedding/pos-encoding + stack of decoder layers + generation helpers
8
+
9
+ Notes / conventions:
10
+ - Pre-LN (LayerNorm before each sublayer) is assumed for stability (consistent with your encoder).
11
+ - MultiHeadAttention, FeedForward, PositionalEncoding are expected to live in src/models
12
+ (you already implemented them).
13
+ - Masks use boolean semantics: True = allowed, False = masked.
14
+ - The decoder API supports:
15
+ - inputs: token ids (LongTensor, (B, T)) or embeddings ((B, T, d_model))
16
+ - memory: encoder outputs (B, S, d_model)
17
+ - mask arguments: tgt_mask (causal/padding), memory_mask (encoder padding)
18
+ - collect_attn: return attention maps per layer if requested
19
+ - Generation helpers (greedy) are skeletons that you can extend to beam search or caching.
20
+
21
+ TODO status keys:
22
+ - [IMPLEMENT] : core implementation required
23
+ - [OPTIONAL] : useful enhancement (caching, beam search, advanced scheduling)
24
+ """
25
+
26
+ from typing import Optional, Tuple, List, Union, Dict
27
+ import math
28
+ import torch
29
+ import torch.nn as nn
30
+
31
+ from .attention import MultiHeadAttention
32
+ from .feedforward import FeedForward
33
+ from .positional_encoding import PositionalEncoding
34
+
35
+
36
+ def create_causal_mask(seq_len: int, device: Optional[torch.device] = None) -> torch.Tensor:
37
+ """
38
+ Create a square causal mask of shape (seq_len, seq_len).
39
+ True indicates allowed positions; False indicates masked (future) positions.
40
+
41
+ Returns:
42
+ mask: torch.BoolTensor of shape (seq_len, seq_len)
43
+ """
44
+ # return a mask with True on and below diagonal, False above diagonal
45
+ # The torch.trui function does masking, which is the idea of zeroing all the values in a matrix below a certain diagonal
46
+ mask = torch.triu(torch.ones(seq_len, seq_len, dtype=torch.bool, device=device), diagonal=1)
47
+ # mask has True above diagonal (to be masked). Want True for allowed, so invert:
48
+ return ~mask # shape (seq_len, seq_len) or (T, T)
49
+
50
+
51
+ class TransformerDecoderLayer(nn.Module):
52
+ """
53
+ One decoder layer with:
54
+ - Masked self-attention (query/key/value = tgt)
55
+ - Encoder-Decoder cross-attention (query = tgt, key/value = memory)
56
+ - Position-wise FeedForward
57
+ - Pre-LN + residuals + dropout
58
+
59
+ Args:
60
+ d_model: model hidden size
61
+ num_heads: number of attention heads
62
+ d_ff: ff intermediate size
63
+ dropout: dropout for residuals / FFN
64
+ """
65
+
66
+ def __init__(self, d_model: int, num_heads: int, d_ff: int, dropout: float = 0.1):
67
+ super().__init__()
68
+ # NOTE: instantiate internal MHA with dropout=0.0 and manage dropout at layer-level
69
+ self.self_attn = MultiHeadAttention(d_model=d_model, num_heads=num_heads, dropout=0.0)
70
+ self.cross_attn = MultiHeadAttention(d_model=d_model, num_heads=num_heads, dropout=0.0)
71
+ self.ffn = FeedForward(d_model=d_model, d_ff=d_ff, dropout=dropout)
72
+
73
+ # LayerNorms (Pre-LN)
74
+ self.norm1 = nn.LayerNorm(d_model)
75
+ self.norm2 = nn.LayerNorm(d_model)
76
+ self.norm3 = nn.LayerNorm(d_model)
77
+
78
+ # Dropouts applied after sublayers (on sublayer outputs before residual add)
79
+ self.dropout1 = nn.Dropout(dropout)
80
+ self.dropout2 = nn.Dropout(dropout)
81
+ self.dropout3 = nn.Dropout(dropout)
82
+
83
+ def forward(
84
+ self,
85
+ tgt: torch.Tensor,
86
+ memory: torch.Tensor,
87
+ tgt_mask: Optional[torch.Tensor] = None,
88
+ memory_mask: Optional[torch.Tensor] = None,
89
+ ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
90
+ """
91
+ Forward pass for one decoder layer.
92
+
93
+ Args:
94
+ tgt: (batch, tgt_len, d_model)
95
+ memory: (batch, src_len, d_model) -- encoder outputs
96
+ tgt_mask: optional (batch, tgt_len, tgt_len) or (batch, 1, tgt_len, tgt_len)
97
+ memory_mask: optional (batch, src_len, src_len) or (batch, 1, tgt_len, src_len)
98
+
99
+ Returns:
100
+ output: (batch, tgt_len, d_model)
101
+ attn_maps: dict with keys 'self' and 'cross' containing attention tensors
102
+ shapes: (batch, num_heads, tgt_len, tgt_len) and (batch, num_heads, tgt_len, src_len)
103
+ """
104
+ # TODO [IMPLEMENT] Self-attention (Pre-LN)
105
+ # x_norm = self.norm1(tgt)
106
+ # self_out, self_attn = self.self_attn(x_norm, x_norm, x_norm, tgt_mask)
107
+ # tgt = tgt + self.dropout1(self_out)
108
+
109
+ # TODO [IMPLEMENT] Cross-attention (Pre-LN)
110
+ # x_norm = self.norm2(tgt)
111
+ # cross_out, cross_attn = self.cross_attn(x_norm, memory, memory, memory_mask)
112
+ # tgt = tgt + self.dropout2(cross_out)
113
+
114
+ # TODO [IMPLEMENT] Feed-forward (Pre-LN)
115
+ # x_norm = self.norm3(tgt)
116
+ # ffn_out = self.ffn(x_norm)
117
+ # tgt = tgt + self.dropout3(ffn_out)
118
+
119
+ # TODO [RETURN] Return (tgt, {"self": self_attn, "cross": cross_attn})
120
+ raise NotImplementedError("TransformerDecoderLayer.forward: implement Pre-LN pipeline")
121
+
122
+
123
+ class TransformerDecoder(nn.Module):
124
+ """
125
+ Full decoder: token embedding + positional encoding + stack of decoder layers.
126
+ Also supports simple greedy decoding.
127
+
128
+ Args:
129
+ vocab_size: for token embeddings (if using token ids)
130
+ d_model, num_layers, num_heads, d_ff, dropout, max_len, pad_token_id: same semantics as encoder
131
+ """
132
+
133
+ def __init__(
134
+ self,
135
+ vocab_size: int,
136
+ d_model: int = 512,
137
+ num_layers: int = 6,
138
+ num_heads: int = 8,
139
+ d_ff: int = 2048,
140
+ dropout: float = 0.1,
141
+ max_len: int = 512,
142
+ pad_token_id: Optional[int] = None,
143
+ ):
144
+ super().__init__()
145
+ self.vocab_size = vocab_size
146
+ self.d_model = d_model
147
+ self.pad_token_id = pad_token_id
148
+
149
+ # Token embedding (used if inputs are token ids)
150
+ self.embedding = nn.Embedding(vocab_size, d_model)
151
+
152
+ # Positional encoding
153
+ self.pos_encoder = PositionalEncoding(d_model=d_model, max_len=max_len, dropout=dropout)
154
+
155
+ # Decoder layers
156
+ self.layers = nn.ModuleList(
157
+ [
158
+ TransformerDecoderLayer(d_model=d_model, num_heads=num_heads, d_ff=d_ff, dropout=dropout)
159
+ for _ in range(num_layers)
160
+ ]
161
+ )
162
+
163
+ # Final layer norm for Pre-LN stacks
164
+ self.final_norm = nn.LayerNorm(d_model)
165
+
166
+ # Output projection to vocabulary (logits)
167
+ self.output_projection = nn.Linear(d_model, vocab_size)
168
+
169
+ # Input dropout (after pos encoding)
170
+ self.input_dropout = nn.Dropout(dropout)
171
+
172
+ def _build_padding_mask_from_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
173
+ """
174
+ Build (batch, seq, seq) boolean mask from input ids and pad_token_id.
175
+ True = allowed, False = masked.
176
+ """
177
+ assert self.pad_token_id is not None, "pad_token_id must be set to build mask from ids"
178
+ pad_mask = (input_ids != self.pad_token_id) # (B, S)
179
+ attn_mask = pad_mask.unsqueeze(1) & pad_mask.unsqueeze(2) # (B, S, S)
180
+ return attn_mask
181
+
182
+ def forward(
183
+ self,
184
+ inputs: torch.Tensor,
185
+ memory: torch.Tensor,
186
+ tgt_mask: Optional[torch.Tensor] = None,
187
+ memory_mask: Optional[torch.Tensor] = None,
188
+ collect_attn: bool = False,
189
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[Dict[str, torch.Tensor]]]]:
190
+ """
191
+ Forward pass for the decoder stack.
192
+
193
+ Args:
194
+ inputs: token ids (B, T) or embeddings (B, T, d_model)
195
+ memory: encoder outputs (B, S, d_model)
196
+ tgt_mask: optional mask for decoder self-attention. If None, a causal mask will be created.
197
+ Mask shapes: (B, T, T) or (B, 1, T, T)
198
+ memory_mask: optional mask over memory (B, S, S) or (B, 1, T, S)
199
+ collect_attn: if True returns (logits/outputs, [per-layer attn dicts])
200
+
201
+ Returns:
202
+ logits: (B, T, vocab_size) or (B, T, d_model) if you prefer returning hidden states
203
+ or (logits, attn_list) if collect_attn True
204
+ """
205
+ # Inputs: if token ids, embed and scale; else assume embeddings
206
+ if inputs.dim() == 2: # token ids
207
+ x = self.embedding(inputs) * math.sqrt(self.d_model)
208
+ elif inputs.dim() == 3:
209
+ x = inputs
210
+ else:
211
+ raise ValueError("inputs must be (B, T) token ids or (B, T, d_model) embeddings")
212
+
213
+ # Positional encoding + dropout
214
+ x = self.pos_encoder(x)
215
+ x = self.input_dropout(x)
216
+
217
+ # Build tgt_mask if not provided: combine causal mask and padding mask if available
218
+ seq_len = x.size(1)
219
+ if tgt_mask is None:
220
+ # base causal mask (T, T)
221
+ causal = create_causal_mask(seq_len, device=x.device) # [TODO implement]
222
+ # expand to batch dim later if padding present
223
+ if inputs.dim() == 2 and self.pad_token_id is not None:
224
+ padding_mask = self._build_padding_mask_from_ids(inputs) # (B, T, T)
225
+ # combine: True only where both causal and padding allow attention
226
+ # TODO: ensure shapes align; broadcast causal to (1, T, T) then & with padding_mask
227
+ raise NotImplementedError("tgt_mask construction: combine causal + padding_mask")
228
+ else:
229
+ # TODO: Broadcast causal to (1, T, T) or (B, 1, T, T) depending on downstream expectations
230
+ raise NotImplementedError("tgt_mask construction: broadcast causal mask for batch")
231
+
232
+ # Ensure memory_mask is boolean on correct device if provided
233
+ if memory_mask is not None:
234
+ memory_mask = memory_mask.to(dtype=torch.bool, device=x.device)
235
+
236
+ attn_list: List[Dict[str, torch.Tensor]] = []
237
+
238
+ # Pass through layers
239
+ for layer in self.layers:
240
+ x, attn = layer(x, memory, tgt_mask=tgt_mask, memory_mask=memory_mask)
241
+ if collect_attn:
242
+ attn_list.append(attn)
243
+
244
+ x = self.final_norm(x) # Pre-LN final normalization
245
+
246
+ logits = self.output_projection(x) # (B, T, vocab)
247
+ if collect_attn:
248
+ return logits, attn_list
249
+ return logits
250
+
251
+ # ---------------------------------------------------------------------
252
+ # Generation / inference helpers (skeletons)
253
+ # ---------------------------------------------------------------------
254
+ def greedy_decode(
255
+ self,
256
+ memory: torch.Tensor,
257
+ max_len: int,
258
+ start_token_id: int,
259
+ end_token_id: Optional[int] = None,
260
+ device: Optional[torch.device] = None,
261
+ ) -> torch.LongTensor:
262
+ """
263
+ Greedy autoregressive decoding using the decoder stack.
264
+
265
+ Args:
266
+ memory: encoder outputs (B, S, d_model)
267
+ max_len: maximum target length to generate
268
+ start_token_id: BOS token id
269
+ end_token_id: optional EOS token id to stop early
270
+ Returns:
271
+ generated: (B, T_out) long tensor of token ids
272
+ """
273
+ # TODO [IMPLEMENT]:
274
+ # - Start with tensor of shape (B, 1) filled with start_token_id
275
+ # - Repeatedly call decoder.forward in incremental mode (or full forward with causal mask)
276
+ # - At each step pick argmax over logits and append to sequence
277
+ # - Stop if all sequences produced end_token_id or max_len reached
278
+ raise NotImplementedError("greedy_decode: implement autoregressive generation loop")
279
+
280
+ # Optional: incremental step method with caching of past keys/values for speed
281
+ def step(
282
+ self,
283
+ last_token_ids: torch.LongTensor,
284
+ memory: torch.Tensor,
285
+ cache: Optional[Dict] = None,
286
+ ) -> Tuple[torch.Tensor, Dict]:
287
+ """
288
+ Single-step decoder that returns logits for the next token given last_token_ids.
289
+
290
+ Args:
291
+ last_token_ids: (B, 1) tokens at current time step
292
+ memory: encoder outputs
293
+ cache: optional dict storing per-layer cached keys/values
294
+
295
+ Returns:
296
+ logits: (B, vocab_size)
297
+ new_cache: updated cache
298
+ """
299
+ # TODO [OPTIONAL]: implement fast incremental decoding caching keys/values per layer
300
+ raise NotImplementedError("step: incremental decoding (optional optimization)")
src/models/encoder.py CHANGED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Transformer encoder implementation (Pre-LN).
3
+
4
+ Contains:
5
+ - TransformerEncoderLayer: one encoder block (self-attention + FFN with residuals + LayerNorm)
6
+ - TransformerEncoder: embedding + positional encoding + stack of encoder layers
7
+
8
+ Design choices:
9
+ - Pre-LN (LayerNorm before each sublayer) for stable training.
10
+ - The FeedForward module is position-wise and does NOT include residuals or normalization.
11
+ - MultiHeadAttention handles mask broadcasting from (B, S, S) -> (B, 1, S, S) internally.
12
+ - The encoder accepts either token ids (LongTensor) or precomputed embeddings (FloatTensor).
13
+ If you pass token ids, provide vocab_size when constructing the encoder and optionally pad_token_id.
14
+ - Optionally collect attention weights by passing collect_attn=True to forward().
15
+ """
16
+
17
+ from typing import Optional, Tuple, List, Union
18
+
19
+ import math
20
+ import torch
21
+ import torch.nn as nn
22
+
23
+ from .attention import MultiHeadAttention
24
+ from .feedforward import FeedForward
25
+ from .positional_encoding import PositionalEncoding
26
+
27
+
28
+ class TransformerEncoderLayer(nn.Module):
29
+ """
30
+ Single Transformer encoder layer (Pre-LN).
31
+
32
+ Args:
33
+ d_model: model hidden size
34
+ num_heads: number of attention heads
35
+ d_ff: hidden dimension of the position-wise feed-forward network
36
+ dropout: dropout probability applied to sublayer outputs
37
+ """
38
+
39
+ def __init__(self, d_model: int, num_heads: int, d_ff: int, dropout: float = 0.1):
40
+ super().__init__()
41
+ self.self_attn = MultiHeadAttention(d_model=d_model, num_heads=num_heads, dropout=0.0)
42
+ # set MHA internal dropout to 0.0 and use dropout1/dropout2 in the layer
43
+ self.ffn = FeedForward(d_model=d_model, d_ff=d_ff, dropout=dropout)
44
+
45
+ self.norm1 = nn.LayerNorm(d_model)
46
+ self.norm2 = nn.LayerNorm(d_model)
47
+
48
+ self.dropout1 = nn.Dropout(dropout)
49
+ self.dropout2 = nn.Dropout(dropout)
50
+
51
+ def forward(
52
+ self,
53
+ x: torch.Tensor,
54
+ mask: Optional[torch.Tensor] = None,
55
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
56
+ """
57
+ Forward pass for the encoder layer.
58
+
59
+ Args:
60
+ x: (batch, seq_len, d_model) - input embeddings / representations
61
+ mask: optional attention mask, shape either (batch, seq_q, seq_k) or (batch, 1, seq_q, seq_k)
62
+
63
+ Returns:
64
+ x: (batch, seq_len, d_model)
65
+ If you want attention weights, set collect_attn externally (the encoder stack can collect them).
66
+ """
67
+ # Self-attention sublayer (Pre-LN)
68
+ x_norm = self.norm1(x) # Pre-LN
69
+ # self_attn expects query, key, value; for encoder they are the same
70
+ attn_out, attn_weights = self.self_attn(x_norm, x_norm, x_norm, mask)
71
+ x = x + self.dropout1(attn_out)
72
+
73
+ # Feed-forward sublayer (Pre-LN)
74
+ x_norm = self.norm2(x)
75
+ ffn_out = self.ffn(x_norm)
76
+ x = x + self.dropout2(ffn_out)
77
+
78
+ # Return output (and optionally attn_weights if caller wants to collect them)
79
+ return x, attn_weights
80
+
81
+
82
+ class TransformerEncoder(nn.Module):
83
+ """
84
+ Full encoder: token embedding + positional encoding + N encoder layers.
85
+
86
+ Args:
87
+ vocab_size: vocabulary size (ignored if you always pass embeddings)
88
+ d_model: model hidden size
89
+ num_layers: number of encoder layers to stack
90
+ num_heads: number of attention heads
91
+ d_ff: hidden dimension in FFN
92
+ dropout: dropout probability (applied in positional encoding & residuals)
93
+ max_len: maximum sequence length for positional encoding
94
+ pad_token_id: optional token id for padding; if provided and input is token ids,
95
+ a padding mask will be constructed automatically
96
+ """
97
+
98
+ def __init__(
99
+ self,
100
+ vocab_size: int,
101
+ d_model: int = 512,
102
+ num_layers: int = 6,
103
+ num_heads: int = 8,
104
+ d_ff: int = 2048,
105
+ dropout: float = 0.1,
106
+ max_len: int = 512,
107
+ pad_token_id: Optional[int] = None,
108
+ ):
109
+ super().__init__()
110
+ self.vocab_size = vocab_size
111
+ self.d_model = d_model
112
+ self.pad_token_id = pad_token_id
113
+
114
+ # Token embedding (only used if forward receives token ids)
115
+ self.embedding = nn.Embedding(vocab_size, d_model)
116
+
117
+ # Positional encoding (adds dropout internally)
118
+ self.pos_encoder = PositionalEncoding(d_model=d_model, max_len=max_len, dropout=dropout)
119
+
120
+ # Encoder layers stack
121
+ self.layers = nn.ModuleList(
122
+ [TransformerEncoderLayer(d_model=d_model, num_heads=num_heads, d_ff=d_ff, dropout=dropout)
123
+ for _ in range(num_layers)]
124
+ )
125
+
126
+ # Final LayerNorm for Pre-LN stacks (recommended)
127
+ self.final_norm = nn.LayerNorm(d_model)
128
+
129
+ # Dropout applied after embedding + positional encoding (paper uses this)
130
+ self.input_dropout = nn.Dropout(dropout)
131
+
132
+ def _build_padding_mask(self, input_ids: torch.Tensor) -> torch.Tensor:
133
+ """
134
+ Build a 3D attention mask (batch, seq, seq) from input_ids and pad_token_id.
135
+ True indicates valid positions; False indicates masked (pad).
136
+ """
137
+ assert self.pad_token_id is not None, "pad_token_id must be set to build padding mask from ids."
138
+ # mask shape: (batch, seq) where True = token kept (non-pad)
139
+ pad_mask = (input_ids != self.pad_token_id)
140
+ # Convert to (batch, seq_q, seq_k) by outer product broadcasting
141
+ # We want positions that are valid as both query and key
142
+ attn_mask = pad_mask.unsqueeze(1) & pad_mask.unsqueeze(2)
143
+ # attn_mask dtype should be bool
144
+ return attn_mask
145
+
146
+ def forward(
147
+ self,
148
+ inputs: torch.Tensor,
149
+ mask: Optional[torch.Tensor] = None,
150
+ collect_attn: bool = False,
151
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
152
+ """
153
+ Forward through the encoder.
154
+
155
+ Args:
156
+ inputs: either
157
+ - token ids: LongTensor of shape (batch, seq)
158
+ - embeddings: FloatTensor of shape (batch, seq, d_model)
159
+ mask: optional attention mask. If None and pad_token_id is set and inputs are token ids,
160
+ a padding mask will be created automatically with shape (batch, seq, seq).
161
+ The mask should be boolean where True indicates allowed attention.
162
+ collect_attn: if True, returns (output, [attn_weights_per_layer]) where each entry is (batch, num_heads, seq, seq)
163
+
164
+ Returns:
165
+ output: (batch, seq, d_model)
166
+ or (output, attn_list) if collect_attn True
167
+ """
168
+ # If inputs are token ids, embed them; otherwise assume they are embeddings
169
+ if inputs.dim() == 2: # token ids
170
+ if self.embedding is None:
171
+ raise ValueError("Encoder was not constructed with an embedding layer.")
172
+ x = self.embedding(inputs) * math.sqrt(self.d_model)
173
+ elif inputs.dim() == 3: # already embeddings
174
+ x = inputs
175
+ else:
176
+ raise ValueError("inputs must be (batch, seq) token ids or (batch, seq, d_model) embeddings")
177
+
178
+ # Positional encoding + dropout
179
+ x = self.pos_encoder(x)
180
+ x = self.input_dropout(x)
181
+
182
+ # Build mask if needed
183
+ if mask is None and inputs.dim() == 2 and self.pad_token_id is not None:
184
+ mask = self._build_padding_mask(inputs)
185
+
186
+ # Ensure mask is boolean and on the same device
187
+ if mask is not None:
188
+ mask = mask.to(dtype=torch.bool, device=x.device)
189
+
190
+ attn_weights_per_layer: List[torch.Tensor] = []
191
+
192
+ # Pass through each encoder layer (optionally collect attn)
193
+ for layer in self.layers:
194
+ x, attn = layer(x, mask=mask)
195
+ if collect_attn:
196
+ attn_weights_per_layer.append(attn)
197
+
198
+ # Final normalization (Pre-LN stack)
199
+ x = self.final_norm(x)
200
+
201
+ if collect_attn:
202
+ return x, attn_weights_per_layer
203
+ return x
src/models/feedforward.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Position-wise Feed-Forward Network.
3
+ """
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.init as init
8
+ from typing import Literal
9
+
10
+ class FeedForward(nn.Module):
11
+ """
12
+ FFN(x) = max(0, xW₁ + b₁)W₂ + b₂
13
+
14
+ Or with GELU: FFN(x) = GELU(xW₁ + b₁)W₂ + b₂
15
+ """
16
+
17
+ def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1, activation: Literal["gelu", "relu"] = "gelu"):
18
+ super().__init__()
19
+ self.linear1 = nn.Linear(d_model, d_ff) # w_1
20
+ self.activation = nn.GELU() if activation == 'gelu' else nn.ReLU()
21
+ self.dropout = nn.Dropout(dropout)
22
+ self.linear2 = nn.Linear(d_ff, d_model) # w_2
23
+
24
+ # Weight Initialization
25
+ init.xavier_uniform_(self.linear1.weight)
26
+ init.zeros_(self.linear1.bias)
27
+ init.xavier_uniform_(self.linear2.weight)
28
+ init.zeros_(self.linear2.bias)
29
+
30
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
31
+ """
32
+ x: (batch, seq_len, d_model)
33
+ returns: (batch, seq_len, d_model)
34
+ """
35
+ x = self.linear1(x) # (batch, seq_len, d_ff)
36
+ x = self.activation(x) # activation
37
+ x = self.dropout(x) # dropout
38
+ x = self.linear2(x) # (batch, seq_len, d_model)
39
+ return x
40
+
src/models/positional_encoding.py CHANGED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # src/models/positional_encoding.py
2
+
3
+ """
4
+ Positional Encoding for Transformer models.
5
+
6
+ Injects information about the position of tokens in a sequence, since
7
+ self-attention has no inherent notion of token order.
8
+ """
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import math
13
+
14
+ class PositionalEncoding(nn.Module):
15
+ """
16
+ Implements the sinusoidal positional encoding from "Attention Is All You Need".
17
+
18
+ Formula:
19
+ PE(pos, 2i) = sin(pos / 10000^(2i/d_model))
20
+ PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))
21
+
22
+ Where:
23
+ pos: position in sequence (0 to max_len-1)
24
+ i: dimension index (0 to d_model/2)
25
+
26
+ Args:
27
+ d_model: Dimension of the model embeddings
28
+ max_len: Maximum sequence length to pre-compute
29
+ dropout: Dropout probability to apply after adding positional encoding
30
+
31
+ Shape:
32
+ Input: (batch, seq_len, d_model)
33
+ Output: (batch, seq_len, d_model)
34
+
35
+ Example:
36
+ >>> pos_enc = PositionalEncoding(d_model=512, max_len=5000)
37
+ >>> x = torch.randn(32, 100, 512) # (batch, seq, d_model)
38
+ >>> output = pos_enc(x)
39
+ >>> output.shape
40
+ torch.Size([32, 100, 512])
41
+ """
42
+
43
+ def __init__(self, d_model: int, max_len: int = 5000, dropout: float = 0.1):
44
+ super().__init__()
45
+ self.dropout = nn.Dropout(p=dropout)
46
+ # Create a tensor of positions: [0, 1, 2, ..., max_len-1]
47
+ # Create a tensor of dimension indices: [0, 1, 2, ..., d_model-1]
48
+ # Compute the division term: 10000^(2i/d_model)
49
+ # Apply sin to even indices, cos to odd indices
50
+ # Register as buffer (not a parameter, but part of state_dict)
51
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
52
+ div_term = torch.exp(
53
+ torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
54
+ )
55
+ pe = torch.zeros(max_len, d_model)
56
+ pe[:, 0::2] = torch.sin(position * div_term) # Even indices
57
+ pe[:, 1::2] = torch.cos(position * div_term) # Odd indices
58
+ pe = pe.unsqueeze(0)
59
+ self.register_buffer("pe", pe)
60
+
61
+
62
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
63
+ """
64
+ Add positional encoding to input embeddings.
65
+
66
+ Args:
67
+ x: Input embeddings (batch, seq_len, d_model)
68
+
69
+ Returns:
70
+ x with positional encoding added (batch, seq_len, d_model)
71
+ """
72
+ # Get sequence length from input
73
+ # Add the appropriate slice of positional encoding
74
+ # Apply dropout
75
+ # Return result
76
+ x = x + self.pe[:, : x.size(1)].requires_grad_(False)
77
+ # self.pe contains pre-computed encodings for all positions
78
+ # just need to add the first seq_len positions to x
79
+ return self.dropout(x)
tests/test_models/__init__.py ADDED
File without changes
tests/test_models/test_encoder.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import pytest
4
+ from src.models.encoder import TransformerEncoder
5
+
6
+
7
+ def test_encoder_token_ids_and_padding_mask_and_grad():
8
+ """
9
+ Test using token ids as input, automatic padding mask creation when pad_token_id is provided,
10
+ output shape, and that gradients flow through the model.
11
+ """
12
+ torch.manual_seed(0)
13
+ vocab_size = 50
14
+ pad_token_id = 0
15
+ d_model = 64
16
+ num_layers = 3
17
+ num_heads = 8
18
+ d_ff = 128
19
+ batch_size = 2
20
+ seq_len = 12
21
+
22
+ encoder = TransformerEncoder(
23
+ vocab_size=vocab_size,
24
+ d_model=d_model,
25
+ num_layers=num_layers,
26
+ num_heads=num_heads,
27
+ d_ff=d_ff,
28
+ dropout=0.1,
29
+ max_len=seq_len,
30
+ pad_token_id=pad_token_id,
31
+ )
32
+
33
+ # create inputs with some padding at the end
34
+ input_ids = torch.randint(1, vocab_size, (batch_size, seq_len), dtype=torch.long)
35
+ input_ids[0, -3:] = pad_token_id # first sample has last 3 tokens as padding
36
+ input_ids[1, -1:] = pad_token_id # second sample has last token as padding
37
+
38
+ # Forward pass (token ids)
39
+ out = encoder(input_ids) # default collect_attn=False
40
+ assert out.shape == (batch_size, seq_len, d_model)
41
+
42
+ # Check gradients flow
43
+ loss = out.sum()
44
+ loss.backward()
45
+ grads = [p.grad for p in encoder.parameters() if p.requires_grad]
46
+ assert any(g is not None for g in grads), "No gradients found on any parameter"
47
+
48
+
49
+ def test_encoder_embeddings_input_and_collect_attn():
50
+ """
51
+ Test passing pre-computed embeddings to the encoder, collecting attention weights,
52
+ and verify shapes of attention maps per layer.
53
+ """
54
+ torch.manual_seed(1)
55
+ vocab_size = 100 # not used in this test
56
+ d_model = 48
57
+ num_layers = 4
58
+ num_heads = 6
59
+ d_ff = 128
60
+ batch_size = 1
61
+ seq_len = 10
62
+
63
+ encoder = TransformerEncoder(
64
+ vocab_size=vocab_size,
65
+ d_model=d_model,
66
+ num_layers=num_layers,
67
+ num_heads=num_heads,
68
+ d_ff=d_ff,
69
+ dropout=0.0,
70
+ max_len=seq_len,
71
+ pad_token_id=None,
72
+ )
73
+
74
+ # Create random embeddings directly
75
+ embeddings = torch.randn(batch_size, seq_len, d_model)
76
+
77
+ out, attn_list = encoder(embeddings, mask=None, collect_attn=True)
78
+ assert out.shape == (batch_size, seq_len, d_model)
79
+ assert isinstance(attn_list, list)
80
+ assert len(attn_list) == num_layers
81
+
82
+ # Each attention weight tensor should have shape (batch, num_heads, seq, seq)
83
+ for attn in attn_list:
84
+ assert attn.shape == (batch_size, num_heads, seq_len, seq_len)
85
+
86
+
87
+ def test_mask_accepts_3d_and_4d_and_broadcasts():
88
+ """
89
+ Test that a provided 3D mask (batch, seq, seq) and an equivalent 4D mask
90
+ (batch, 1, seq, seq) produce outputs of the same shape and do not error.
91
+ """
92
+ torch.manual_seed(2)
93
+ vocab_size = 40
94
+ d_model = 32
95
+ num_layers = 2
96
+ num_heads = 4
97
+ d_ff = 64
98
+ batch_size = 2
99
+ seq_len = 7
100
+
101
+ encoder = TransformerEncoder(
102
+ vocab_size=vocab_size,
103
+ d_model=d_model,
104
+ num_layers=num_layers,
105
+ num_heads=num_heads,
106
+ d_ff=d_ff,
107
+ dropout=0.0,
108
+ max_len=seq_len,
109
+ pad_token_id=None,
110
+ )
111
+
112
+ # Create dummy embeddings
113
+ embeddings = torch.randn(batch_size, seq_len, d_model)
114
+
115
+ # 3D mask: True indicates allowed attention
116
+ mask3 = torch.ones(batch_size, seq_len, seq_len, dtype=torch.bool)
117
+ mask3[:, :, -2:] = False # mask out last two keys
118
+
119
+ # 4D mask equivalent
120
+ mask4 = mask3.unsqueeze(1) # (B, 1, S, S)
121
+
122
+ out3 = encoder(embeddings, mask=mask3)
123
+ out4 = encoder(embeddings, mask=mask4)
124
+
125
+ assert out3.shape == (batch_size, seq_len, d_model)
126
+ assert out4.shape == (batch_size, seq_len, d_model)
127
+ # Outputs should be finite and not NaN
128
+ assert torch.isfinite(out3).all()
129
+ assert torch.isfinite(out4).all()
130
+
131
+
132
+ def test_train_eval_determinism_and_dropout_effect():
133
+ """
134
+ Validate that in train mode with dropout enabled, repeated forwards differ,
135
+ and in eval mode they are equal (deterministic).
136
+ """
137
+ torch.manual_seed(3)
138
+ vocab_size = 60
139
+ pad_token_id = 0
140
+ d_model = 64
141
+ num_layers = 2
142
+ num_heads = 8
143
+ d_ff = 128
144
+ batch_size = 2
145
+ seq_len = 9
146
+
147
+ encoder = TransformerEncoder(
148
+ vocab_size=vocab_size,
149
+ d_model=d_model,
150
+ num_layers=num_layers,
151
+ num_heads=num_heads,
152
+ d_ff=d_ff,
153
+ dropout=0.4,
154
+ max_len=seq_len,
155
+ pad_token_id=pad_token_id,
156
+ )
157
+
158
+ # token ids with occasional padding
159
+ input_ids = torch.randint(1, vocab_size, (batch_size, seq_len), dtype=torch.long)
160
+ input_ids[0, -2:] = pad_token_id
161
+
162
+ # Training mode: randomness due to dropout -> outputs should likely differ
163
+ encoder.train()
164
+ out1 = encoder(input_ids)
165
+ out2 = encoder(input_ids)
166
+ assert not torch.allclose(out1, out2), "Outputs identical in train mode despite dropout"
167
+
168
+ # Eval mode: deterministic
169
+ encoder.eval()
170
+ out3 = encoder(input_ids)
171
+ out4 = encoder(input_ids)
172
+ assert torch.allclose(out3, out4), "Outputs differ in eval mode"
173
+
174
+
175
+ if __name__ == "__main__":
176
+ pytest.main([__file__, "-q"])
tests/test_models/test_encoder_layer.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import pytest
3
+ from src.models.encoder import TransformerEncoderLayer
4
+
5
+
6
+ def test_output_shape_and_grad():
7
+ """
8
+ The encoder layer should preserve the input shape (batch, seq_len, d_model)
9
+ and gradients should flow to parameters.
10
+ """
11
+ d_model, num_heads, d_ff = 64, 8, 256
12
+ batch_size, seq_len = 2, 10
13
+
14
+ layer = TransformerEncoderLayer(d_model=d_model, num_heads=num_heads, d_ff=d_ff, dropout=0.0)
15
+ x = torch.randn(batch_size, seq_len, d_model, requires_grad=True)
16
+
17
+ out = layer(x) # should accept mask=None by default
18
+ assert out.shape == (batch_size, seq_len, d_model)
19
+
20
+ # simple backward to ensure gradients propagate
21
+ loss = out.sum()
22
+ loss.backward()
23
+
24
+ grads = [p.grad for p in layer.parameters() if p.requires_grad]
25
+ assert any(g is not None for g in grads), "No gradients found on any parameter"
26
+
27
+
28
+ def test_dropout_behavior_train_vs_eval():
29
+ """
30
+ With dropout > 0, the outputs should differ between two forward calls in train mode
31
+ and be identical in eval mode.
32
+ """
33
+ torch.manual_seed(0)
34
+ d_model, num_heads, d_ff = 64, 8, 256
35
+ batch_size, seq_len = 2, 10
36
+
37
+ layer = TransformerEncoderLayer(d_model=d_model, num_heads=num_heads, d_ff=d_ff, dropout=0.5)
38
+ x = torch.randn(batch_size, seq_len, d_model)
39
+
40
+ layer.train()
41
+ out1 = layer(x)
42
+ out2 = layer(x)
43
+ # Training mode with dropout: outputs usually differ
44
+ assert not torch.allclose(out1, out2), "Outputs identical in train mode despite dropout"
45
+
46
+ layer.eval()
47
+ out3 = layer(x)
48
+ out4 = layer(x)
49
+ # Eval mode deterministic: outputs should be identical
50
+ assert torch.allclose(out3, out4), "Outputs differ in eval mode"
51
+
52
+
53
+ def test_mask_broadcasting_accepts_3d_and_4d_mask():
54
+ """
55
+ The encoder layer should accept a 3D mask (batch, seq_q, seq_k) and a 4D mask
56
+ (batch, 1, seq_q, seq_k) and handle broadcasting across heads without error.
57
+ """
58
+ d_model, num_heads, d_ff = 64, 8, 256
59
+ batch_size, seq_len = 2, 7
60
+
61
+ layer = TransformerEncoderLayer(d_model=d_model, num_heads=num_heads, d_ff=d_ff, dropout=0.0)
62
+ x = torch.randn(batch_size, seq_len, d_model)
63
+
64
+ # 3D mask: (batch, seq, seq)
65
+ mask3 = torch.ones(batch_size, seq_len, seq_len, dtype=torch.bool)
66
+ mask3[:, :, -2:] = False # mask out last two key positions
67
+ out3 = layer(x, mask=mask3) # should not raise
68
+ assert out3.shape == (batch_size, seq_len, d_model)
69
+
70
+ # 4D mask: (batch, 1, seq, seq) already including head dim for broadcasting
71
+ mask4 = mask3.unsqueeze(1)
72
+ out4 = layer(x, mask=mask4)
73
+ assert out4.shape == (batch_size, seq_len, d_model)
74
+
75
+
76
+ if __name__ == "__main__":
77
+ # Run tests interactively if needed
78
+ pytest.main([__file__, "-q"])
tests/test_models/test_feedforward.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import pytest
3
+ from src.models.feedforward import FeedForward
4
+
5
+
6
+ class TestFeedForward:
7
+ def test_output_shape(self):
8
+ d_model, d_ff = 512, 2048
9
+ batch_size, seq_len = 2, 10
10
+
11
+ ffn = FeedForward(d_model=d_model, d_ff=d_ff, dropout=0.0)
12
+ x = torch.randn(batch_size, seq_len, d_model)
13
+ out = ffn(x)
14
+
15
+ assert out.shape == (batch_size, seq_len, d_model)
16
+
17
+ def test_dropout_changes_output(self):
18
+ torch.manual_seed(0)
19
+ d_model, d_ff = 128, 512
20
+ x = torch.randn(2, 8, d_model)
21
+
22
+ ffn = FeedForward(d_model=d_model, d_ff=d_ff, dropout=0.5)
23
+ ffn.train()
24
+ out1 = ffn(x)
25
+ out2 = ffn(x)
26
+ # With dropout in train mode, outputs should differ (most likely)
27
+ assert not torch.allclose(out1, out2)
28
+
29
+ ffn.eval()
30
+ out3 = ffn(x)
31
+ out4 = ffn(x)
32
+ # In eval mode (no dropout), outputs should be identical for same input
33
+ assert torch.allclose(out3, out4)
34
+
35
+ def test_parameter_count_and_grads(self):
36
+ d_model, d_ff = 64, 256
37
+ ffn = FeedForward(d_model=d_model, d_ff=d_ff, dropout=0.0)
38
+
39
+ # Parameter existence
40
+ param_names = [name for name, _ in ffn.named_parameters()]
41
+ assert any('linear1' in name for name in param_names)
42
+ assert any('linear2' in name for name in param_names)
43
+
44
+ # Parameter shapes
45
+ shapes = {name: p.shape for name, p in ffn.named_parameters()}
46
+ assert shapes.get('linear1.weight') == (d_ff, d_model)
47
+ assert shapes.get('linear2.weight') == (d_model, d_ff)
48
+ assert shapes.get('linear1.bias') == (d_ff,)
49
+ assert shapes.get('linear2.bias') == (d_model,)
50
+
51
+ # ensure gradients flow
52
+ x = torch.randn(3, 5, d_model)
53
+ out = ffn(x)
54
+ loss = out.sum()
55
+ loss.backward()
56
+ for _, p in ffn.named_parameters():
57
+ assert p.grad is not None
tests/test_models/test_positional_encoding.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # tests/test_models/test_positional_encoding.py
2
+
3
+ """
4
+ Tests for positional encoding.
5
+ """
6
+
7
+ import pytest
8
+ import torch
9
+ import matplotlib.pyplot as plt
10
+ import seaborn as sns
11
+ from src.models.positional_encoding import PositionalEncoding
12
+
13
+
14
+ class TestPositionalEncoding:
15
+ """Test suite for PositionalEncoding."""
16
+
17
+ def test_output_shape(self):
18
+ """Test that output shape matches input shape."""
19
+ d_model, max_len = 512, 5000
20
+ batch_size, seq_len = 2, 100
21
+
22
+ pos_enc = PositionalEncoding(d_model, max_len, dropout=0.0)
23
+ x = torch.randn(batch_size, seq_len, d_model)
24
+
25
+ output = pos_enc(x)
26
+ assert output.shape == (batch_size, seq_len, d_model)
27
+
28
+ def test_different_sequence_lengths(self):
29
+ """Test with various sequence lengths."""
30
+ pos_enc = PositionalEncoding(d_model=256, max_len=1000, dropout=0.0)
31
+
32
+ for seq_len in [10, 50, 100, 500]:
33
+ x = torch.randn(1, seq_len, 256)
34
+ output = pos_enc(x)
35
+ assert output.shape == (1, seq_len, 256)
36
+
37
+ def test_dropout_changes_output(self):
38
+ """Test that dropout is applied during training."""
39
+ torch.manual_seed(42)
40
+ pos_enc = PositionalEncoding(d_model=128, dropout=0.5)
41
+ pos_enc.train()
42
+
43
+ x = torch.randn(2, 10, 128)
44
+
45
+ output1 = pos_enc(x)
46
+ output2 = pos_enc(x)
47
+
48
+ # Should be different due to dropout
49
+ assert not torch.allclose(output1, output2)
50
+
51
+ # In eval mode, should be deterministic
52
+ pos_enc.eval()
53
+ output3 = pos_enc(x)
54
+ output4 = pos_enc(x)
55
+ assert torch.allclose(output3, output4)
56
+
57
+ def test_encoding_properties(self):
58
+ """Test mathematical properties of encoding."""
59
+ pos_enc = PositionalEncoding(d_model=128, max_len=100, dropout=0.0)
60
+
61
+ # Get the raw encoding (without dropout)
62
+ pe = pos_enc.pe[0] # Remove batch dimension
63
+
64
+ # Each row should have values in [-1, 1] (sin/cos range)
65
+ assert (pe >= -1).all() and (pe <= 1).all()
66
+
67
+ # Different positions should have different encodings
68
+ assert not torch.allclose(pe[0], pe[1])
69
+ assert not torch.allclose(pe[0], pe[50])
70
+
71
+
72
+ def test_visualize_positional_encoding():
73
+ """
74
+ Visualize the positional encoding pattern.
75
+ Creates heatmap showing encoding values.
76
+ """
77
+ pos_enc = PositionalEncoding(d_model=128, max_len=100, dropout=0.0)
78
+
79
+ # Get encoding matrix
80
+ pe = pos_enc.pe.squeeze(0).numpy() # (max_len, d_model)
81
+
82
+ # Plot first 50 positions and 64 dimensions
83
+ plt.figure(figsize=(12, 8))
84
+ sns.heatmap(
85
+ pe[:50, :64].T,
86
+ cmap='RdBu_r',
87
+ center=0,
88
+ xticklabels=5,
89
+ yticklabels=8,
90
+ cbar_kws={'label': 'Encoding Value'}
91
+ )
92
+ plt.xlabel('Position in Sequence')
93
+ plt.ylabel('Embedding Dimension')
94
+ plt.title('Positional Encoding Pattern\n(Notice the wave patterns with different frequencies)')
95
+ plt.tight_layout()
96
+ plt.savefig('outputs/positional_encoding_heatmap.png', dpi=150)
97
+ print("✅ Saved to outputs/positional_encoding_heatmap.png")
98
+
99
+
100
+ if __name__ == "__main__":
101
+ import os
102
+ os.makedirs('outputs', exist_ok=True)
103
+ test_visualize_positional_encoding()