EDJE: Efficient Discriminative Joint Encoders for Large Scale Vision-Language Re-ranking
A high-throughput vision-language re-ranker that combines SigLIP vision encoder with MiniLM for efficient image-text matching at scale.
Overview
Multimodal retrieval typically relies on embedding-based models like CLIP for fast vector search over pre-computed image embeddings. However, unlike text retrieval where joint-encoder re-rankers are standard, comparable vision-language re-rankers have been largely absent due to efficiency bottlenecks.
EDJE addresses this gap by introducing an Efficient Discriminative Joint Encoder that:
- Precomputes vision tokens offline - Images are encoded once and stored on disk
- Compresses visual features via a lightweight attention-based adapter using learnable queries
- Runs only a compact joint encoder online over a small set of visual tokens plus text
This design enables fine-grained cross-modal interactions (unlike embedding-only models that simply compare vectors) while maintaining the efficiency required for large-scale retrieval.
Why Re-ranking?
Embedding-based models (CLIP, SigLIP) enable efficient similarity search through simple vector comparisons, but they process image and text independently. Joint encoders process both modalities together, allowing richer cross-modal interactions that can significantly improve retrieval quality.
EDJE is designed as a second-stage re-ranker: given the top-k candidates retrieved by an embedding model, EDJE scores each image-text pair to produce a refined ranking.
Key Features
| Feature | Value |
|---|---|
| Throughput | ~50k image-text pairs/second |
| Storage | ~49 kB per image (64 compressed tokens) |
| Compression | 576 β 64 tokens via attention-based adapter |
Architecture
- Vision Encoder: SigLIP2 ViT-L/16 @ 384px (
google/siglip2-large-patch16-384) - Language Model: MiniLM-L12-H384 (
microsoft/MiniLM-L12-H384-uncased) - Token Compression: Cross-attention adapter with 64 learnable queries
The model is split into two components for efficient deployment:
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
β OFFLINE (Index Time) β
β ββββββββββββ βββββββββββββββββββββββ βββββββββββββββββββββ β
β β Image βββββΆβ SigLIP ViT-L/16 βββββΆβ Token Compression ββββΆ Store
β ββββββββββββ β (576 tokens) β β Adapter (64 tok) β β
β βββββββββββββββββββββββ βββββββββββββββββββββ β
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
β ONLINE (Query Time) β
β ββββββββββββββββββββ ββββββββββββββββββββ βββββββββββββββββ β
β β Compressed TokensβββββΆβ β β β β
β β (from index) β β MiniLM Joint βββββΆβ Matching β β
β ββββββββββββββββββββ€ββββΆβ Encoder β β Score β β
β β Text Query β β β β β β
β ββββββββββββββββββββ ββββββββββββββββββββ βββββββββββββββββ β
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
Model Components
- EDJEModelForIndexing: Encodes images into compressed visual tokens (offline, run once per image)
- EDJEModelForRanking: Scores image-text pairs using pre-computed visual tokens (online, run at query time)
Usage
import torch
import requests
from PIL import Image
from io import BytesIO
from transformers import AutoProcessor, AutoTokenizer
from huggingface_hub import hf_hub_download
from pretrain_model import EDJEModelForIndexing, EDJEModelForRanking
# =============================================================================
# Download and load model checkpoint
# =============================================================================
checkpoint_path = hf_hub_download(
repo_id="shahafw/edje-vl-image-retrieval-reranker",
filename="pytorch_model.pth"
)
checkpoint = torch.load(checkpoint_path, map_location="cpu")
# Initialize both models (must match the checkpoint's architecture)
indexing_model = EDJEModelForIndexing(
siglip_path="google/siglip2-large-patch16-384", # Large model, 384px
language_model_path="microsoft/MiniLM-L12-H384-uncased",
num_compressed_tokens=64,
)
ranking_model = EDJEModelForRanking(
language_model_path="microsoft/MiniLM-L12-H384-uncased",
num_compressed_tokens=64,
)
# Load trained weights (strict=False since each model only uses a subset of weights)
indexing_model.load_state_dict(checkpoint["model"], strict=False)
ranking_model.load_state_dict(checkpoint["model"], strict=False)
indexing_model.eval()
ranking_model.eval()
# =============================================================================
# OFFLINE INDEXING PHASE
# Run once per image - encode and store compressed tokens in your database/index
# =============================================================================
# Load an image from URL
image_url = "https://images.unsplash.com/photo-1558788353-f76d92427f16?w=400"
response = requests.get(image_url)
image = Image.open(BytesIO(response.content)).convert("RGB")
# Process image
processor = AutoProcessor.from_pretrained("google/siglip2-large-patch16-384", use_fast=True)
pixel_values = processor(images=image, return_tensors="pt")["pixel_values"]
# Generate compressed visual tokens (store these in your vector index)
with torch.no_grad():
compressed_tokens = indexing_model(pixel_values)
# compressed_tokens shape: (1, 64, 384) β ~49 kB per image
print(f"Compressed tokens shape: {compressed_tokens.shape}")
# =============================================================================
# ONLINE RE-RANKING PHASE
# Given candidates retrieved by first-stage model (e.g., CLIP), re-rank them
# =============================================================================
# Example: rank the image against multiple candidate captions
candidate_captions = [
"a cute golden retriever puppy",
"a cat sleeping on a sofa",
"a beautiful sunset over the ocean",
]
# Tokenize captions
tokenizer = AutoTokenizer.from_pretrained("microsoft/MiniLM-L12-H384-uncased")
text_inputs = tokenizer(
candidate_captions,
padding=True,
truncation=True,
max_length=64,
return_tensors="pt"
)
# Expand compressed tokens to match batch size (one image vs multiple captions)
compressed_tokens_batch = compressed_tokens.expand(len(candidate_captions), -1, -1)
# Compute matching scores for re-ranking
with torch.no_grad():
scores = ranking_model(
compressed_tokens=compressed_tokens_batch,
input_ids=text_inputs["input_ids"],
attention_mask=text_inputs["attention_mask"],
)
# Display re-ranked results (higher score = better match)
print("\nRe-ranked Results:")
for caption, score in sorted(zip(candidate_captions, scores.tolist()), key=lambda x: x[1], reverse=True):
print(f" {score:.4f}: {caption}")
Typical Retrieval Pipeline
Query βββΆ First-Stage Retrieval (CLIP/SigLIP) βββΆ Top-K Candidates βββΆ EDJE Re-ranking βββΆ Final Results
(fast, embedding-based) (accurate, joint-encoder)
- First stage: Use an embedding model (CLIP, SigLIP) to retrieve top-K candidates via approximate nearest neighbor search
- Second stage: Use EDJE to re-rank the candidates with fine-grained cross-modal scoring
Training
This model was trained using:
- Image-Text Matching (ITM) loss - Binary classification of matched/unmatched pairs
- Image-Text Contrastive (ITC) loss - Alignment with SigLIP's embedding space
- Masked Language Modeling (MLM) loss - Language understanding
- Knowledge distillation - From a larger teacher model
Results
Main Retrieval Performance (Recall@1)
EDJE matches prior joint encoders while being up to 53Γ faster with 36Γ less storage:
| Method | Training Data | Flickr-ZS | COCO-FT | Storage | Params | Inference | ||
|---|---|---|---|---|---|---|---|---|
| T2I | I2T | T2I | I2T | per image | time (ms) | |||
| Prior Joint Encoders | ||||||||
| ALBEF ViT-B/16 | 12M | 82.8 | 94.1 | 60.7 | 77.6 | 1,769 kB | 147M | 45.92 |
| BLIP ViT-B/16 | 12M | 84.9 | 94.8 | 63.1 | 80.6 | 1,769 kB | 139M | 83.27 |
| BLIP ViT-L/16 | 129M | 86.7 | 96.7 | 65.1 | 82.4 | 2,359 kB | 139M | 101.61 |
| EDJE (Ours) | ||||||||
| Local ViT-B/16 | 12M | 84.3 | 94.3 | 60.9 | 76.1 | 442 kB | 33M | 2.86 |
| Local ViT-L/16 | 12M | 87.8 | 96.5 | 64.9 | 81.0 | 442 kB | 33M | 4.14 |
| Compressed-128 ViT-L/16 | 12M | 87.1 | 96.3 | 64.6 | 81.0 | 98 kB | 33M | 2.04 |
| Compressed-64 ViT-L/16 | 12M | 86.9 | 96.4 | 64.6 | 80.9 | 49 kB | 33M | 1.91 |
Full Dataset Retrieval
Evaluation on the full Flickr and COCO datasets (retrieval against all images/captions):
Flickr Full (Zero-Shot)
| Model | R@5 | R@10 | R@20 | R@5 | R@10 | R@20 |
|---|---|---|---|---|---|---|
| TextβImage | ImageβText | |||||
| LightningDOT | 60.1 | 69.5 | 78.3 | 75.1 | 83.9 | 90.5 |
| EDJE | 78.3 | 84.5 | 89.6 | 92.4 | 95.9 | 97.7 |
MS-COCO Full (Fine-tuned)
| Model | R@5 | R@10 | R@20 | R@5 | R@10 | R@20 |
|---|---|---|---|---|---|---|
| TextβImage | ImageβText | |||||
| LightningDOT | 37.3 | 46.8 | 56.4 | 48.0 | 59.0 | 68.9 |
| EDJE | 52.2 | 60.6 | 68.1 | 69.9 | 77.0 | 82.6 |
Efficiency Comparison
| Metric | EDJE | Prior Joint Encoders (BLIP) |
|---|---|---|
| Throughput | ~50k pairs/sec | ~1k pairs/sec |
| Relative Speedup | Up to 53Γ faster | 1Γ (baseline) |
| Storage per image | 49 kB | ~1.2 MB (full ViT features) |
| Online ViT forward | β Not needed | β Required |
Citation
@inproceedings{edje2026,
author = {Mitchell Keren Taraday, Shahaf Wagner, Chaim Baskin},
title = {Efficient Discriminative Joint Encoders for Large Scale Vision-Language Re-ranking},
booktitle = {ICLR},
year = {2026},
}
License
BSD-3-Clause
- Downloads last month
- -