Efficient Long Context Language Model Retrieval with Compression

Abstract

We propose a new compression approach tailored for Long Context Language Model (LCLM) retrieval, which is trained to maximize the retrieval performance while minimizing the length of the compressed passages. To accomplish this, we generate the synthetic data, where compressed passages are automatically created and labeled as chosen or rejected according to their retrieval success for a given query, and we train the proposed Compression model for Long context Retrieval (CoLoR) with this data via preference optimization while adding the length regularization loss on top of it to enforce brevity. Through extensive experiments on 9 datasets, we show that CoLoR improves the retrieval performance by 6% while compressing the in-context size by a factor of 1.91.

πŸš€ Quick Start

Install dependencies

pip install transformers peft torch jsonlines tqdm

πŸ§ͺ Example Input Format

Your input file must be .jsonl, where each line has the field passage_text:

{
  "id": "id-123",
  "passage_text": "Lorem ipsum dolor sit amet..."
}

πŸ“Œ Output Format

The script will add a new field: "summary_text"


βš™οΈ FULL INFERENCE CODE (main.py)

from peft import PeftModel
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import jsonlines
from tqdm import tqdm
import json
import copy
import argparse

torch.random.manual_seed(0)

def main(model_id, peft_model, max_length, input_filepath, output_filepath):

    assert torch.cuda.is_available(), "This model needs a GPU to run ..."
    device = 'cuda'

    tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)

    if model_id == "google/gemma-2-2b-it" or (model_id.startswith("/") and model_id.split("/")[-1].startswith("gemma")):
        print(f"model_id: {model_id}\t torch_dtype=torch.bfloat16")
        model = AutoModelForCausalLM.from_pretrained(
            model_id,
            torch_dtype=torch.bfloat16,
            trust_remote_code=True,
            device_map="auto"
        )
    else:
        print(f"model_id: {model_id}\t torch_dtype=torch.auto, flash_attention_2")
        model = AutoModelForCausalLM.from_pretrained(
            model_id,
            torch_dtype="auto",
            trust_remote_code=True,
            attn_implementation="flash_attention_2",
            device_map="auto"
        )

    if len(peft_model) > 0:
        finetuned_model = PeftModel.from_pretrained(
            model,
            peft_model,
            torch_dtype=torch.float16,
            is_trainable=False,
            device_map="auto"
        )
        finetuned_model = finetuned_model.merge_and_unload()
    else:
        finetuned_model = model

    input_lines = []
    with jsonlines.open(input_filepath) as f:
        for line in f.iter():
            input_lines.append(line)

    output_lines = []
    for line in tqdm(input_lines):
        messages = [
            {
                "role": "user",
                "content": f"Summarize the following content.\nContent:\n{line['passage_text']}\nSummary:"
            },
        ]

        inputs = tokenizer.apply_chat_template(
            messages,
            add_generation_prompt=True,
            return_tensors="pt",
            max_length=max_length
        ).to(device)

        output_json = copy.deepcopy(line)

        try:
            outputs = finetuned_model.generate(
                inputs,
                max_new_tokens=512,
                do_sample=False,
                num_return_sequences=1,
                eos_token_id=tokenizer.eos_token_id,
                use_cache=False
            )

            sample_output = outputs[0]
            decoded_text = tokenizer.decode(sample_output[len(inputs[0]):], skip_special_tokens=True)
            output_json['summary_text'] = decoded_text

        except Exception as e:
            print(e)
            output_json['summary_text'] = ''

        output_lines.append(output_json)

    with open(output_filepath, 'w') as f:
        for line in output_lines:
            json.dump(line, f)
            f.write('\n')

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_id", default="microsoft/Phi-3-mini-4k-instruct")
    parser.add_argument("--peft_model", default="")
    parser.add_argument("--max_length", default=2048)
    parser.add_argument("--input_filepath")
    parser.add_argument("--output_filepath")

    args = parser.parse_args()
    main(args.model_id, args.peft_model, args.max_length, args.input_filepath, args.output_filepath)

🧠 Example Command

python main.py \
    --model_id "iaminju/CoLoR-Phi-3-mini-4k-instruct" \
    --max_length 20000 \
    --input_filepath "/path/to/input.jsonl" \
    --output_filepath "/path/to/output.jsonl"

Citation

@inproceedings{Seo2025CoLoR,
  author       = {Minju Seo and
                  Jinheon Baek and
                  Seongyun Lee and
                  Sung Ju Hwang},
  editor       = {Wanxiang Che and
                  Joyce Nabende and
                  Ekaterina Shutova and
                  Mohammad Taher Pilehvar},
  title        = {Efficient Long Context Language Model Retrieval with Compression},
  booktitle    = {Proceedings of the 63rd Annual Meeting of the Association for Computational
                  Linguistics (Volume 1: Long Papers), {ACL} 2025, Vienna, Austria,
                  July 27 - August 1, 2025},
  pages        = {15251--15268},
  publisher    = {Association for Computational Linguistics},
  year         = {2025},
  url          = {https://aclanthology.org/2025.acl-long.740/},
  timestamp    = {Wed, 24 Sep 2025 15:22:07 +0200},
  biburl       = {https://dblp.org/rec/conf/acl/SeoBLH25.bib},
  bibsource    = {dblp computer science bibliography, https://dblp.org}
}
Downloads last month
27
Safetensors
Model size
4B params
Tensor type
BF16
Β·
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Model tree for iaminju/CoLoR-Phi-3-mini-4k-instruct

Quantizations
2 models