Spaces:
Runtime error
Runtime error
| import base64 | |
| import hashlib | |
| import io | |
| import json | |
| import os | |
| import tempfile | |
| from collections import OrderedDict as CollectionsOrderedDict | |
| from pathlib import Path | |
| from threading import Thread | |
| from typing import Iterator, Optional, List, Union, OrderedDict | |
| import fitz | |
| import gradio as gr | |
| import requests | |
| import spaces | |
| import torch | |
| from PIL import Image | |
| from colpali_engine import ColPali, ColPaliProcessor | |
| from huggingface_hub import hf_hub_download | |
| from pydantic import BaseModel | |
| from qwen_vl_utils import process_vision_info | |
| from swift.llm import ( | |
| ModelType, | |
| get_model_tokenizer, | |
| get_default_template_type, | |
| get_template, | |
| inference, | |
| inference_stream, | |
| ) | |
| from tqdm import tqdm | |
| from transformers import ( | |
| Qwen2VLForConditionalGeneration, | |
| PreTrainedTokenizer, | |
| Qwen2VLProcessor, | |
| TextIteratorStreamer, | |
| AutoTokenizer, | |
| ) | |
| from ultralytics import YOLO | |
| from ultralytics.engine.results import Results | |
| MAX_MAX_NEW_TOKENS = 2048 | |
| DEFAULT_MAX_NEW_TOKENS = 1024 | |
| MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096")) | |
| DESCRIPTION = """\ | |
| # M-Longdoc: A Benchmark For Multimodal Super-Long Document Understanding And A Retrieval-Aware Tuning Framework | |
| This Space demonstrates the multimodal long document understanding model with 7B parameters fine-tuned for texts, tables, and figures. Feel free to play with it, or duplicate to run generations without a queue! | |
| 🔎 For more details about the project, check out the [paper](https://arxiv.org/pdf/2411.06176). | |
| """ | |
| LICENSE = """ | |
| <p/> | |
| --- | |
| As a derivate work of [Llama-3-8b-chat](https://huggingface.co/meta-llama/Meta-Llama-3-8B) by Meta, | |
| this demo is governed by the original [license](https://huggingface.co/meta-llama/Meta-Llama-3-8B/blob/main/LICENSE) and [acceptable use policy](https://huggingface.co/meta-llama/Meta-Llama-3-8B/blob/main/USE_POLICY.md). | |
| """ | |
| class MultimodalSample(BaseModel): | |
| question: str | |
| answer: str | |
| category: str | |
| evidence_pages: List[int] = [] | |
| raw_output: str = "" | |
| pred: str = "" | |
| source: str = "" | |
| annotator: str = "" | |
| generator: str = "" | |
| retrieved_pages: List[int] = [] | |
| class MultimodalObject(BaseModel): | |
| id: str = "" | |
| page: int = 0 | |
| text: str = "" | |
| image_string: str = "" | |
| snippet: str = "" | |
| score: float = 0.0 | |
| source: str = "" | |
| category: str = "" | |
| def get_image(self) -> Optional[Image.Image]: | |
| if self.image_string: | |
| return convert_text_to_image(self.image_string) | |
| def from_image(cls, image: Image.Image, **kwargs): | |
| return cls(image_string=convert_image_to_text(image), **kwargs) | |
| class ObjectDetector(BaseModel, arbitrary_types_allowed=True): | |
| def run(self, image: Image.Image) -> List[MultimodalObject]: | |
| raise NotImplementedError() | |
| class YoloDetector(ObjectDetector): | |
| repo_id: str = "DILHTWD/documentlayoutsegmentation_YOLOv8_ondoclaynet" | |
| filename: str = "yolov8x-doclaynet-epoch64-imgsz640-initiallr1e-4-finallr1e-5.pt" | |
| local_dir: str = "data/yolo" | |
| client: Optional[YOLO] = None | |
| def load(self): | |
| if self.client is None: | |
| if not Path(self.local_dir, self.filename).exists(): | |
| hf_hub_download( | |
| repo_id=self.repo_id, | |
| filename=self.filename, | |
| local_dir=self.local_dir, | |
| ) | |
| self.client = YOLO(Path(self.local_dir, self.filename)) | |
| def save_image(self, image: Image.Image) -> str: | |
| text = convert_image_to_text(image) | |
| hash_id = hashlib.md5(text.encode()).hexdigest() | |
| path = Path(self.local_dir, f"{hash_id}.png") | |
| image.save(path) | |
| return str(path) | |
| def extract_subimage(image: Image.Image, box: List[float]) -> Image.Image: | |
| return image.crop((round(box[0]), round(box[1]), round(box[2]), round(box[3]))) | |
| def run(self, image: Image.Image) -> List[MultimodalObject]: | |
| self.load() | |
| path = self.save_image(image) | |
| results: List[Results] = self.client(source=[path]) | |
| assert len(results) == 1 | |
| objects = [] | |
| for i, label_id in enumerate(results[0].boxes.cls): | |
| label = results[0].names[label_id.item()] | |
| score = results[0].boxes.conf[i].item() | |
| box: List[float] = results[0].boxes.xyxy[i].tolist() | |
| subimage = self.extract_subimage(image, box) | |
| objects.append( | |
| MultimodalObject( | |
| image_string=convert_image_to_text(subimage), | |
| category=label, | |
| score=score, | |
| ) | |
| ) | |
| return objects | |
| class MultimodalPage(BaseModel): | |
| number: int | |
| objects: List[MultimodalObject] | |
| text: str | |
| image_string: str | |
| source: str | |
| score: float = 0.0 | |
| def get_tables_and_figures(self) -> List[MultimodalObject]: | |
| return [o for o in self.objects if o.category in ["Table", "Picture"]] | |
| def get_full_image(self) -> Image.Image: | |
| return convert_text_to_image(self.image_string) | |
| def from_text(cls, text: str): | |
| return MultimodalPage( | |
| text=text, number=0, objects=[], image_string="", source="" | |
| ) | |
| def from_image(cls, image: Image.Image): | |
| return MultimodalPage( | |
| image_string=convert_image_to_text(image), | |
| number=0, | |
| objects=[], | |
| text="", | |
| source="", | |
| ) | |
| class MultimodalDocument(BaseModel): | |
| pages: List[MultimodalPage] | |
| def get_page(self, i: int) -> MultimodalPage: | |
| pages = [p for p in self.pages if p.number == i] | |
| assert len(pages) == 1 | |
| return pages[0] | |
| def load_from_pdf(cls, path: str, dpi: int = 150, detector: ObjectDetector = None): | |
| # Each page as an image (with optional extracted text) | |
| doc = fitz.open(path) | |
| pages = [] | |
| for i, page in enumerate(tqdm(doc.pages(), desc=path)): | |
| text = page.get_text() | |
| zoom = dpi / 72 # 72 is the default DPI | |
| matrix = fitz.Matrix(zoom, zoom) | |
| pix = page.get_pixmap(matrix=matrix) | |
| image = Image.frombytes("RGB", (pix.width, pix.height), pix.samples) | |
| objects = [] | |
| if detector: | |
| objects = detector.run(image) | |
| for o in objects: | |
| o.page, o.source = i + 1, path | |
| pages.append( | |
| MultimodalPage( | |
| number=i + 1, | |
| objects=objects, | |
| text=text, | |
| image_string=convert_image_to_text(image), | |
| source=path, | |
| ) | |
| ) | |
| return cls(pages=pages) | |
| def load(cls, path: str): | |
| pages = [] | |
| with open(path) as f: | |
| for line in f: | |
| pages.append(MultimodalPage(**json.loads(line))) | |
| return cls(pages=pages) | |
| def save(self, path: str): | |
| Path(path).parent.mkdir(exist_ok=True, parents=True) | |
| with open(path, "w") as f: | |
| for o in self.pages: | |
| print(o.model_dump_json(), file=f) | |
| def get_domain(self) -> str: | |
| filename = Path(self.pages[0].source).name | |
| if filename.startswith("NYSE"): | |
| return "Financial<br>Report" | |
| elif filename[:4].isdigit() and filename[4] == "." and filename[5].isdigit(): | |
| return "Academic<br>Paper" | |
| else: | |
| return "Technical<br>Manuals" | |
| class MultimodalRetriever(BaseModel, arbitrary_types_allowed=True): | |
| def run(self, query: str, doc: MultimodalDocument) -> MultimodalDocument: | |
| raise NotImplementedError | |
| def get_top_pages(doc: MultimodalDocument, k: int) -> List[int]: | |
| # Get top-k in terms of score but maintain the original order | |
| doc = doc.copy(deep=True) | |
| pages = sorted(doc.pages, key=lambda x: x.score, reverse=True) | |
| threshold = pages[:k][-1].score | |
| return [p.number for p in doc.pages if p.score >= threshold] | |
| class ColpaliRetriever(MultimodalRetriever): | |
| path: str = "vidore/colpali-v1.2" | |
| model: Optional[ColPali] = None | |
| processor: Optional[ColPaliProcessor] = None | |
| device: str = "cuda" | |
| cache: OrderedDict[str, torch.Tensor] = CollectionsOrderedDict() | |
| def load(self): | |
| if self.model is None: | |
| self.model = ColPali.from_pretrained( | |
| self.path, torch_dtype=torch.bfloat16, device_map=self.device | |
| ) | |
| self.model = self.model.eval() | |
| self.processor = ColPaliProcessor.from_pretrained(self.path) | |
| def encode_document(self, doc: MultimodalDocument) -> torch.Tensor: | |
| hash_id = hashlib.md5(doc.json().encode()).hexdigest() | |
| if len(self.cache) > 100: | |
| self.cache.popitem(last=False) | |
| if hash_id not in self.cache: | |
| images = [page.get_full_image() for page in doc.pages] | |
| batch_size = 8 | |
| ds: List[torch.Tensor] = [] | |
| for i in tqdm(range(0, len(images), batch_size), desc="Encoding document"): | |
| batch = self.processor.process_images(images[i : i + batch_size]) | |
| with torch.no_grad(): | |
| # noinspection PyTypeChecker | |
| ds.append(self.model(**batch.to(self.device)).cpu()) | |
| lengths = [x.shape[1] for x in ds] | |
| if len(set(lengths)) != 1: | |
| print("Warning: Inconsistent lengths from colqwen", set(lengths)) | |
| assert "colqwen" in self.path | |
| for i, x in enumerate(ds): | |
| ds[i] = x[:, : min(lengths), :] | |
| self.cache[hash_id] = torch.cat(ds) | |
| return self.cache[hash_id] | |
| def run(self, query: str, doc: MultimodalDocument) -> MultimodalDocument: | |
| doc = doc.copy(deep=True) | |
| self.load() | |
| ds = self.encode_document(doc) | |
| with torch.no_grad(): | |
| # noinspection PyTypeChecker | |
| qs = self.model(**self.processor.process_queries([query]).to(self.device)) | |
| # noinspection PyTypeChecker | |
| scores = self.processor.score_multi_vector(qs.cpu(), ds).squeeze() | |
| assert len(scores) == len(doc.pages) | |
| for i, page in enumerate(doc.pages): | |
| page.score = scores[i].item() | |
| return doc | |
| class DummyRetriever(MultimodalRetriever): | |
| def run(self, query: str, doc: MultimodalDocument) -> MultimodalDocument: | |
| doc = doc.copy(deep=True) | |
| for i, page in enumerate(doc.pages): | |
| page.score = i | |
| return doc | |
| def convert_image_to_text(image: Image) -> str: | |
| # This is also how OpenAI encodes images: https://platform.openai.com/docs/guides/vision | |
| with io.BytesIO() as output: | |
| image.save(output, format="PNG") | |
| data = output.getvalue() | |
| return base64.b64encode(data).decode("utf-8") | |
| def convert_text_to_image(text: str) -> Image: | |
| data = base64.b64decode(text.encode("utf-8")) | |
| return Image.open(io.BytesIO(data)) | |
| def save_image(image: Image.Image, folder: str) -> str: | |
| image_hash = hashlib.md5(image.tobytes()).hexdigest() | |
| path = Path(folder, f"{image_hash}.png") | |
| path.parent.mkdir(exist_ok=True, parents=True) | |
| if not path.exists(): | |
| image.save(path) | |
| return str(path) | |
| def resize_image(image: Image.Image, max_size: int) -> Image.Image: | |
| # Same as modeling.py resize_image | |
| width, height = image.size | |
| if width <= max_size and height <= max_size: | |
| return image | |
| if width > height: | |
| new_width = max_size | |
| new_height = round(height * max_size / width) | |
| else: | |
| new_height = max_size | |
| new_width = round(width * max_size / height) | |
| return image.resize((new_width, new_height), Image.LANCZOS) | |
| class EvalModel(BaseModel, arbitrary_types_allowed=True): | |
| engine: str | |
| timeout: int = 60 | |
| temperature: float = 0.0 | |
| max_output_tokens: int = 512 | |
| def run(self, inputs: List[Union[str, Image.Image]]) -> str: | |
| raise NotImplementedError | |
| def run_many(self, inputs: List[Union[str, Image.Image]], num: int) -> List[str]: | |
| raise NotImplementedError | |
| class SwiftQwenModel(EvalModel): | |
| # https://github.com/modelscope/ms-swift/blob/main/docs/source_en/Multi-Modal/qwen2-vl-best-practice.md | |
| path: str = "" | |
| model: Optional[Qwen2VLForConditionalGeneration] = None | |
| tokenizer: Optional[PreTrainedTokenizer] = None | |
| engine: str = ModelType.qwen2_vl_7b_instruct | |
| image_size: int = 768 | |
| image_dir: str = "data/qwen_images" | |
| def load(self): | |
| if self.model is None or self.tokenizer is None: | |
| self.model, self.tokenizer = get_model_tokenizer( | |
| self.engine, | |
| torch.bfloat16, | |
| model_kwargs={"device_map": "auto"}, | |
| model_id_or_path=self.path or None, | |
| ) | |
| def run(self, inputs: List[Union[str, Image.Image]]) -> str: | |
| self.load() | |
| template_type = get_default_template_type(self.engine) | |
| self.model.generation_config.max_new_tokens = self.max_output_tokens | |
| template = get_template(template_type, self.tokenizer) | |
| text = "\n\n".join([x for x in inputs if isinstance(x, str)]) | |
| content = [] | |
| for x in inputs: | |
| if isinstance(x, Image.Image): | |
| path = save_image(resize_image(x, self.image_size), self.image_dir) | |
| content.append(f"<img>{path}</img>") | |
| content.append(text) | |
| query = "".join(content) | |
| response, history = inference(self.model, template, query) | |
| return response | |
| def run_stream(self, inputs: List[Union[str, Image.Image]]) -> Iterator[str]: | |
| self.load() | |
| template_type = get_default_template_type(self.engine) | |
| self.model.generation_config.max_new_tokens = self.max_output_tokens | |
| template = get_template(template_type, self.tokenizer) | |
| text = "\n\n".join([x for x in inputs if isinstance(x, str)]) | |
| content = [] | |
| for x in inputs: | |
| if isinstance(x, Image.Image): | |
| path = save_image(resize_image(x, self.image_size), self.image_dir) | |
| content.append(f"<img>{path}</img>") | |
| content.append(text) | |
| query = "".join(content) | |
| generator = inference_stream(self.model, template, query) | |
| print_idx = 0 | |
| print(f"query: {query}\nresponse: ", end="") | |
| for response, history in generator: | |
| delta = response[print_idx:] | |
| print(delta, end="", flush=True) | |
| print_idx = len(response) | |
| yield delta | |
| class QwenModel(EvalModel): | |
| path: str = "models/qwen" | |
| engine: str = "Qwen/Qwen2-VL-7B-Instruct" | |
| model: Optional[Qwen2VLForConditionalGeneration] = None | |
| processor: Optional[Qwen2VLProcessor] = None | |
| tokenizer: Optional[AutoTokenizer] = None | |
| device: str = "cuda" | |
| image_size: int = 768 | |
| lora_path: str = "" | |
| def load(self): | |
| if self.model is None: | |
| path = self.path if os.path.exists(self.path) else self.engine | |
| print(dict(load_path=path)) | |
| # noinspection PyTypeChecker | |
| self.model = Qwen2VLForConditionalGeneration.from_pretrained( | |
| path, torch_dtype="auto", device_map="auto" | |
| ) | |
| self.tokenizer = AutoTokenizer.from_pretrained(self.engine) | |
| if self.lora_path: | |
| print("Loading LORA from", self.lora_path) | |
| self.model.load_adapter(self.lora_path) | |
| self.model = self.model.to(self.device).eval() | |
| self.processor = Qwen2VLProcessor.from_pretrained(self.engine) | |
| torch.manual_seed(0) | |
| torch.cuda.manual_seed_all(0) | |
| def make_messages(self, inputs: List[Union[str, Image.Image]]) -> List[dict]: | |
| text = "\n\n".join([x for x in inputs if isinstance(x, str)]) | |
| content = [ | |
| dict( | |
| type="image", | |
| image=f"data:image;base64,{convert_image_to_text(resize_image(x, self.image_size))}", | |
| ) | |
| for x in inputs | |
| if isinstance(x, Image.Image) | |
| ] | |
| content.append(dict(type="text", text=text)) | |
| return [dict(role="user", content=content)] | |
| def run(self, inputs: List[Union[str, Image.Image]]) -> str: | |
| self.load() | |
| messages = self.make_messages(inputs) | |
| text = self.processor.apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=True | |
| ) | |
| image_inputs, video_inputs = process_vision_info(messages) | |
| # noinspection PyTypeChecker | |
| model_inputs = self.processor( | |
| text=[text], | |
| images=image_inputs, | |
| videos=video_inputs, | |
| padding=True, | |
| return_tensors="pt", | |
| ).to(self.device) | |
| with torch.inference_mode(): | |
| generated_ids = self.model.generate( | |
| **model_inputs, max_new_tokens=self.max_output_tokens | |
| ) | |
| generated_ids_trimmed = [ | |
| out_ids[len(in_ids) :] | |
| for in_ids, out_ids in zip(model_inputs.input_ids, generated_ids) | |
| ] | |
| output_text = self.processor.batch_decode( | |
| generated_ids_trimmed, | |
| skip_special_tokens=True, | |
| clean_up_tokenization_spaces=False, | |
| ) | |
| return output_text[0] | |
| def run_stream(self, inputs: List[Union[str, Image.Image]]) -> Iterator[str]: | |
| self.load() | |
| messages = self.make_messages(inputs) | |
| text = self.processor.apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=True | |
| ) | |
| image_inputs, video_inputs = process_vision_info(messages) | |
| # noinspection PyTypeChecker | |
| model_inputs = self.processor( | |
| text=[text], | |
| images=image_inputs, | |
| videos=video_inputs, | |
| padding=True, | |
| return_tensors="pt", | |
| ).to(self.device) | |
| streamer = TextIteratorStreamer( | |
| self.tokenizer, | |
| timeout=10.0, | |
| skip_prompt=True, | |
| skip_special_tokens=True, | |
| ) | |
| generate_kwargs = dict( | |
| **model_inputs, | |
| streamer=streamer, | |
| max_new_tokens=self.max_output_tokens, | |
| ) | |
| t = Thread(target=self.model.generate, kwargs=generate_kwargs) | |
| t.start() | |
| outputs = [] | |
| for text in streamer: | |
| outputs.append(text) | |
| yield "".join(outputs) | |
| class DummyModel(EvalModel): | |
| engine: str = "" | |
| def run(self, inputs: List[Union[str, Image.Image]]) -> str: | |
| return " ".join(inputs) | |
| def run_stream(self, inputs: List[Union[str, Image.Image]]) -> Iterator[str]: | |
| assert self is not None | |
| text = " ".join([x for x in inputs if isinstance(x, str)]) | |
| num_images = sum(1 for x in inputs if isinstance(x, Image.Image)) | |
| tokens = f"Hello this is your message: {text}, images: {num_images}".split() | |
| for i in range(len(tokens)): | |
| yield " ".join(tokens[: i + 1]) | |
| import time | |
| time.sleep(0.05) | |
| if not torch.cuda.is_available(): | |
| DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>" | |
| if torch.cuda.is_available(): | |
| model = QwenModel() | |
| model.load() | |
| detect_model = YoloDetector() | |
| detect_model.load() | |
| retriever = ColpaliRetriever() | |
| retriever.load() | |
| else: | |
| model = DummyModel() | |
| detect_model = None | |
| retriever = DummyRetriever() | |
| def get_file_path(file: gr.File = None, url: str = None) -> Optional[str]: | |
| if file is not None: | |
| # noinspection PyUnresolvedReferences | |
| return file.name | |
| if url is not None: | |
| response = requests.get(url) | |
| response.raise_for_status() | |
| save_path = Path(tempfile.mkdtemp(), url.split("/")[-1]) | |
| if "application/pdf" in response.headers.get("Content-Type", ""): | |
| # Open the file in binary write mode | |
| with open(save_path, "wb") as file: | |
| file.write(response.content) | |
| return str(save_path) | |
| def generate( | |
| query: str, file: gr.File = None, url: str = None, top_k=5 | |
| ) -> Iterator[str]: | |
| sample = MultimodalSample(question=query, answer="", category="") | |
| path = get_file_path(file, url) | |
| if path is not None: | |
| doc = MultimodalDocument.load_from_pdf(path, detector=detect_model) | |
| output = retriever.run(sample.question, doc) | |
| sorted_pages = sorted(output.pages, key=lambda p: p.score, reverse=True) | |
| sample.retrieved_pages = sorted([p.number for p in sorted_pages][:top_k]) | |
| context = [] | |
| for p in doc.pages: | |
| if p.number in sample.retrieved_pages: | |
| if p.text: | |
| context.append(p.text) | |
| context.extend(o.get_image() for o in p.get_tables_and_figures()) | |
| inputs = [ | |
| "Context:", | |
| *context, | |
| f"Answer the following question in 200 words or less: {sample.question}", | |
| ] | |
| else: | |
| inputs = [ | |
| "Context:", | |
| f"Answer the following question in 200 words or less: {sample.question}", | |
| ] | |
| for text in model.run_stream(inputs): | |
| yield text | |
| with gr.Blocks(css_paths="style.css", fill_height=True) as demo: | |
| gr.Markdown(DESCRIPTION) | |
| gr.DuplicateButton( | |
| value="Duplicate Space for private use", elem_id="duplicate-button" | |
| ) | |
| with gr.Row(): | |
| pdf_upload = gr.File(label="Upload PDF (optional)", file_types=[".pdf"]) | |
| with gr.Column(): | |
| url_input = gr.Textbox(label="Enter PDF URL (optional)") | |
| text_input = gr.Textbox(label="Enter your message", lines=3) | |
| submit_button = gr.Button("Submit") | |
| result = gr.Textbox(label="Response", lines=10) | |
| submit_button.click( | |
| generate, inputs=[text_input, pdf_upload, url_input], outputs=result | |
| ) | |
| gr.Markdown(LICENSE) | |
| if __name__ == "__main__": | |
| demo.queue(max_size=20).launch() | |