splade-code-8B / splade.py
Tom Aarsen
Patch loading SparseEncoder from Hub
ebcd7f4
raw
history blame
6.5 kB
"""
Compared to standard Qwen3, we're using bidirectional attention and not causal attention, but it's specified
with `is_causal=False` in the config.
This file supports two loading paths:
1. Sentence Transformers: `SparseEncoder("naver/splade-code-8B", trust_remote_code=True)` via SpladeCodeMLMTransformer -> AutoModelForMaskedLM -> Qwen3ForCausalLM
2. Transformers: `AutoModelForCausalLM.from_pretrained("naver/splade-code-8B", trust_remote_code=True)` -> Splade
The checkpoint is distributed as a LoRA adapter on top of Qwen/Qwen3-8B; `Qwen3ForCausalLM.from_pretrained`
loads the base model and applies the adapter.
"""
import torch
from transformers import Qwen3ForCausalLM as TransformersQwen3ForCausalLM
from transformers import PretrainedConfig, PreTrainedModel, AutoConfig
from transformers.utils import is_flash_attn_2_available
from .utils import prepare_tokenizer, splade_max, similarity, encode
class Qwen3ForCausalLM(TransformersQwen3ForCausalLM):
def tie_weights(self, *args, **kwargs):
"""Explicitly re-tie lm_head to embed_tokens to hopefully avoid meta tensor errors."""
if (
self.config.tie_word_embeddings
and hasattr(self, "lm_head")
and hasattr(self, "model")
):
self.lm_head.weight = self.model.embed_tokens.weight
missing_keys = kwargs.get("missing_keys")
if missing_keys is not None:
missing_keys.discard("lm_head.weight")
else:
super().tie_weights(*args, **kwargs)
def _init_weights(self, module):
"""Skip lm_head init when it will be tied to embed_tokens later."""
if module is getattr(self, "lm_head", None) and self.config.tie_word_embeddings:
return
super()._init_weights(module)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
from peft import PeftConfig, PeftModel
try:
peft_config = PeftConfig.from_pretrained(
pretrained_model_name_or_path, token=kwargs.get("token")
)
except Exception:
peft_config = None
if peft_config is None:
return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
# Use provided splade config (has is_causal=False) or load it from the adapter repo
config = kwargs.pop("config", None)
if config is None or not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(
pretrained_model_name_or_path, token=kwargs.get("token")
)
# We apply the adapter manually below, so drop any auto-PEFT hints to avoid double loading
kwargs.pop("adapter_kwargs", None)
base_model = super().from_pretrained(
peft_config.base_model_name_or_path,
*model_args,
config=config,
**kwargs,
)
return PeftModel.from_pretrained(
base_model, pretrained_model_name_or_path, token=kwargs.get("token")
)
class SpladeConfig(PretrainedConfig):
model_type = "qwen3"
def __init__(
self,
model_name_or_path: str = "Qwen/Qwen3-8B",
attn_implementation: str = "flash_attention_2",
bidirectional: bool = True, # only for decoder models
padding_side: str = "left",
**kwargs,
):
super().__init__(**kwargs)
self.model_name_or_path = model_name_or_path
self.attn_implementation = attn_implementation
self.bidirectional = bidirectional
self.padding_side = padding_side
class Splade(PreTrainedModel):
config_class = SpladeConfig
# methods for MTEB's interface
similarity = similarity
encode = encode
def __init__(self, config, weights_path=None, token=None):
super().__init__(config)
self.name = "splade"
base_cfg = AutoConfig.from_pretrained(
weights_path,
attn_implementation=config.attn_implementation,
torch_dtype="auto",
token=token,
)
self.tokenizer = prepare_tokenizer(
weights_path, padding_side=config.padding_side
)
if is_flash_attn_2_available():
config.attn_implementation = "flash_attention_2"
else:
config.attn_implementation = "sdpa"
self.model = Qwen3ForCausalLM.from_pretrained(
weights_path,
config=base_cfg,
torch_dtype=torch.bfloat16,
attn_implementation=config.attn_implementation,
token=token,
)
def save_pretrained(self, save_directory, *args, **kwargs):
self.model.save_pretrained(save_directory)
self.config.save_pretrained(save_directory)
@classmethod
def from_pretrained(cls, model_name_or_path, *args, **kwargs):
token = kwargs.get("token", None)
config = SpladeConfig.from_pretrained(
model_name_or_path,
token=token,
)
model = cls(config, weights_path=model_name_or_path, token=token)
model.reverse_voc = {v: k for k, v in model.tokenizer.vocab.items()}
return model
def forward(self, **tokens):
output = self.model(**tokens)
splade_reps, _ = splade_max(output.logits, tokens["attention_mask"])
return (splade_reps,)
def get_width(self):
return self.model.config.vocab_size
def create_batch_dict(self, input_texts, max_length):
return self.tokenizer(
input_texts,
add_special_tokens=True,
padding="longest",
truncation=True,
max_length=max_length,
return_attention_mask=True,
return_tensors="pt",
)
__all__ = ["Qwen3ForCausalLM", "Splade"]
# Override ST's `_load_config` to return our `Qwen3Config` (with `auto_map`)
# instead of a `PeftConfig`, so hub-path loads route to `splade.Qwen3ForCausalLM`
# instead of failing in `AutoModelForMaskedLM`. The LoRA is still applied by
# transformers' built-in PEFT path.
try:
from sentence_transformers.sparse_encoder.models import MLMTransformer
class SpladeCodeMLMTransformer(MLMTransformer):
def _load_config(self, model_name_or_path, backend, config_kwargs):
return AutoConfig.from_pretrained(model_name_or_path, **config_kwargs), False
__all__.append("SpladeCodeMLMTransformer")
except ImportError:
pass