fdschmidt93's picture
initial commit
b0221f6
raw
history blame
20.1 kB
from typing import Any, Dict, List, Optional, Tuple, cast
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.models.auto import AutoModel
from transformers.modeling_outputs import BaseModelOutputWithPooling
from transformers.modeling_utils import PreTrainedModel
from transformers.models.m2m_100.modeling_m2m_100 import M2M100Encoder
from .configuration_nllbllm2vec import NLLBLLM2VecConfig
from .modeling_llama_encoder import LlamaEncoderModel
class NLLBLLM2Vec(PreTrainedModel):
config_class = NLLBLLM2VecConfig
"""
NLLBLLM2Vec model combining NLLB and LLama encoders.
Args:
config (Optional[NLLBLLM2VecConfig]): Configuration object.
nllb_encoder (Optional[M2M100Encoder]): Pre-initialized NLLB encoder.
llm2vec (Optional[LlamaEncoderModel]): Pre-initialized LLama encoder.
*inputs: Additional positional arguments.
**kwargs: Additional keyword arguments.
"""
model_type = "nllb-llm2vec"
def __init__(
self,
config: Optional[NLLBLLM2VecConfig] = None,
nllb_encoder: Optional[M2M100Encoder] = None,
llm2vec: Optional[LlamaEncoderModel] = None,
*inputs,
**kwargs,
):
# Ensure that either config is not None or both encoders are provided
if config is None and (nllb_encoder is None or llm2vec is None):
raise ValueError(
"Either `config` must be provided, or both `nllb_encoder` and `llm2vec` must be specified."
)
if config is not None:
super().__init__(config, *inputs, **kwargs)
self.nllb_encoder = nllb_encoder or M2M100Encoder(config.nllb_config)
self.llm2vec = llm2vec or LlamaEncoderModel(config.llm2vec_config)
self.config = config
else:
# Both encoders are provided
self.nllb_encoder = cast(M2M100Encoder, nllb_encoder)
self.llm2vec = cast(LlamaEncoderModel, llm2vec)
self.config = NLLBLLM2VecConfig(
nllb_config=self.nllb_encoder.config, # type: ignore
llm2vec_config=self.llm2vec.config, # type: ignore
)
super().__init__(self.config, *inputs, **kwargs)
self.up_proj = nn.Linear(
self.nllb_encoder.config.d_model,
self.llm2vec.config.hidden_size,
bias=False,
)
# Additional initialization logic can go here
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
indices: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
*args,
**kwargs,
) -> BaseModelOutputWithPooling:
"""
Forward pass of the model.
Args:
input_ids (torch.Tensor): Input token IDs.
attention_mask (torch.Tensor): Attention mask.
indices (Optional[Tuple[torch.Tensor, torch.Tensor]]): Precomputed input indices and offsets.
Returns:
BaseModelOutputWithPooling: Model outputs with last hidden state and pooled output.
"""
# Compute input indices and offsets if not provided
if indices is None:
seq_indices, seq_offsets = self._get_input_offsets(attention_mask)
else:
seq_indices, seq_offsets = indices
with torch.inference_mode():
nllb_outputs = self.nllb_encoder(
input_ids=input_ids,
attention_mask=attention_mask,
)
nllb_last_hidden_state = nllb_outputs.last_hidden_state
nllb_last_hidden_state = self.up_proj(nllb_last_hidden_state)
if self.training:
# Inference mode otherwise yields embeddings that cannot be trained with
nllb_last_hidden_state = nllb_last_hidden_state.detach().clone()
outputs = self.llm2vec(
inputs_embeds=nllb_last_hidden_state,
attention_mask=attention_mask,
)
pooler_output = self._mean_embedding(
hidden_states=outputs.last_hidden_state,
input_indices=seq_indices,
offsets=seq_offsets,
)
return BaseModelOutputWithPooling(
last_hidden_state=outputs.last_hidden_state,
pooler_output=pooler_output,
)
@property
def tokenizer(self):
"""
Get the tokenizer associated with the model.
Returns:
PreTrainedTokenizer: The tokenizer instance.
"""
if not hasattr(self, "_tokenizer"):
from transformers import AutoTokenizer
self._tokenizer = AutoTokenizer.from_pretrained(
"facebook/nllb-200-distilled-600M", padding_side="right"
)
return self._tokenizer
def encode(
self,
inputs: List[str],
src_lang: str = "eng_Latn",
tokenize_kwargs: Optional[Dict[str, Any]] = None,
) -> torch.Tensor:
"""
Encode input texts into embeddings.
Args:
inputs (List[str]): List of input texts.
src_lang (str): Source language code.
tokenize_kwargs (Optional[Dict[str, Any]]): Additional keyword arguments for the tokenizer.
Defaults to:
>> tokenize_kwargs = {
>> "padding": True,
>> "truncation": True,
>> "max_length": 512,
>> "return_tensors": "pt",
>> }
Returns:
torch.Tensor: Mean-pooled sequence embeddings of the inputs.
"""
if tokenize_kwargs is None:
tokenize_kwargs = {
"padding": True,
"truncation": True,
"max_length": 512,
"return_tensors": "pt",
}
tokenizer = self.tokenizer
tokenizer.src_lang = src_lang
device = next(self.parameters()).device
batch = tokenizer(inputs, **tokenize_kwargs).to(device)
device_type = device.type # e.g., 'cuda' or 'cpu'
with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
return self(**batch).pooler_output
@staticmethod
def _get_input_offsets(
attention_mask: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Compute indices and offsets for mean pooling using EmbeddingBag.
Args:
attention_mask (torch.Tensor): Attention mask of shape (batch_size, seq_len).
Returns:
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
- input_indices: Indices of non-padded tokens in the flattened input.
- offsets: Offsets indicating the start index of each sequence in the flattened input.
"""
# Find the indices of non-padded tokens in flattened hidden_states
input_indices = attention_mask.view(-1).nonzero(as_tuple=False).squeeze()
# Compute the offsets: for each sequence, where it starts in the flattened input
non_padded_lengths = attention_mask.sum(
dim=1
) # Count non-padded tokens per sequence
offsets = torch.cat(
[
torch.tensor([0], device=attention_mask.device),
non_padded_lengths.cumsum(dim=0)[:-1],
]
)
return input_indices, offsets
@staticmethod
def _mean_embedding(
hidden_states: torch.Tensor,
input_indices: torch.Tensor,
offsets: torch.Tensor,
) -> torch.Tensor:
"""
Compute the mean of non-padded embeddings using `embedding_bag`,
properly handling padding with offsets.
Args:
hidden_states (torch.Tensor): Hidden states of shape (batch_size, seq_len, embed_dim).
input_indices (torch.Tensor): Indices of non-padded tokens in flattened form.
offsets (torch.Tensor): Offsets specifying the start of each sequence.
Returns:
torch.Tensor: Pooled mean embeddings of shape (batch_size, embed_dim).
"""
# Flatten hidden_states to 2D: shape (batch_size * seq_len, embedding_dim)
batch_size, seq_len, embed_dim = hidden_states.shape
token_embeds = hidden_states.view(-1, embed_dim)
# Use embedding_bag with mode 'mean' and appropriate indices
return F.embedding_bag(
input=input_indices, # Indices of non-padded tokens in flattened form
weight=token_embeds, # The flattened hidden states as embedding matrix
offsets=offsets, # Offsets specifying start of each sequence
mode="mean", # Aggregation mode
)
AutoModel.register(NLLBLLM2VecConfig, NLLBLLM2Vec)
def repl():
cfg = NLLBLLM2VecConfig()
model = NLLBLLM2Vec(cfg)
from peft.mapping import get_peft_model
from peft.tuners.lora.config import LoraConfig
lora_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.0,
bias="none",
task_type="FEATURE_EXTRACTION",
target_modules=[
"llm2vec.layers.0.self_attn.q_proj",
"llm2vec.layers.0.self_attn.k_proj",
"llm2vec.layers.0.self_attn.v_proj",
"llm2vec.layers.0.self_attn.o_proj",
"llm2vec.layers.0.mlp.gate_proj",
"llm2vec.layers.0.mlp.up_proj",
"llm2vec.layers.0.mlp.down_proj",
"llm2vec.layers.1.self_attn.q_proj",
"llm2vec.layers.1.self_attn.k_proj",
"llm2vec.layers.1.self_attn.v_proj",
"llm2vec.layers.1.self_attn.o_proj",
"llm2vec.layers.1.mlp.gate_proj",
"llm2vec.layers.1.mlp.up_proj",
"llm2vec.layers.1.mlp.down_proj",
"llm2vec.layers.2.self_attn.q_proj",
"llm2vec.layers.2.self_attn.k_proj",
"llm2vec.layers.2.self_attn.v_proj",
"llm2vec.layers.2.self_attn.o_proj",
"llm2vec.layers.2.mlp.gate_proj",
"llm2vec.layers.2.mlp.up_proj",
"llm2vec.layers.2.mlp.down_proj",
"llm2vec.layers.3.self_attn.q_proj",
"llm2vec.layers.3.self_attn.k_proj",
"llm2vec.layers.3.self_attn.v_proj",
"llm2vec.layers.3.self_attn.o_proj",
"llm2vec.layers.3.mlp.gate_proj",
"llm2vec.layers.3.mlp.up_proj",
"llm2vec.layers.3.mlp.down_proj",
"llm2vec.layers.4.self_attn.q_proj",
"llm2vec.layers.4.self_attn.k_proj",
"llm2vec.layers.4.self_attn.v_proj",
"llm2vec.layers.4.self_attn.o_proj",
"llm2vec.layers.4.mlp.gate_proj",
"llm2vec.layers.4.mlp.up_proj",
"llm2vec.layers.4.mlp.down_proj",
"llm2vec.layers.5.self_attn.q_proj",
"llm2vec.layers.5.self_attn.k_proj",
"llm2vec.layers.5.self_attn.v_proj",
"llm2vec.layers.5.self_attn.o_proj",
"llm2vec.layers.5.mlp.gate_proj",
"llm2vec.layers.5.mlp.up_proj",
"llm2vec.layers.5.mlp.down_proj",
"llm2vec.layers.6.self_attn.q_proj",
"llm2vec.layers.6.self_attn.k_proj",
"llm2vec.layers.6.self_attn.v_proj",
"llm2vec.layers.6.self_attn.o_proj",
"llm2vec.layers.6.mlp.gate_proj",
"llm2vec.layers.6.mlp.up_proj",
"llm2vec.layers.6.mlp.down_proj",
"llm2vec.layers.7.self_attn.q_proj",
"llm2vec.layers.7.self_attn.k_proj",
"llm2vec.layers.7.self_attn.v_proj",
"llm2vec.layers.7.self_attn.o_proj",
"llm2vec.layers.7.mlp.gate_proj",
"llm2vec.layers.7.mlp.up_proj",
"llm2vec.layers.7.mlp.down_proj",
"llm2vec.layers.8.self_attn.q_proj",
"llm2vec.layers.8.self_attn.k_proj",
"llm2vec.layers.8.self_attn.v_proj",
"llm2vec.layers.8.self_attn.o_proj",
"llm2vec.layers.8.mlp.gate_proj",
"llm2vec.layers.8.mlp.up_proj",
"llm2vec.layers.8.mlp.down_proj",
"llm2vec.layers.9.self_attn.q_proj",
"llm2vec.layers.9.self_attn.k_proj",
"llm2vec.layers.9.self_attn.v_proj",
"llm2vec.layers.9.self_attn.o_proj",
"llm2vec.layers.9.mlp.gate_proj",
"llm2vec.layers.9.mlp.up_proj",
"llm2vec.layers.9.mlp.down_proj",
"llm2vec.layers.10.self_attn.q_proj",
"llm2vec.layers.10.self_attn.k_proj",
"llm2vec.layers.10.self_attn.v_proj",
"llm2vec.layers.10.self_attn.o_proj",
"llm2vec.layers.10.mlp.gate_proj",
"llm2vec.layers.10.mlp.up_proj",
"llm2vec.layers.10.mlp.down_proj",
"llm2vec.layers.11.self_attn.q_proj",
"llm2vec.layers.11.self_attn.k_proj",
"llm2vec.layers.11.self_attn.v_proj",
"llm2vec.layers.11.self_attn.o_proj",
"llm2vec.layers.11.mlp.gate_proj",
"llm2vec.layers.11.mlp.up_proj",
"llm2vec.layers.11.mlp.down_proj",
"llm2vec.layers.12.self_attn.q_proj",
"llm2vec.layers.12.self_attn.k_proj",
"llm2vec.layers.12.self_attn.v_proj",
"llm2vec.layers.12.self_attn.o_proj",
"llm2vec.layers.12.mlp.gate_proj",
"llm2vec.layers.12.mlp.up_proj",
"llm2vec.layers.12.mlp.down_proj",
"llm2vec.layers.13.self_attn.q_proj",
"llm2vec.layers.13.self_attn.k_proj",
"llm2vec.layers.13.self_attn.v_proj",
"llm2vec.layers.13.self_attn.o_proj",
"llm2vec.layers.13.mlp.gate_proj",
"llm2vec.layers.13.mlp.up_proj",
"llm2vec.layers.13.mlp.down_proj",
"llm2vec.layers.14.self_attn.q_proj",
"llm2vec.layers.14.self_attn.k_proj",
"llm2vec.layers.14.self_attn.v_proj",
"llm2vec.layers.14.self_attn.o_proj",
"llm2vec.layers.14.mlp.gate_proj",
"llm2vec.layers.14.mlp.up_proj",
"llm2vec.layers.14.mlp.down_proj",
"llm2vec.layers.15.self_attn.q_proj",
"llm2vec.layers.15.self_attn.k_proj",
"llm2vec.layers.15.self_attn.v_proj",
"llm2vec.layers.15.self_attn.o_proj",
"llm2vec.layers.15.mlp.gate_proj",
"llm2vec.layers.15.mlp.up_proj",
"llm2vec.layers.15.mlp.down_proj",
"llm2vec.layers.16.self_attn.q_proj",
"llm2vec.layers.16.self_attn.k_proj",
"llm2vec.layers.16.self_attn.v_proj",
"llm2vec.layers.16.self_attn.o_proj",
"llm2vec.layers.16.mlp.gate_proj",
"llm2vec.layers.16.mlp.up_proj",
"llm2vec.layers.16.mlp.down_proj",
"llm2vec.layers.17.self_attn.q_proj",
"llm2vec.layers.17.self_attn.k_proj",
"llm2vec.layers.17.self_attn.v_proj",
"llm2vec.layers.17.self_attn.o_proj",
"llm2vec.layers.17.mlp.gate_proj",
"llm2vec.layers.17.mlp.up_proj",
"llm2vec.layers.17.mlp.down_proj",
"llm2vec.layers.18.self_attn.q_proj",
"llm2vec.layers.18.self_attn.k_proj",
"llm2vec.layers.18.self_attn.v_proj",
"llm2vec.layers.18.self_attn.o_proj",
"llm2vec.layers.18.mlp.gate_proj",
"llm2vec.layers.18.mlp.up_proj",
"llm2vec.layers.18.mlp.down_proj",
"llm2vec.layers.19.self_attn.q_proj",
"llm2vec.layers.19.self_attn.k_proj",
"llm2vec.layers.19.self_attn.v_proj",
"llm2vec.layers.19.self_attn.o_proj",
"llm2vec.layers.19.mlp.gate_proj",
"llm2vec.layers.19.mlp.up_proj",
"llm2vec.layers.19.mlp.down_proj",
"llm2vec.layers.20.self_attn.q_proj",
"llm2vec.layers.20.self_attn.k_proj",
"llm2vec.layers.20.self_attn.v_proj",
"llm2vec.layers.20.self_attn.o_proj",
"llm2vec.layers.20.mlp.gate_proj",
"llm2vec.layers.20.mlp.up_proj",
"llm2vec.layers.20.mlp.down_proj",
"llm2vec.layers.21.self_attn.q_proj",
"llm2vec.layers.21.self_attn.k_proj",
"llm2vec.layers.21.self_attn.v_proj",
"llm2vec.layers.21.self_attn.o_proj",
"llm2vec.layers.21.mlp.gate_proj",
"llm2vec.layers.21.mlp.up_proj",
"llm2vec.layers.21.mlp.down_proj",
"llm2vec.layers.22.self_attn.q_proj",
"llm2vec.layers.22.self_attn.k_proj",
"llm2vec.layers.22.self_attn.v_proj",
"llm2vec.layers.22.self_attn.o_proj",
"llm2vec.layers.22.mlp.gate_proj",
"llm2vec.layers.22.mlp.up_proj",
"llm2vec.layers.22.mlp.down_proj",
"llm2vec.layers.23.self_attn.q_proj",
"llm2vec.layers.23.self_attn.k_proj",
"llm2vec.layers.23.self_attn.v_proj",
"llm2vec.layers.23.self_attn.o_proj",
"llm2vec.layers.23.mlp.gate_proj",
"llm2vec.layers.23.mlp.up_proj",
"llm2vec.layers.23.mlp.down_proj",
"llm2vec.layers.24.self_attn.q_proj",
"llm2vec.layers.24.self_attn.k_proj",
"llm2vec.layers.24.self_attn.v_proj",
"llm2vec.layers.24.self_attn.o_proj",
"llm2vec.layers.24.mlp.gate_proj",
"llm2vec.layers.24.mlp.up_proj",
"llm2vec.layers.24.mlp.down_proj",
"llm2vec.layers.25.self_attn.q_proj",
"llm2vec.layers.25.self_attn.k_proj",
"llm2vec.layers.25.self_attn.v_proj",
"llm2vec.layers.25.self_attn.o_proj",
"llm2vec.layers.25.mlp.gate_proj",
"llm2vec.layers.25.mlp.up_proj",
"llm2vec.layers.25.mlp.down_proj",
"llm2vec.layers.26.self_attn.q_proj",
"llm2vec.layers.26.self_attn.k_proj",
"llm2vec.layers.26.self_attn.v_proj",
"llm2vec.layers.26.self_attn.o_proj",
"llm2vec.layers.26.mlp.gate_proj",
"llm2vec.layers.26.mlp.up_proj",
"llm2vec.layers.26.mlp.down_proj",
"llm2vec.layers.27.self_attn.q_proj",
"llm2vec.layers.27.self_attn.k_proj",
"llm2vec.layers.27.self_attn.v_proj",
"llm2vec.layers.27.self_attn.o_proj",
"llm2vec.layers.27.mlp.gate_proj",
"llm2vec.layers.27.mlp.up_proj",
"llm2vec.layers.27.mlp.down_proj",
"llm2vec.layers.28.self_attn.q_proj",
"llm2vec.layers.28.self_attn.k_proj",
"llm2vec.layers.28.self_attn.v_proj",
"llm2vec.layers.28.self_attn.o_proj",
"llm2vec.layers.28.mlp.gate_proj",
"llm2vec.layers.28.mlp.up_proj",
"llm2vec.layers.28.mlp.down_proj",
"llm2vec.layers.29.self_attn.q_proj",
"llm2vec.layers.29.self_attn.k_proj",
"llm2vec.layers.29.self_attn.v_proj",
"llm2vec.layers.29.self_attn.o_proj",
"llm2vec.layers.29.mlp.gate_proj",
"llm2vec.layers.29.mlp.up_proj",
"llm2vec.layers.29.mlp.down_proj",
"llm2vec.layers.30.self_attn.q_proj",
"llm2vec.layers.30.self_attn.k_proj",
"llm2vec.layers.30.self_attn.v_proj",
"llm2vec.layers.30.self_attn.o_proj",
"llm2vec.layers.30.mlp.gate_proj",
"llm2vec.layers.30.mlp.up_proj",
"llm2vec.layers.30.mlp.down_proj",
"llm2vec.layers.31.self_attn.q_proj",
"llm2vec.layers.31.self_attn.k_proj",
"llm2vec.layers.31.self_attn.v_proj",
"llm2vec.layers.31.self_attn.o_proj",
"llm2vec.layers.31.mlp.gate_proj",
"llm2vec.layers.31.mlp.up_proj",
"llm2vec.layers.31.mlp.down_proj",
],
)
peft_model = get_peft_model(model, lora_config)
peft_model.save_pretrained("../nllb-llm2vec-saved")
import json
with open("./model.safetensors.index.json", "r") as f:
print(json.load(f))