Efficient Long Context Language Model Retrieval with Compression
- Paper : Efficient Long Context Language Model Retrieval with Compression, ACL 2025 Main
- Github Repository : CoLoR
- Finetuned from model : Phi-3-mini-4k-instruct
Abstract
We propose a new compression approach tailored for Long Context Language Model (LCLM) retrieval, which is trained to maximize the retrieval performance while minimizing the length of the compressed passages. To accomplish this, we generate the synthetic data, where compressed passages are automatically created and labeled as chosen or rejected according to their retrieval success for a given query, and we train the proposed Compression model for Long context Retrieval (CoLoR) with this data via preference optimization while adding the length regularization loss on top of it to enforce brevity. Through extensive experiments on 9 datasets, we show that CoLoR improves the retrieval performance by 6% while compressing the in-context size by a factor of 1.91.
π Quick Start
Install dependencies
pip install transformers peft torch jsonlines tqdm
π§ͺ Example Input Format
Your input file must be .jsonl, where each line has the field passage_text:
{
"id": "id-123",
"passage_text": "Lorem ipsum dolor sit amet..."
}
π Output Format
The script will add a new field: "summary_text"
βοΈ FULL INFERENCE CODE (main.py)
from peft import PeftModel
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import jsonlines
from tqdm import tqdm
import json
import copy
import argparse
torch.random.manual_seed(0)
def main(model_id, peft_model, max_length, input_filepath, output_filepath):
assert torch.cuda.is_available(), "This model needs a GPU to run ..."
device = 'cuda'
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
if model_id == "google/gemma-2-2b-it" or (model_id.startswith("/") and model_id.split("/")[-1].startswith("gemma")):
print(f"model_id: {model_id}\t torch_dtype=torch.bfloat16")
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
trust_remote_code=True,
device_map="auto"
)
else:
print(f"model_id: {model_id}\t torch_dtype=torch.auto, flash_attention_2")
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype="auto",
trust_remote_code=True,
attn_implementation="flash_attention_2",
device_map="auto"
)
if len(peft_model) > 0:
finetuned_model = PeftModel.from_pretrained(
model,
peft_model,
torch_dtype=torch.float16,
is_trainable=False,
device_map="auto"
)
finetuned_model = finetuned_model.merge_and_unload()
else:
finetuned_model = model
input_lines = []
with jsonlines.open(input_filepath) as f:
for line in f.iter():
input_lines.append(line)
output_lines = []
for line in tqdm(input_lines):
messages = [
{
"role": "user",
"content": f"Summarize the following content.\nContent:\n{line['passage_text']}\nSummary:"
},
]
inputs = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
return_tensors="pt",
max_length=max_length
).to(device)
output_json = copy.deepcopy(line)
try:
outputs = finetuned_model.generate(
inputs,
max_new_tokens=512,
do_sample=False,
num_return_sequences=1,
eos_token_id=tokenizer.eos_token_id,
use_cache=False
)
sample_output = outputs[0]
decoded_text = tokenizer.decode(sample_output[len(inputs[0]):], skip_special_tokens=True)
output_json['summary_text'] = decoded_text
except Exception as e:
print(e)
output_json['summary_text'] = ''
output_lines.append(output_json)
with open(output_filepath, 'w') as f:
for line in output_lines:
json.dump(line, f)
f.write('\n')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--model_id", default="microsoft/Phi-3-mini-4k-instruct")
parser.add_argument("--peft_model", default="")
parser.add_argument("--max_length", default=2048)
parser.add_argument("--input_filepath")
parser.add_argument("--output_filepath")
args = parser.parse_args()
main(args.model_id, args.peft_model, args.max_length, args.input_filepath, args.output_filepath)
π§ Example Command
python main.py \
--model_id "iaminju/CoLoR-Phi-3-mini-4k-instruct" \
--max_length 20000 \
--input_filepath "/path/to/input.jsonl" \
--output_filepath "/path/to/output.jsonl"
Citation
@inproceedings{Seo2025CoLoR,
author = {Minju Seo and
Jinheon Baek and
Seongyun Lee and
Sung Ju Hwang},
editor = {Wanxiang Che and
Joyce Nabende and
Ekaterina Shutova and
Mohammad Taher Pilehvar},
title = {Efficient Long Context Language Model Retrieval with Compression},
booktitle = {Proceedings of the 63rd Annual Meeting of the Association for Computational
Linguistics (Volume 1: Long Papers), {ACL} 2025, Vienna, Austria,
July 27 - August 1, 2025},
pages = {15251--15268},
publisher = {Association for Computational Linguistics},
year = {2025},
url = {https://aclanthology.org/2025.acl-long.740/},
timestamp = {Wed, 24 Sep 2025 15:22:07 +0200},
biburl = {https://dblp.org/rec/conf/acl/SeoBLH25.bib},
bibsource = {dblp computer science bibliography, https://dblp.org}
}
- Downloads last month
- 27