"""Prediction heads for Transformer models. This module provides task-specific output heads: - ClassificationHead: Sequence-level classification with pooling (mean/cls/max) - TokenClassificationHead: Per-token classification (NER, POS tagging) - LMHead: Language modeling logits with optional weight tying - ProjectionHead: MLP for representation learning / contrastive tasks Author: Oliver Perrin Date: 2025-10-23 """ from typing import Literal, Optional import torch import torch.nn as nn class ClassificationHead(nn.Module): """ Sequence-level classification head. Args: d_model: hidden size from encoder/decoder num_labels: number of output classes pooler: one of 'mean', 'cls', 'max' - how to pool the sequence dropout: dropout probability before final linear layer """ def __init__( self, d_model: int, num_labels: int, pooler: Literal["mean", "cls", "max"] = "mean", dropout: float = 0.1, ): super().__init__() assert pooler in ("mean", "cls", "max"), "pooler must be 'mean'|'cls'|'max'" self.pooler = pooler self.dropout = nn.Dropout(dropout) self.out_proj = nn.Linear(d_model, num_labels) def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: """ x: (batch, seq_len, d_model) mask: (batch, seq_len) - True for valid tokens, False for padding returns: (batch, num_labels) """ if self.pooler == "mean": if mask is not None: # mask is (B, S) # x is (B, S, D) # Expand mask to (B, S, 1) mask_expanded = mask.unsqueeze(-1).float() # Zero out padding x = x * mask_expanded # Sum over sequence sum_embeddings = x.sum(dim=1) # Count valid tokens sum_mask = mask_expanded.sum(dim=1) # Avoid division by zero sum_mask = torch.clamp(sum_mask, min=1e-9) pooled = sum_embeddings / sum_mask else: pooled = x.mean(dim=1) elif self.pooler == "cls": pooled = x[:, 0, :] else: # max if mask is not None: # Mask padding with -inf mask_expanded = mask.unsqueeze(-1) x = x.masked_fill(~mask_expanded, float("-inf")) pooled, _ = x.max(dim=1) pooled = self.dropout(pooled) return self.out_proj(pooled) class TokenClassificationHead(nn.Module): """ Per-token classification head. Useful for NER, POS, etc. Args: d_model: hidden size num_labels: number of per-token classes dropout: dropout probability applied before the linear layer """ def __init__(self, d_model: int, num_labels: int, dropout: float = 0.1): super().__init__() self.dropout = nn.Dropout(dropout) self.out_proj = nn.Linear(d_model, num_labels) def forward(self, x: torch.Tensor) -> torch.Tensor: """ x: (batch, seq_len, d_model) returns: (batch, seq_len, num_labels) """ x = self.dropout(x) return self.out_proj(x) class LMHead(nn.Module): """ Language modeling head: maps hidden states to logits over vocabulary. Args: d_model: hidden size vocab_size: vocabulary size tie_embedding: optional nn.Embedding instance to tie weights with """ def __init__(self, d_model: int, vocab_size: int, tie_embedding: Optional[nn.Embedding] = None): super().__init__() self.vocab_size = vocab_size self.d_model = d_model self.proj = nn.Linear(d_model, vocab_size, bias=True) if tie_embedding is not None: # Validate sizes assert tie_embedding.num_embeddings == vocab_size, ( "vocab size mismatch for weight tying" ) assert tie_embedding.embedding_dim == d_model, ( "embedding dim must match d_model for weight tying" ) # Tie weights: point the projection weight to the embedding weight Tensor # Remove the existing projection parameter in favor of the embedding weight # This keeps the same Parameter object, so updates affect both modules. self.proj.weight = tie_embedding.weight def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: """ hidden_states: (batch, seq_len, d_model) returns logits: (batch, seq_len, vocab_size) """ return self.proj(hidden_states) class ProjectionHead(nn.Module): """ Simple projection head for representation learning. Args: d_model: input dimension proj_dim: output projection dimension hidden_dim: intermediate dimension (optional) dropout: dropout probability """ def __init__( self, d_model: int, proj_dim: int = 128, hidden_dim: Optional[int] = None, dropout: float = 0.1, ): super().__init__() if hidden_dim is None: hidden_dim = max(d_model, proj_dim) self.net = nn.Sequential( nn.Linear(d_model, hidden_dim), nn.GELU(), nn.Dropout(dropout), nn.Linear(hidden_dim, proj_dim), ) def forward(self, x: torch.Tensor) -> torch.Tensor: """ x: (batch, d_model) or (batch, seq_len, d_model) - both supported. Returns: If input is 3D: (batch, seq_len, proj_dim) If input is 2D: (batch, proj_dim) """ orig_dim = x.dim() if orig_dim == 3: B, T, D = x.shape out = self.net(x.view(B * T, D)) return out.view(B, T, -1) elif orig_dim == 2: return self.net(x) else: raise ValueError("Input must be 2D or 3D tensor")