from typing import Any, Dict, List, Optional, Tuple, cast import torch import torch.nn as nn import torch.nn.functional as F from transformers.models.auto import AutoModel from transformers.modeling_outputs import BaseModelOutputWithPooling from transformers.modeling_utils import PreTrainedModel from transformers.models.m2m_100.modeling_m2m_100 import M2M100Encoder from .configuration_nllbllm2vec import NLLBLLM2VecConfig from .modeling_llama_encoder import LlamaEncoderModel class NLLBLLM2Vec(PreTrainedModel): config_class = NLLBLLM2VecConfig """ NLLBLLM2Vec model combining NLLB and LLama encoders. Args: config (Optional[NLLBLLM2VecConfig]): Configuration object. nllb_encoder (Optional[M2M100Encoder]): Pre-initialized NLLB encoder. llm2vec (Optional[LlamaEncoderModel]): Pre-initialized LLama encoder. *inputs: Additional positional arguments. **kwargs: Additional keyword arguments. """ model_type = "nllb-llm2vec" def __init__( self, config: Optional[NLLBLLM2VecConfig] = None, nllb_encoder: Optional[M2M100Encoder] = None, llm2vec: Optional[LlamaEncoderModel] = None, *inputs, **kwargs, ): # Ensure that either config is not None or both encoders are provided if config is None and (nllb_encoder is None or llm2vec is None): raise ValueError( "Either `config` must be provided, or both `nllb_encoder` and `llm2vec` must be specified." ) if config is not None: super().__init__(config, *inputs, **kwargs) self.nllb_encoder = nllb_encoder or M2M100Encoder(config.nllb_config) self.llm2vec = llm2vec or LlamaEncoderModel(config.llm2vec_config) self.config = config else: # Both encoders are provided self.nllb_encoder = cast(M2M100Encoder, nllb_encoder) self.llm2vec = cast(LlamaEncoderModel, llm2vec) self.config = NLLBLLM2VecConfig( nllb_config=self.nllb_encoder.config, # type: ignore llm2vec_config=self.llm2vec.config, # type: ignore ) super().__init__(self.config, *inputs, **kwargs) self.up_proj = nn.Linear( self.nllb_encoder.config.d_model, self.llm2vec.config.hidden_size, bias=False, ) # Additional initialization logic can go here def forward( self, input_ids: torch.Tensor, attention_mask: torch.Tensor, indices: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, *args, **kwargs, ) -> BaseModelOutputWithPooling: """ Forward pass of the model. Args: input_ids (torch.Tensor): Input token IDs. attention_mask (torch.Tensor): Attention mask. indices (Optional[Tuple[torch.Tensor, torch.Tensor]]): Precomputed input indices and offsets. Returns: BaseModelOutputWithPooling: Model outputs with last hidden state and pooled output. """ # Compute input indices and offsets if not provided if indices is None: seq_indices, seq_offsets = self._get_input_offsets(attention_mask) else: seq_indices, seq_offsets = indices with torch.inference_mode(): nllb_outputs = self.nllb_encoder( input_ids=input_ids, attention_mask=attention_mask, ) nllb_last_hidden_state = nllb_outputs.last_hidden_state nllb_last_hidden_state = self.up_proj(nllb_last_hidden_state) if self.training: # Inference mode otherwise yields embeddings that cannot be trained with nllb_last_hidden_state = nllb_last_hidden_state.detach().clone() outputs = self.llm2vec( inputs_embeds=nllb_last_hidden_state, attention_mask=attention_mask, ) pooler_output = self._mean_embedding( hidden_states=outputs.last_hidden_state, input_indices=seq_indices, offsets=seq_offsets, ) return BaseModelOutputWithPooling( last_hidden_state=outputs.last_hidden_state, pooler_output=pooler_output, ) @property def tokenizer(self): """ Get the tokenizer associated with the model. Returns: PreTrainedTokenizer: The tokenizer instance. """ if not hasattr(self, "_tokenizer"): from transformers import AutoTokenizer self._tokenizer = AutoTokenizer.from_pretrained( "facebook/nllb-200-distilled-600M", padding_side="right" ) return self._tokenizer def encode( self, inputs: List[str], src_lang: str = "eng_Latn", tokenize_kwargs: Optional[Dict[str, Any]] = None, ) -> torch.Tensor: """ Encode input texts into embeddings. Args: inputs (List[str]): List of input texts. src_lang (str): Source language code. tokenize_kwargs (Optional[Dict[str, Any]]): Additional keyword arguments for the tokenizer. Defaults to: >> tokenize_kwargs = { >> "padding": True, >> "truncation": True, >> "max_length": 512, >> "return_tensors": "pt", >> } Returns: torch.Tensor: Mean-pooled sequence embeddings of the inputs. """ if tokenize_kwargs is None: tokenize_kwargs = { "padding": True, "truncation": True, "max_length": 512, "return_tensors": "pt", } tokenizer = self.tokenizer tokenizer.src_lang = src_lang device = next(self.parameters()).device batch = tokenizer(inputs, **tokenize_kwargs).to(device) device_type = device.type # e.g., 'cuda' or 'cpu' with torch.autocast(device_type=device_type, dtype=torch.bfloat16): return self(**batch).pooler_output @staticmethod def _get_input_offsets( attention_mask: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Compute indices and offsets for mean pooling using EmbeddingBag. Args: attention_mask (torch.Tensor): Attention mask of shape (batch_size, seq_len). Returns: Tuple[torch.Tensor, torch.Tensor]: A tuple containing: - input_indices: Indices of non-padded tokens in the flattened input. - offsets: Offsets indicating the start index of each sequence in the flattened input. """ # Find the indices of non-padded tokens in flattened hidden_states input_indices = attention_mask.view(-1).nonzero(as_tuple=False).squeeze() # Compute the offsets: for each sequence, where it starts in the flattened input non_padded_lengths = attention_mask.sum( dim=1 ) # Count non-padded tokens per sequence offsets = torch.cat( [ torch.tensor([0], device=attention_mask.device), non_padded_lengths.cumsum(dim=0)[:-1], ] ) return input_indices, offsets @staticmethod def _mean_embedding( hidden_states: torch.Tensor, input_indices: torch.Tensor, offsets: torch.Tensor, ) -> torch.Tensor: """ Compute the mean of non-padded embeddings using `embedding_bag`, properly handling padding with offsets. Args: hidden_states (torch.Tensor): Hidden states of shape (batch_size, seq_len, embed_dim). input_indices (torch.Tensor): Indices of non-padded tokens in flattened form. offsets (torch.Tensor): Offsets specifying the start of each sequence. Returns: torch.Tensor: Pooled mean embeddings of shape (batch_size, embed_dim). """ # Flatten hidden_states to 2D: shape (batch_size * seq_len, embedding_dim) batch_size, seq_len, embed_dim = hidden_states.shape token_embeds = hidden_states.view(-1, embed_dim) # Use embedding_bag with mode 'mean' and appropriate indices return F.embedding_bag( input=input_indices, # Indices of non-padded tokens in flattened form weight=token_embeds, # The flattened hidden states as embedding matrix offsets=offsets, # Offsets specifying start of each sequence mode="mean", # Aggregation mode ) AutoModel.register(NLLBLLM2VecConfig, NLLBLLM2Vec) def repl(): cfg = NLLBLLM2VecConfig() model = NLLBLLM2Vec(cfg) from peft.mapping import get_peft_model from peft.tuners.lora.config import LoraConfig lora_config = LoraConfig( r=16, lora_alpha=32, lora_dropout=0.0, bias="none", task_type="FEATURE_EXTRACTION", target_modules=[ "llm2vec.layers.0.self_attn.q_proj", "llm2vec.layers.0.self_attn.k_proj", "llm2vec.layers.0.self_attn.v_proj", "llm2vec.layers.0.self_attn.o_proj", "llm2vec.layers.0.mlp.gate_proj", "llm2vec.layers.0.mlp.up_proj", "llm2vec.layers.0.mlp.down_proj", "llm2vec.layers.1.self_attn.q_proj", "llm2vec.layers.1.self_attn.k_proj", "llm2vec.layers.1.self_attn.v_proj", "llm2vec.layers.1.self_attn.o_proj", "llm2vec.layers.1.mlp.gate_proj", "llm2vec.layers.1.mlp.up_proj", "llm2vec.layers.1.mlp.down_proj", "llm2vec.layers.2.self_attn.q_proj", "llm2vec.layers.2.self_attn.k_proj", "llm2vec.layers.2.self_attn.v_proj", "llm2vec.layers.2.self_attn.o_proj", "llm2vec.layers.2.mlp.gate_proj", "llm2vec.layers.2.mlp.up_proj", "llm2vec.layers.2.mlp.down_proj", "llm2vec.layers.3.self_attn.q_proj", "llm2vec.layers.3.self_attn.k_proj", "llm2vec.layers.3.self_attn.v_proj", "llm2vec.layers.3.self_attn.o_proj", "llm2vec.layers.3.mlp.gate_proj", "llm2vec.layers.3.mlp.up_proj", "llm2vec.layers.3.mlp.down_proj", "llm2vec.layers.4.self_attn.q_proj", "llm2vec.layers.4.self_attn.k_proj", "llm2vec.layers.4.self_attn.v_proj", "llm2vec.layers.4.self_attn.o_proj", "llm2vec.layers.4.mlp.gate_proj", "llm2vec.layers.4.mlp.up_proj", "llm2vec.layers.4.mlp.down_proj", "llm2vec.layers.5.self_attn.q_proj", "llm2vec.layers.5.self_attn.k_proj", "llm2vec.layers.5.self_attn.v_proj", "llm2vec.layers.5.self_attn.o_proj", "llm2vec.layers.5.mlp.gate_proj", "llm2vec.layers.5.mlp.up_proj", "llm2vec.layers.5.mlp.down_proj", "llm2vec.layers.6.self_attn.q_proj", "llm2vec.layers.6.self_attn.k_proj", "llm2vec.layers.6.self_attn.v_proj", "llm2vec.layers.6.self_attn.o_proj", "llm2vec.layers.6.mlp.gate_proj", "llm2vec.layers.6.mlp.up_proj", "llm2vec.layers.6.mlp.down_proj", "llm2vec.layers.7.self_attn.q_proj", "llm2vec.layers.7.self_attn.k_proj", "llm2vec.layers.7.self_attn.v_proj", "llm2vec.layers.7.self_attn.o_proj", "llm2vec.layers.7.mlp.gate_proj", "llm2vec.layers.7.mlp.up_proj", "llm2vec.layers.7.mlp.down_proj", "llm2vec.layers.8.self_attn.q_proj", "llm2vec.layers.8.self_attn.k_proj", "llm2vec.layers.8.self_attn.v_proj", "llm2vec.layers.8.self_attn.o_proj", "llm2vec.layers.8.mlp.gate_proj", "llm2vec.layers.8.mlp.up_proj", "llm2vec.layers.8.mlp.down_proj", "llm2vec.layers.9.self_attn.q_proj", "llm2vec.layers.9.self_attn.k_proj", "llm2vec.layers.9.self_attn.v_proj", "llm2vec.layers.9.self_attn.o_proj", "llm2vec.layers.9.mlp.gate_proj", "llm2vec.layers.9.mlp.up_proj", "llm2vec.layers.9.mlp.down_proj", "llm2vec.layers.10.self_attn.q_proj", "llm2vec.layers.10.self_attn.k_proj", "llm2vec.layers.10.self_attn.v_proj", "llm2vec.layers.10.self_attn.o_proj", "llm2vec.layers.10.mlp.gate_proj", "llm2vec.layers.10.mlp.up_proj", "llm2vec.layers.10.mlp.down_proj", "llm2vec.layers.11.self_attn.q_proj", "llm2vec.layers.11.self_attn.k_proj", "llm2vec.layers.11.self_attn.v_proj", "llm2vec.layers.11.self_attn.o_proj", "llm2vec.layers.11.mlp.gate_proj", "llm2vec.layers.11.mlp.up_proj", "llm2vec.layers.11.mlp.down_proj", "llm2vec.layers.12.self_attn.q_proj", "llm2vec.layers.12.self_attn.k_proj", "llm2vec.layers.12.self_attn.v_proj", "llm2vec.layers.12.self_attn.o_proj", "llm2vec.layers.12.mlp.gate_proj", "llm2vec.layers.12.mlp.up_proj", "llm2vec.layers.12.mlp.down_proj", "llm2vec.layers.13.self_attn.q_proj", "llm2vec.layers.13.self_attn.k_proj", "llm2vec.layers.13.self_attn.v_proj", "llm2vec.layers.13.self_attn.o_proj", "llm2vec.layers.13.mlp.gate_proj", "llm2vec.layers.13.mlp.up_proj", "llm2vec.layers.13.mlp.down_proj", "llm2vec.layers.14.self_attn.q_proj", "llm2vec.layers.14.self_attn.k_proj", "llm2vec.layers.14.self_attn.v_proj", "llm2vec.layers.14.self_attn.o_proj", "llm2vec.layers.14.mlp.gate_proj", "llm2vec.layers.14.mlp.up_proj", "llm2vec.layers.14.mlp.down_proj", "llm2vec.layers.15.self_attn.q_proj", "llm2vec.layers.15.self_attn.k_proj", "llm2vec.layers.15.self_attn.v_proj", "llm2vec.layers.15.self_attn.o_proj", "llm2vec.layers.15.mlp.gate_proj", "llm2vec.layers.15.mlp.up_proj", "llm2vec.layers.15.mlp.down_proj", "llm2vec.layers.16.self_attn.q_proj", "llm2vec.layers.16.self_attn.k_proj", "llm2vec.layers.16.self_attn.v_proj", "llm2vec.layers.16.self_attn.o_proj", "llm2vec.layers.16.mlp.gate_proj", "llm2vec.layers.16.mlp.up_proj", "llm2vec.layers.16.mlp.down_proj", "llm2vec.layers.17.self_attn.q_proj", "llm2vec.layers.17.self_attn.k_proj", "llm2vec.layers.17.self_attn.v_proj", "llm2vec.layers.17.self_attn.o_proj", "llm2vec.layers.17.mlp.gate_proj", "llm2vec.layers.17.mlp.up_proj", "llm2vec.layers.17.mlp.down_proj", "llm2vec.layers.18.self_attn.q_proj", "llm2vec.layers.18.self_attn.k_proj", "llm2vec.layers.18.self_attn.v_proj", "llm2vec.layers.18.self_attn.o_proj", "llm2vec.layers.18.mlp.gate_proj", "llm2vec.layers.18.mlp.up_proj", "llm2vec.layers.18.mlp.down_proj", "llm2vec.layers.19.self_attn.q_proj", "llm2vec.layers.19.self_attn.k_proj", "llm2vec.layers.19.self_attn.v_proj", "llm2vec.layers.19.self_attn.o_proj", "llm2vec.layers.19.mlp.gate_proj", "llm2vec.layers.19.mlp.up_proj", "llm2vec.layers.19.mlp.down_proj", "llm2vec.layers.20.self_attn.q_proj", "llm2vec.layers.20.self_attn.k_proj", "llm2vec.layers.20.self_attn.v_proj", "llm2vec.layers.20.self_attn.o_proj", "llm2vec.layers.20.mlp.gate_proj", "llm2vec.layers.20.mlp.up_proj", "llm2vec.layers.20.mlp.down_proj", "llm2vec.layers.21.self_attn.q_proj", "llm2vec.layers.21.self_attn.k_proj", "llm2vec.layers.21.self_attn.v_proj", "llm2vec.layers.21.self_attn.o_proj", "llm2vec.layers.21.mlp.gate_proj", "llm2vec.layers.21.mlp.up_proj", "llm2vec.layers.21.mlp.down_proj", "llm2vec.layers.22.self_attn.q_proj", "llm2vec.layers.22.self_attn.k_proj", "llm2vec.layers.22.self_attn.v_proj", "llm2vec.layers.22.self_attn.o_proj", "llm2vec.layers.22.mlp.gate_proj", "llm2vec.layers.22.mlp.up_proj", "llm2vec.layers.22.mlp.down_proj", "llm2vec.layers.23.self_attn.q_proj", "llm2vec.layers.23.self_attn.k_proj", "llm2vec.layers.23.self_attn.v_proj", "llm2vec.layers.23.self_attn.o_proj", "llm2vec.layers.23.mlp.gate_proj", "llm2vec.layers.23.mlp.up_proj", "llm2vec.layers.23.mlp.down_proj", "llm2vec.layers.24.self_attn.q_proj", "llm2vec.layers.24.self_attn.k_proj", "llm2vec.layers.24.self_attn.v_proj", "llm2vec.layers.24.self_attn.o_proj", "llm2vec.layers.24.mlp.gate_proj", "llm2vec.layers.24.mlp.up_proj", "llm2vec.layers.24.mlp.down_proj", "llm2vec.layers.25.self_attn.q_proj", "llm2vec.layers.25.self_attn.k_proj", "llm2vec.layers.25.self_attn.v_proj", "llm2vec.layers.25.self_attn.o_proj", "llm2vec.layers.25.mlp.gate_proj", "llm2vec.layers.25.mlp.up_proj", "llm2vec.layers.25.mlp.down_proj", "llm2vec.layers.26.self_attn.q_proj", "llm2vec.layers.26.self_attn.k_proj", "llm2vec.layers.26.self_attn.v_proj", "llm2vec.layers.26.self_attn.o_proj", "llm2vec.layers.26.mlp.gate_proj", "llm2vec.layers.26.mlp.up_proj", "llm2vec.layers.26.mlp.down_proj", "llm2vec.layers.27.self_attn.q_proj", "llm2vec.layers.27.self_attn.k_proj", "llm2vec.layers.27.self_attn.v_proj", "llm2vec.layers.27.self_attn.o_proj", "llm2vec.layers.27.mlp.gate_proj", "llm2vec.layers.27.mlp.up_proj", "llm2vec.layers.27.mlp.down_proj", "llm2vec.layers.28.self_attn.q_proj", "llm2vec.layers.28.self_attn.k_proj", "llm2vec.layers.28.self_attn.v_proj", "llm2vec.layers.28.self_attn.o_proj", "llm2vec.layers.28.mlp.gate_proj", "llm2vec.layers.28.mlp.up_proj", "llm2vec.layers.28.mlp.down_proj", "llm2vec.layers.29.self_attn.q_proj", "llm2vec.layers.29.self_attn.k_proj", "llm2vec.layers.29.self_attn.v_proj", "llm2vec.layers.29.self_attn.o_proj", "llm2vec.layers.29.mlp.gate_proj", "llm2vec.layers.29.mlp.up_proj", "llm2vec.layers.29.mlp.down_proj", "llm2vec.layers.30.self_attn.q_proj", "llm2vec.layers.30.self_attn.k_proj", "llm2vec.layers.30.self_attn.v_proj", "llm2vec.layers.30.self_attn.o_proj", "llm2vec.layers.30.mlp.gate_proj", "llm2vec.layers.30.mlp.up_proj", "llm2vec.layers.30.mlp.down_proj", "llm2vec.layers.31.self_attn.q_proj", "llm2vec.layers.31.self_attn.k_proj", "llm2vec.layers.31.self_attn.v_proj", "llm2vec.layers.31.self_attn.o_proj", "llm2vec.layers.31.mlp.gate_proj", "llm2vec.layers.31.mlp.up_proj", "llm2vec.layers.31.mlp.down_proj", ], ) peft_model = get_peft_model(model, lora_config) peft_model.save_pretrained("../nllb-llm2vec-saved") import json with open("./model.safetensors.index.json", "r") as f: print(json.load(f))