Sentence Similarity
Transformers
Safetensors
multilingual
nllb-llm2vec
feature-extraction
text-embedding
embeddings
information-retrieval
beir
text-classification
language-model
text-clustering
text-semantic-similarity
text-evaluation
text-reranking
Sentence Similarity
natural_questions
ms_marco
fever
hotpot_qa
mteb
custom_code
| from typing import Any, Dict, List, Optional, Tuple, cast | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from transformers.models.auto import AutoModel | |
| from transformers.modeling_outputs import BaseModelOutputWithPooling | |
| from transformers.modeling_utils import PreTrainedModel | |
| from transformers.models.m2m_100.modeling_m2m_100 import M2M100Encoder | |
| from .configuration_nllbllm2vec import NLLBLLM2VecConfig | |
| from .modeling_llama_encoder import LlamaEncoderModel | |
| class NLLBLLM2Vec(PreTrainedModel): | |
| config_class = NLLBLLM2VecConfig | |
| """ | |
| NLLBLLM2Vec model combining NLLB and LLama encoders. | |
| Args: | |
| config (Optional[NLLBLLM2VecConfig]): Configuration object. | |
| nllb_encoder (Optional[M2M100Encoder]): Pre-initialized NLLB encoder. | |
| llm2vec (Optional[LlamaEncoderModel]): Pre-initialized LLama encoder. | |
| *inputs: Additional positional arguments. | |
| **kwargs: Additional keyword arguments. | |
| """ | |
| model_type = "nllb-llm2vec" | |
| def __init__( | |
| self, | |
| config: Optional[NLLBLLM2VecConfig] = None, | |
| nllb_encoder: Optional[M2M100Encoder] = None, | |
| llm2vec: Optional[LlamaEncoderModel] = None, | |
| *inputs, | |
| **kwargs, | |
| ): | |
| # Ensure that either config is not None or both encoders are provided | |
| if config is None and (nllb_encoder is None or llm2vec is None): | |
| raise ValueError( | |
| "Either `config` must be provided, or both `nllb_encoder` and `llm2vec` must be specified." | |
| ) | |
| if config is not None: | |
| super().__init__(config, *inputs, **kwargs) | |
| self.nllb_encoder = nllb_encoder or M2M100Encoder(config.nllb_config) | |
| self.llm2vec = llm2vec or LlamaEncoderModel(config.llm2vec_config) | |
| self.config = config | |
| else: | |
| # Both encoders are provided | |
| self.nllb_encoder = cast(M2M100Encoder, nllb_encoder) | |
| self.llm2vec = cast(LlamaEncoderModel, llm2vec) | |
| self.config = NLLBLLM2VecConfig( | |
| nllb_config=self.nllb_encoder.config, # type: ignore | |
| llm2vec_config=self.llm2vec.config, # type: ignore | |
| ) | |
| super().__init__(self.config, *inputs, **kwargs) | |
| self.up_proj = nn.Linear( | |
| self.nllb_encoder.config.d_model, | |
| self.llm2vec.config.hidden_size, | |
| bias=False, | |
| ) | |
| # Additional initialization logic can go here | |
| def forward( | |
| self, | |
| input_ids: torch.Tensor, | |
| attention_mask: torch.Tensor, | |
| indices: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, | |
| *args, | |
| **kwargs, | |
| ) -> BaseModelOutputWithPooling: | |
| """ | |
| Forward pass of the model. | |
| Args: | |
| input_ids (torch.Tensor): Input token IDs. | |
| attention_mask (torch.Tensor): Attention mask. | |
| indices (Optional[Tuple[torch.Tensor, torch.Tensor]]): Precomputed input indices and offsets. | |
| Returns: | |
| BaseModelOutputWithPooling: Model outputs with last hidden state and pooled output. | |
| """ | |
| # Compute input indices and offsets if not provided | |
| if indices is None: | |
| seq_indices, seq_offsets = self._get_input_offsets(attention_mask) | |
| else: | |
| seq_indices, seq_offsets = indices | |
| with torch.inference_mode(): | |
| nllb_outputs = self.nllb_encoder( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| ) | |
| nllb_last_hidden_state = nllb_outputs.last_hidden_state | |
| nllb_last_hidden_state = self.up_proj(nllb_last_hidden_state) | |
| if self.training: | |
| # Inference mode otherwise yields embeddings that cannot be trained with | |
| nllb_last_hidden_state = nllb_last_hidden_state.detach().clone() | |
| outputs = self.llm2vec( | |
| inputs_embeds=nllb_last_hidden_state, | |
| attention_mask=attention_mask, | |
| ) | |
| pooler_output = self._mean_embedding( | |
| hidden_states=outputs.last_hidden_state, | |
| input_indices=seq_indices, | |
| offsets=seq_offsets, | |
| ) | |
| return BaseModelOutputWithPooling( | |
| last_hidden_state=outputs.last_hidden_state, | |
| pooler_output=pooler_output, | |
| ) | |
| def tokenizer(self): | |
| """ | |
| Get the tokenizer associated with the model. | |
| Returns: | |
| PreTrainedTokenizer: The tokenizer instance. | |
| """ | |
| if not hasattr(self, "_tokenizer"): | |
| from transformers import AutoTokenizer | |
| self._tokenizer = AutoTokenizer.from_pretrained( | |
| "facebook/nllb-200-distilled-600M", padding_side="right" | |
| ) | |
| return self._tokenizer | |
| def encode( | |
| self, | |
| inputs: List[str], | |
| src_lang: str = "eng_Latn", | |
| tokenize_kwargs: Optional[Dict[str, Any]] = None, | |
| ) -> torch.Tensor: | |
| """ | |
| Encode input texts into embeddings. | |
| Args: | |
| inputs (List[str]): List of input texts. | |
| src_lang (str): Source language code. | |
| tokenize_kwargs (Optional[Dict[str, Any]]): Additional keyword arguments for the tokenizer. | |
| Defaults to: | |
| >> tokenize_kwargs = { | |
| >> "padding": True, | |
| >> "truncation": True, | |
| >> "max_length": 512, | |
| >> "return_tensors": "pt", | |
| >> } | |
| Returns: | |
| torch.Tensor: Mean-pooled sequence embeddings of the inputs. | |
| """ | |
| if tokenize_kwargs is None: | |
| tokenize_kwargs = { | |
| "padding": True, | |
| "truncation": True, | |
| "max_length": 512, | |
| "return_tensors": "pt", | |
| } | |
| tokenizer = self.tokenizer | |
| tokenizer.src_lang = src_lang | |
| device = next(self.parameters()).device | |
| batch = tokenizer(inputs, **tokenize_kwargs).to(device) | |
| device_type = device.type # e.g., 'cuda' or 'cpu' | |
| with torch.autocast(device_type=device_type, dtype=torch.bfloat16): | |
| return self(**batch).pooler_output | |
| def _get_input_offsets( | |
| attention_mask: torch.Tensor, | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """ | |
| Compute indices and offsets for mean pooling using EmbeddingBag. | |
| Args: | |
| attention_mask (torch.Tensor): Attention mask of shape (batch_size, seq_len). | |
| Returns: | |
| Tuple[torch.Tensor, torch.Tensor]: A tuple containing: | |
| - input_indices: Indices of non-padded tokens in the flattened input. | |
| - offsets: Offsets indicating the start index of each sequence in the flattened input. | |
| """ | |
| # Find the indices of non-padded tokens in flattened hidden_states | |
| input_indices = attention_mask.view(-1).nonzero(as_tuple=False).squeeze() | |
| # Compute the offsets: for each sequence, where it starts in the flattened input | |
| non_padded_lengths = attention_mask.sum( | |
| dim=1 | |
| ) # Count non-padded tokens per sequence | |
| offsets = torch.cat( | |
| [ | |
| torch.tensor([0], device=attention_mask.device), | |
| non_padded_lengths.cumsum(dim=0)[:-1], | |
| ] | |
| ) | |
| return input_indices, offsets | |
| def _mean_embedding( | |
| hidden_states: torch.Tensor, | |
| input_indices: torch.Tensor, | |
| offsets: torch.Tensor, | |
| ) -> torch.Tensor: | |
| """ | |
| Compute the mean of non-padded embeddings using `embedding_bag`, | |
| properly handling padding with offsets. | |
| Args: | |
| hidden_states (torch.Tensor): Hidden states of shape (batch_size, seq_len, embed_dim). | |
| input_indices (torch.Tensor): Indices of non-padded tokens in flattened form. | |
| offsets (torch.Tensor): Offsets specifying the start of each sequence. | |
| Returns: | |
| torch.Tensor: Pooled mean embeddings of shape (batch_size, embed_dim). | |
| """ | |
| # Flatten hidden_states to 2D: shape (batch_size * seq_len, embedding_dim) | |
| batch_size, seq_len, embed_dim = hidden_states.shape | |
| token_embeds = hidden_states.view(-1, embed_dim) | |
| # Use embedding_bag with mode 'mean' and appropriate indices | |
| return F.embedding_bag( | |
| input=input_indices, # Indices of non-padded tokens in flattened form | |
| weight=token_embeds, # The flattened hidden states as embedding matrix | |
| offsets=offsets, # Offsets specifying start of each sequence | |
| mode="mean", # Aggregation mode | |
| ) | |
| AutoModel.register(NLLBLLM2VecConfig, NLLBLLM2Vec) | |
| def repl(): | |
| cfg = NLLBLLM2VecConfig() | |
| model = NLLBLLM2Vec(cfg) | |
| from peft.mapping import get_peft_model | |
| from peft.tuners.lora.config import LoraConfig | |
| lora_config = LoraConfig( | |
| r=16, | |
| lora_alpha=32, | |
| lora_dropout=0.0, | |
| bias="none", | |
| task_type="FEATURE_EXTRACTION", | |
| target_modules=[ | |
| "llm2vec.layers.0.self_attn.q_proj", | |
| "llm2vec.layers.0.self_attn.k_proj", | |
| "llm2vec.layers.0.self_attn.v_proj", | |
| "llm2vec.layers.0.self_attn.o_proj", | |
| "llm2vec.layers.0.mlp.gate_proj", | |
| "llm2vec.layers.0.mlp.up_proj", | |
| "llm2vec.layers.0.mlp.down_proj", | |
| "llm2vec.layers.1.self_attn.q_proj", | |
| "llm2vec.layers.1.self_attn.k_proj", | |
| "llm2vec.layers.1.self_attn.v_proj", | |
| "llm2vec.layers.1.self_attn.o_proj", | |
| "llm2vec.layers.1.mlp.gate_proj", | |
| "llm2vec.layers.1.mlp.up_proj", | |
| "llm2vec.layers.1.mlp.down_proj", | |
| "llm2vec.layers.2.self_attn.q_proj", | |
| "llm2vec.layers.2.self_attn.k_proj", | |
| "llm2vec.layers.2.self_attn.v_proj", | |
| "llm2vec.layers.2.self_attn.o_proj", | |
| "llm2vec.layers.2.mlp.gate_proj", | |
| "llm2vec.layers.2.mlp.up_proj", | |
| "llm2vec.layers.2.mlp.down_proj", | |
| "llm2vec.layers.3.self_attn.q_proj", | |
| "llm2vec.layers.3.self_attn.k_proj", | |
| "llm2vec.layers.3.self_attn.v_proj", | |
| "llm2vec.layers.3.self_attn.o_proj", | |
| "llm2vec.layers.3.mlp.gate_proj", | |
| "llm2vec.layers.3.mlp.up_proj", | |
| "llm2vec.layers.3.mlp.down_proj", | |
| "llm2vec.layers.4.self_attn.q_proj", | |
| "llm2vec.layers.4.self_attn.k_proj", | |
| "llm2vec.layers.4.self_attn.v_proj", | |
| "llm2vec.layers.4.self_attn.o_proj", | |
| "llm2vec.layers.4.mlp.gate_proj", | |
| "llm2vec.layers.4.mlp.up_proj", | |
| "llm2vec.layers.4.mlp.down_proj", | |
| "llm2vec.layers.5.self_attn.q_proj", | |
| "llm2vec.layers.5.self_attn.k_proj", | |
| "llm2vec.layers.5.self_attn.v_proj", | |
| "llm2vec.layers.5.self_attn.o_proj", | |
| "llm2vec.layers.5.mlp.gate_proj", | |
| "llm2vec.layers.5.mlp.up_proj", | |
| "llm2vec.layers.5.mlp.down_proj", | |
| "llm2vec.layers.6.self_attn.q_proj", | |
| "llm2vec.layers.6.self_attn.k_proj", | |
| "llm2vec.layers.6.self_attn.v_proj", | |
| "llm2vec.layers.6.self_attn.o_proj", | |
| "llm2vec.layers.6.mlp.gate_proj", | |
| "llm2vec.layers.6.mlp.up_proj", | |
| "llm2vec.layers.6.mlp.down_proj", | |
| "llm2vec.layers.7.self_attn.q_proj", | |
| "llm2vec.layers.7.self_attn.k_proj", | |
| "llm2vec.layers.7.self_attn.v_proj", | |
| "llm2vec.layers.7.self_attn.o_proj", | |
| "llm2vec.layers.7.mlp.gate_proj", | |
| "llm2vec.layers.7.mlp.up_proj", | |
| "llm2vec.layers.7.mlp.down_proj", | |
| "llm2vec.layers.8.self_attn.q_proj", | |
| "llm2vec.layers.8.self_attn.k_proj", | |
| "llm2vec.layers.8.self_attn.v_proj", | |
| "llm2vec.layers.8.self_attn.o_proj", | |
| "llm2vec.layers.8.mlp.gate_proj", | |
| "llm2vec.layers.8.mlp.up_proj", | |
| "llm2vec.layers.8.mlp.down_proj", | |
| "llm2vec.layers.9.self_attn.q_proj", | |
| "llm2vec.layers.9.self_attn.k_proj", | |
| "llm2vec.layers.9.self_attn.v_proj", | |
| "llm2vec.layers.9.self_attn.o_proj", | |
| "llm2vec.layers.9.mlp.gate_proj", | |
| "llm2vec.layers.9.mlp.up_proj", | |
| "llm2vec.layers.9.mlp.down_proj", | |
| "llm2vec.layers.10.self_attn.q_proj", | |
| "llm2vec.layers.10.self_attn.k_proj", | |
| "llm2vec.layers.10.self_attn.v_proj", | |
| "llm2vec.layers.10.self_attn.o_proj", | |
| "llm2vec.layers.10.mlp.gate_proj", | |
| "llm2vec.layers.10.mlp.up_proj", | |
| "llm2vec.layers.10.mlp.down_proj", | |
| "llm2vec.layers.11.self_attn.q_proj", | |
| "llm2vec.layers.11.self_attn.k_proj", | |
| "llm2vec.layers.11.self_attn.v_proj", | |
| "llm2vec.layers.11.self_attn.o_proj", | |
| "llm2vec.layers.11.mlp.gate_proj", | |
| "llm2vec.layers.11.mlp.up_proj", | |
| "llm2vec.layers.11.mlp.down_proj", | |
| "llm2vec.layers.12.self_attn.q_proj", | |
| "llm2vec.layers.12.self_attn.k_proj", | |
| "llm2vec.layers.12.self_attn.v_proj", | |
| "llm2vec.layers.12.self_attn.o_proj", | |
| "llm2vec.layers.12.mlp.gate_proj", | |
| "llm2vec.layers.12.mlp.up_proj", | |
| "llm2vec.layers.12.mlp.down_proj", | |
| "llm2vec.layers.13.self_attn.q_proj", | |
| "llm2vec.layers.13.self_attn.k_proj", | |
| "llm2vec.layers.13.self_attn.v_proj", | |
| "llm2vec.layers.13.self_attn.o_proj", | |
| "llm2vec.layers.13.mlp.gate_proj", | |
| "llm2vec.layers.13.mlp.up_proj", | |
| "llm2vec.layers.13.mlp.down_proj", | |
| "llm2vec.layers.14.self_attn.q_proj", | |
| "llm2vec.layers.14.self_attn.k_proj", | |
| "llm2vec.layers.14.self_attn.v_proj", | |
| "llm2vec.layers.14.self_attn.o_proj", | |
| "llm2vec.layers.14.mlp.gate_proj", | |
| "llm2vec.layers.14.mlp.up_proj", | |
| "llm2vec.layers.14.mlp.down_proj", | |
| "llm2vec.layers.15.self_attn.q_proj", | |
| "llm2vec.layers.15.self_attn.k_proj", | |
| "llm2vec.layers.15.self_attn.v_proj", | |
| "llm2vec.layers.15.self_attn.o_proj", | |
| "llm2vec.layers.15.mlp.gate_proj", | |
| "llm2vec.layers.15.mlp.up_proj", | |
| "llm2vec.layers.15.mlp.down_proj", | |
| "llm2vec.layers.16.self_attn.q_proj", | |
| "llm2vec.layers.16.self_attn.k_proj", | |
| "llm2vec.layers.16.self_attn.v_proj", | |
| "llm2vec.layers.16.self_attn.o_proj", | |
| "llm2vec.layers.16.mlp.gate_proj", | |
| "llm2vec.layers.16.mlp.up_proj", | |
| "llm2vec.layers.16.mlp.down_proj", | |
| "llm2vec.layers.17.self_attn.q_proj", | |
| "llm2vec.layers.17.self_attn.k_proj", | |
| "llm2vec.layers.17.self_attn.v_proj", | |
| "llm2vec.layers.17.self_attn.o_proj", | |
| "llm2vec.layers.17.mlp.gate_proj", | |
| "llm2vec.layers.17.mlp.up_proj", | |
| "llm2vec.layers.17.mlp.down_proj", | |
| "llm2vec.layers.18.self_attn.q_proj", | |
| "llm2vec.layers.18.self_attn.k_proj", | |
| "llm2vec.layers.18.self_attn.v_proj", | |
| "llm2vec.layers.18.self_attn.o_proj", | |
| "llm2vec.layers.18.mlp.gate_proj", | |
| "llm2vec.layers.18.mlp.up_proj", | |
| "llm2vec.layers.18.mlp.down_proj", | |
| "llm2vec.layers.19.self_attn.q_proj", | |
| "llm2vec.layers.19.self_attn.k_proj", | |
| "llm2vec.layers.19.self_attn.v_proj", | |
| "llm2vec.layers.19.self_attn.o_proj", | |
| "llm2vec.layers.19.mlp.gate_proj", | |
| "llm2vec.layers.19.mlp.up_proj", | |
| "llm2vec.layers.19.mlp.down_proj", | |
| "llm2vec.layers.20.self_attn.q_proj", | |
| "llm2vec.layers.20.self_attn.k_proj", | |
| "llm2vec.layers.20.self_attn.v_proj", | |
| "llm2vec.layers.20.self_attn.o_proj", | |
| "llm2vec.layers.20.mlp.gate_proj", | |
| "llm2vec.layers.20.mlp.up_proj", | |
| "llm2vec.layers.20.mlp.down_proj", | |
| "llm2vec.layers.21.self_attn.q_proj", | |
| "llm2vec.layers.21.self_attn.k_proj", | |
| "llm2vec.layers.21.self_attn.v_proj", | |
| "llm2vec.layers.21.self_attn.o_proj", | |
| "llm2vec.layers.21.mlp.gate_proj", | |
| "llm2vec.layers.21.mlp.up_proj", | |
| "llm2vec.layers.21.mlp.down_proj", | |
| "llm2vec.layers.22.self_attn.q_proj", | |
| "llm2vec.layers.22.self_attn.k_proj", | |
| "llm2vec.layers.22.self_attn.v_proj", | |
| "llm2vec.layers.22.self_attn.o_proj", | |
| "llm2vec.layers.22.mlp.gate_proj", | |
| "llm2vec.layers.22.mlp.up_proj", | |
| "llm2vec.layers.22.mlp.down_proj", | |
| "llm2vec.layers.23.self_attn.q_proj", | |
| "llm2vec.layers.23.self_attn.k_proj", | |
| "llm2vec.layers.23.self_attn.v_proj", | |
| "llm2vec.layers.23.self_attn.o_proj", | |
| "llm2vec.layers.23.mlp.gate_proj", | |
| "llm2vec.layers.23.mlp.up_proj", | |
| "llm2vec.layers.23.mlp.down_proj", | |
| "llm2vec.layers.24.self_attn.q_proj", | |
| "llm2vec.layers.24.self_attn.k_proj", | |
| "llm2vec.layers.24.self_attn.v_proj", | |
| "llm2vec.layers.24.self_attn.o_proj", | |
| "llm2vec.layers.24.mlp.gate_proj", | |
| "llm2vec.layers.24.mlp.up_proj", | |
| "llm2vec.layers.24.mlp.down_proj", | |
| "llm2vec.layers.25.self_attn.q_proj", | |
| "llm2vec.layers.25.self_attn.k_proj", | |
| "llm2vec.layers.25.self_attn.v_proj", | |
| "llm2vec.layers.25.self_attn.o_proj", | |
| "llm2vec.layers.25.mlp.gate_proj", | |
| "llm2vec.layers.25.mlp.up_proj", | |
| "llm2vec.layers.25.mlp.down_proj", | |
| "llm2vec.layers.26.self_attn.q_proj", | |
| "llm2vec.layers.26.self_attn.k_proj", | |
| "llm2vec.layers.26.self_attn.v_proj", | |
| "llm2vec.layers.26.self_attn.o_proj", | |
| "llm2vec.layers.26.mlp.gate_proj", | |
| "llm2vec.layers.26.mlp.up_proj", | |
| "llm2vec.layers.26.mlp.down_proj", | |
| "llm2vec.layers.27.self_attn.q_proj", | |
| "llm2vec.layers.27.self_attn.k_proj", | |
| "llm2vec.layers.27.self_attn.v_proj", | |
| "llm2vec.layers.27.self_attn.o_proj", | |
| "llm2vec.layers.27.mlp.gate_proj", | |
| "llm2vec.layers.27.mlp.up_proj", | |
| "llm2vec.layers.27.mlp.down_proj", | |
| "llm2vec.layers.28.self_attn.q_proj", | |
| "llm2vec.layers.28.self_attn.k_proj", | |
| "llm2vec.layers.28.self_attn.v_proj", | |
| "llm2vec.layers.28.self_attn.o_proj", | |
| "llm2vec.layers.28.mlp.gate_proj", | |
| "llm2vec.layers.28.mlp.up_proj", | |
| "llm2vec.layers.28.mlp.down_proj", | |
| "llm2vec.layers.29.self_attn.q_proj", | |
| "llm2vec.layers.29.self_attn.k_proj", | |
| "llm2vec.layers.29.self_attn.v_proj", | |
| "llm2vec.layers.29.self_attn.o_proj", | |
| "llm2vec.layers.29.mlp.gate_proj", | |
| "llm2vec.layers.29.mlp.up_proj", | |
| "llm2vec.layers.29.mlp.down_proj", | |
| "llm2vec.layers.30.self_attn.q_proj", | |
| "llm2vec.layers.30.self_attn.k_proj", | |
| "llm2vec.layers.30.self_attn.v_proj", | |
| "llm2vec.layers.30.self_attn.o_proj", | |
| "llm2vec.layers.30.mlp.gate_proj", | |
| "llm2vec.layers.30.mlp.up_proj", | |
| "llm2vec.layers.30.mlp.down_proj", | |
| "llm2vec.layers.31.self_attn.q_proj", | |
| "llm2vec.layers.31.self_attn.k_proj", | |
| "llm2vec.layers.31.self_attn.v_proj", | |
| "llm2vec.layers.31.self_attn.o_proj", | |
| "llm2vec.layers.31.mlp.gate_proj", | |
| "llm2vec.layers.31.mlp.up_proj", | |
| "llm2vec.layers.31.mlp.down_proj", | |
| ], | |
| ) | |
| peft_model = get_peft_model(model, lora_config) | |
| peft_model.save_pretrained("../nllb-llm2vec-saved") | |
| import json | |
| with open("./model.safetensors.index.json", "r") as f: | |
| print(json.load(f)) | |