import torch import torch.nn as nn from transformers import PreTrainedModel from transformers.modeling_outputs import MaskedLMOutput from .configuration_nicheformer import NicheformerConfig import math class PositionalEncoding(nn.Module): """Positional encoding using sine and cosine functions.""" def __init__(self, d_model: int, max_seq_len: int): super().__init__() encoding = torch.zeros(max_seq_len, d_model) position = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) encoding[:, 0::2] = torch.sin(position * div_term) encoding[:, 1::2] = torch.cos(position * div_term) encoding = encoding.unsqueeze(0) self.register_buffer('encoding', encoding, persistent=False) def forward(self, x: torch.Tensor) -> torch.Tensor: """Add positional encoding to input tensor.""" return x + self.encoding[:, :x.size(1)] class NicheformerPreTrainedModel(PreTrainedModel): """Base class for Nicheformer models.""" config_class = NicheformerConfig base_model_prefix = "nicheformer" supports_gradient_checkpointing = True def _init_weights(self, module): if isinstance(module, nn.Linear): nn.init.xavier_normal_(module.weight) if module.bias is not None: nn.init.zeros_(module.bias) class NicheformerModel(NicheformerPreTrainedModel): def __init__(self, config: NicheformerConfig): super().__init__(config) # Core transformer components self.encoder_layer = nn.TransformerEncoderLayer( d_model=config.dim_model, nhead=config.nheads, dim_feedforward=config.dim_feedforward, batch_first=config.batch_first, dropout=config.dropout, layer_norm_eps=1e-12 ) self.encoder = nn.TransformerEncoder( encoder_layer=self.encoder_layer, num_layers=config.nlayers, enable_nested_tensor=False ) # Embedding layers self.embeddings = nn.Embedding( num_embeddings=config.n_tokens+5, embedding_dim=config.dim_model, padding_idx=1 ) if config.learnable_pe: self.positional_embedding = nn.Embedding( num_embeddings=config.context_length, embedding_dim=config.dim_model ) self.dropout = nn.Dropout(p=config.dropout) self.register_buffer('pos', torch.arange(0, config.context_length, dtype=torch.long)) else: self.positional_embedding = PositionalEncoding( d_model=config.dim_model, max_seq_len=config.context_length ) # Initialize weights self.post_init() def forward(self, input_ids, attention_mask=None): token_embedding = self.embeddings(input_ids) if self.config.learnable_pe: pos_embedding = self.positional_embedding(self.pos.to(token_embedding.device)) embeddings = self.dropout(token_embedding + pos_embedding) else: embeddings = self.positional_embedding(token_embedding) # Convert attention_mask to boolean and invert it for transformer's src_key_padding_mask # True indicates positions that will be masked if attention_mask is not None: attention_mask = ~attention_mask.bool() transformer_output = self.encoder( embeddings, src_key_padding_mask=attention_mask if attention_mask is not None else None, is_causal=False ) return transformer_output def get_embeddings(self, input_ids, attention_mask=None, layer: int = -1, with_context: bool = False) -> torch.Tensor: """Get embeddings from the model. Args: input_ids: Input token IDs attention_mask: Attention mask layer: Which transformer layer to extract embeddings from (-1 means last layer) with_context: Whether to include context tokens in the embeddings Returns: torch.Tensor: Embeddings tensor """ # Get token embeddings and positional encodings token_embedding = self.embeddings(input_ids) if self.config.learnable_pe: pos_embedding = self.positional_embedding(self.pos.to(token_embedding.device)) embeddings = self.dropout(token_embedding + pos_embedding) else: embeddings = self.positional_embedding(token_embedding) # Process through transformer layers up to desired layer if layer < 0: layer = self.config.nlayers + layer # -1 means last layer # Convert attention_mask to boolean and invert it for transformer's src_key_padding_mask if attention_mask is not None: padding_mask = ~attention_mask.bool() else: padding_mask = None # Process through each layer up to the desired one for i in range(layer + 1): embeddings = self.encoder.layers[i]( embeddings, src_key_padding_mask=padding_mask, is_causal=False ) # Remove context tokens (first 3 tokens) if not needed if not with_context: embeddings = embeddings[:, 3:, :] # Mean pooling over sequence dimension embeddings = embeddings.mean(dim=1) return embeddings class NicheformerForMaskedLM(NicheformerPreTrainedModel): def __init__(self, config: NicheformerConfig): super().__init__(config) self.nicheformer = NicheformerModel(config) self.classifier_head = nn.Linear(config.dim_model, config.n_tokens, bias=False) self.classifier_head.bias = nn.Parameter(torch.zeros(config.n_tokens)) # Initialize weights self.post_init() def forward( self, input_ids=None, attention_mask=None, labels=None, return_dict=None, apply_masking=False, ): return_dict = return_dict if return_dict is not None else self.config.use_return_dict # Apply masking if requested (typically during training) if apply_masking: batch = { 'input_ids': input_ids, 'attention_mask': attention_mask } masked_batch = complete_masking(batch, self.config.masking_p, self.config.n_tokens) input_ids = masked_batch['masked_indices'] labels = masked_batch['input_ids'] # Original tokens become labels mask = masked_batch['mask'] # Only compute loss on masked tokens and ensure labels are long labels = torch.where(mask, labels, torch.tensor(-100, device=labels.device)).long() transformer_output = self.nicheformer( input_ids=input_ids, attention_mask=attention_mask, ) prediction_scores = self.classifier_head(transformer_output) masked_lm_loss = None if labels is not None: loss_fct = nn.CrossEntropyLoss() masked_lm_loss = loss_fct( prediction_scores.view(-1, self.config.n_tokens), labels.view(-1) ) if not return_dict: output = (prediction_scores,) + (transformer_output,) return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output return MaskedLMOutput( loss=masked_lm_loss, logits=prediction_scores, hidden_states=transformer_output, ) def get_embeddings(self, input_ids, attention_mask=None, layer: int = -1, with_context: bool = False) -> torch.Tensor: """Get embeddings from the model. Args: input_ids: Input token IDs attention_mask: Attention mask layer: Which transformer layer to extract embeddings from (-1 means last layer) with_context: Whether to include context tokens in the embeddings Returns: torch.Tensor: Embeddings tensor """ return self.nicheformer.get_embeddings( input_ids=input_ids, attention_mask=attention_mask, layer=layer, with_context=with_context ) def complete_masking(batch, masking_p, n_tokens): """Apply masking to input batch for masked language modeling. Args: batch (dict): Input batch containing 'input_ids' and 'attention_mask' masking_p (float): Probability of masking a token n_tokens (int): Total number of tokens in vocabulary Returns: dict: Batch with masked indices and masking information """ device = batch['input_ids'].device input_ids = batch['input_ids'] attention_mask = batch['attention_mask'] # Create mask tensor (1 for tokens to be masked, 0 otherwise) prob = torch.rand(input_ids.shape, device=device) mask = (prob < masking_p) & (input_ids != PAD_TOKEN) & (input_ids != CLS_TOKEN) # For masked tokens: # - 80% replace with MASK token # - 10% replace with random token # - 10% keep unchanged masked_indices = input_ids.clone() # Calculate number of tokens to be masked num_tokens_to_mask = mask.sum().item() # Determine which tokens get which type of masking mask_mask = torch.rand(num_tokens_to_mask, device=device) < 0.8 random_mask = (torch.rand(num_tokens_to_mask, device=device) < 0.5) & ~mask_mask # Apply MASK token (80% of masked tokens) masked_indices[mask] = torch.where( mask_mask, torch.tensor(MASK_TOKEN, device=device, dtype=torch.long), masked_indices[mask] ) # Apply random tokens (10% of masked tokens) random_tokens = torch.randint( 3, n_tokens, # Start from 3 to avoid special tokens (random_mask.sum(),), device=device, dtype=torch.long ) masked_indices[mask][random_mask] = random_tokens # 10% remain unchanged return { 'masked_indices': masked_indices, 'attention_mask': attention_mask, 'mask': mask, 'input_ids': input_ids }