license: other
license_name: health-ai-developer-foundations
license_link: https://developers.google.com/health-ai-developer-foundations/terms
language:
- en
tags:
- vision-language
- pathology
- computational-pathology
- whole-slide-imaging
base_model:
- google/medgemma-4b-it
pipeline_tag: image-text-to-text
library_name: transformers
ANTONI-Alpha (Pretrain)
Note: This is the pretrain checkpoint (Stage 1: projector alignment only). For the instruction-tuned model, use the main branch.
Resources
- Paper: OpenReview (under review)
- Code: GitHub
- Dataset: SaltySander/HISTAI-Instruct
- Data Generation Framework: Polysome
- Base Model: MedGemma-4B-IT
Authors
Sander Moonemans, Sebastiaan Ram, Frédérique Meeuwsen, Jeroen van der Laak, Geert Litjens, Francesco Ciompi
Model Information
This checkpoint represents Stage 1 of ANTONI-Alpha training: projector alignment with a frozen language model.
Architecture:
- Vision encoder: Prism (1280-dim tile embeddings)
- Language model: MedGemma-2B (4-bit quantized, frozen)
- Projector: Cross-attention (256 output tokens, 8 query/KV heads)
Training (Stage 1: Projector Alignment)
- Frozen LLM, trainable projector
- Dataset: Clean pathology reports from HISTAI-Instruct
- 21 epochs, batch size 32, learning rate 1e-4
Installation
pip install git+https://github.com/computationalpathologygroup/ANTONI-Alpha.git
Optional: Flash Attention 2
For improved performance on compatible hardware, install Flash Attention 2:
pip install flash-attn==2.8.3 --no-build-isolation
The --no-build-isolation flag allows the build process to use your installed PyTorch. Flash Attention 2 requires CUDA-capable hardware and will be used automatically if installed.
How to Use
import torch
from antoni_alpha.models.antoni_pretrained import AntoniAlphaPreTrained
# Load pretrain checkpoint
model = AntoniAlphaPreTrained.from_pretrained(
"SaltySander/ANTONI-Alpha",
revision="pretrain",
device_map="auto",
torch_dtype=torch.bfloat16
)
# Load slide features (Prism embeddings)
slide_features = torch.load("slide_features.pt") # [num_tiles, 1280]
slide_latents = slide_features.unsqueeze(0)
slide_latents = slide_latents.to(next(model.projection_layer.parameters()).device)
# Inference
conversation = [{"role": "user", "content": "Describe this tissue."}]
with torch.no_grad():
output_ids = model.generate(
slide_latents=slide_latents,
conversations=[conversation],
max_new_tokens=200,
do_sample=False,
)
response = model.processor.batch_decode(output_ids, skip_special_tokens=True)[0]
print(response)
Input/Output
Input:
slide_latents: Tensor[batch_size, num_tiles, 1280](Prism embeddings)conversations: List of conversations in OpenAI format
Output:
- Generated text response
Note: This model requires pre-extracted Prism embeddings. It does not process raw WSI images end-to-end.
Citation
@inproceedings{moonemans2025open,
title = {Open Instruction Tuning for Whole-Slide Digital Pathology},
author = {Sander Moonemans and Sebastiaan Ram and Fr{\'e}d{\'e}rique Meeuwsen and Jeroen van der Laak and Geert Litjens and Francesco Ciompi},
booktitle = {Submitted to Medical Imaging with Deep Learning},
year = {2025},
url = {https://openreview.net/forum?id=aGPowreqPi},
note = {under review}
}
License
This model is released under the Health AI Developer Foundations License.