| | |
| | import os |
| | import gc |
| | import tempfile |
| | import gradio as gr |
| | import torch |
| | import numpy as np |
| | import faiss |
| | from typing import Tuple, Dict, Any, Optional |
| | import spaces |
| |
|
| | from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline |
| |
|
| | |
| | os.environ["TOKENIZERS_PARALLELISM"] = "false" |
| |
|
| | |
| | LLM_MODEL_NAME = "Qwen/Qwen2.5-7B-Instruct" |
| | EMBED_MODEL_NAME = "BAAI/bge-large-en-v1.5" |
| | DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
| | MAX_PROMPT_LENGTH = 28000 |
| |
|
| | |
| | QA_PROMPT_TEMPLATE = ( |
| | "System: You are a helpful assistant. Answer the user's question based *only* on the provided context. " |
| | "If the answer is not found in the context, state that clearly.\n\n" |
| | "Context:\n---\n{context}\n---\n\nQuestion: {question}\n\nAnswer:" |
| | ) |
| |
|
| | SUMMARY_PROMPTS = { |
| | "Quick": ( |
| | "You are an expert academic summarizer. Provide a single, concise paragraph that summarizes the absolute key takeaway of the following document. " |
| | "Be brief and direct.\n\nDocument:\n---\n{text}\n---\n\nQuick Summary:" |
| | ), |
| | "Standard": ( |
| | "You are an expert academic summarizer. Provide a detailed, well-structured summary of the following document. " |
| | "Cover the key points, methodology, findings, and conclusions.\n\n" |
| | "Document:\n---\n{text}\n---\n\nStandard Summary:" |
| | ), |
| | "Detailed": ( |
| | "You are an expert academic summarizer. Provide a highly detailed and comprehensive summary of the following document. " |
| | "Go into depth on the methodology, specific results, limitations, and any mention of future work. Use multiple paragraphs for structure.\n\n" |
| | "Document:\n---\n{text}\n---\n\nDetailed Summary:" |
| | ) |
| | } |
| |
|
| | |
| | class ModelManager: |
| | _llm_pipe = None |
| | _embed_model = None |
| |
|
| | @classmethod |
| | def _clear_gpu_memory(cls): |
| | """Frees up GPU memory by deleting models and clearing the cache.""" |
| | models = [cls._llm_pipe, cls._embed_model] |
| | for model in models: |
| | if model: |
| | try: |
| | del model |
| | except Exception: |
| | pass |
| | cls._llm_pipe = None |
| | cls._embed_model = None |
| | gc.collect() |
| | if torch.cuda.is_available(): |
| | torch.cuda.empty_cache() |
| | print("[Memory] GPU Memory Cleared.") |
| |
|
| | @classmethod |
| | def get_llm_pipeline(cls): |
| | """Loads and returns the LLM pipeline, ensuring no other models are loaded.""" |
| | if cls._llm_pipe is None: |
| | cls._clear_gpu_memory() |
| | print("[LLM] Loading model...") |
| | try: |
| | tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_NAME) |
| | model = AutoModelForCausalLM.from_pretrained( |
| | LLM_MODEL_NAME, |
| | device_map=DEVICE, |
| | torch_dtype=torch.bfloat16 |
| | ) |
| | cls._llm_pipe = pipeline( |
| | "text-generation", |
| | model=model, |
| | tokenizer=tokenizer, |
| | max_new_tokens=1024, |
| | temperature=0.2, |
| | top_p=0.95, |
| | ) |
| | print("[LLM] Model loaded successfully.") |
| | except Exception as e: |
| | print(f"[LLM] Failed to load model: {e}") |
| | return None |
| | return cls._llm_pipe |
| |
|
| | @classmethod |
| | def get_embedding_model(cls): |
| | """Loads and returns the embedding model, ensuring the LLM is not loaded.""" |
| | |
| | from langchain_huggingface import HuggingFaceEmbeddings |
| | if cls._embed_model is None: |
| | cls._clear_gpu_memory() |
| | print("[Embed] Loading embedding model...") |
| | try: |
| | cls._embed_model = HuggingFaceEmbeddings( |
| | model_name=EMBED_MODEL_NAME, |
| | model_kwargs={"device": DEVICE}, |
| | encode_kwargs={"normalize_embeddings": True} |
| | ) |
| | print("[Embed] Embedding model loaded successfully.") |
| | except Exception as e: |
| | print(f"[Embed] Failed to load model: {e}") |
| | return None |
| | return cls._embed_model |
| |
|
| | |
| | @spaces.GPU |
| | def invoke_llm(prompt_str: str) -> str: |
| | """Invokes the LLM with a given prompt.""" |
| | if len(prompt_str) > MAX_PROMPT_LENGTH: |
| | prompt_str = prompt_str[:MAX_PROMPT_LENGTH] |
| | print(f"[invoke_llm] Prompt truncated to {MAX_PROMPT_LENGTH} characters.") |
| |
|
| | try: |
| | pipe = ModelManager.get_llm_pipeline() |
| | if not pipe: |
| | return "Error: LLM could not be loaded." |
| |
|
| | with torch.no_grad(): |
| | outputs = pipe(prompt_str) |
| |
|
| | if isinstance(outputs, list) and outputs and "generated_text" in outputs[0]: |
| | |
| | return outputs[0]["generated_text"].replace(prompt_str, "").strip() |
| | return "No valid response was generated." |
| |
|
| | except Exception as e: |
| | print(f"[invoke_llm] Error: {e}") |
| | return f"LLM invocation failed: {e}" |
| |
|
| | @spaces.GPU |
| | def process_pdf_and_index(pdf_path: str) -> Tuple[str, Optional[Dict[str, Any]]]: |
| | """Processes a PDF, creates embeddings, and builds a FAISS index.""" |
| | from langchain_community.document_loaders import PyMuPDFLoader |
| | from langchain.text_splitter import RecursiveCharacterTextSplitter |
| |
|
| | if not pdf_path: |
| | return "No file path provided.", None |
| |
|
| | try: |
| | print("[Process] Loading and splitting PDF...") |
| | docs = PyMuPDFLoader(pdf_path).load() |
| | chunks = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=150).split_documents(docs) |
| | texts = [c.page_content for c in chunks if c.page_content.strip()] |
| |
|
| | if not texts: |
| | return "No text could be extracted from the PDF.", None |
| | print(f"[Process] Extracted {len(texts)} text chunks.") |
| |
|
| | embed_model = ModelManager.get_embedding_model() |
| | if not embed_model: |
| | return "Could not load embedding model.", None |
| |
|
| | print(f"[Process] Creating embeddings...") |
| | embeddings = embed_model.embed_documents(texts) |
| | emb_np = np.array(embeddings, dtype=np.float32) |
| |
|
| | print("[Process] Building and saving FAISS index...") |
| | index = faiss.IndexFlatL2(emb_np.shape[1]) |
| | index.add(emb_np) |
| |
|
| | with tempfile.NamedTemporaryFile(delete=False, suffix=".faiss") as f: |
| | index_path = f.name |
| | faiss.write_index(index, index_path) |
| |
|
| | state_bundle = {"index_path": index_path, "texts": texts} |
| | return f"Successfully processed and indexed {len(texts)} chunks.", state_bundle |
| |
|
| | except Exception as e: |
| | print(f"[process_pdf] Exception: {e}") |
| | return f"Error processing PDF: {e}", None |
| |
|
| | @spaces.GPU |
| | def retrieve_and_answer(question: str, state_bundle: Dict[str, Any]) -> Tuple[str, str]: |
| | """Retrieves context and generates an answer for a given question.""" |
| | if not (state_bundle and "index_path" in state_bundle): |
| | return "Please upload and process a PDF first.", "" |
| |
|
| | try: |
| | embed_model = ModelManager.get_embedding_model() |
| | if not embed_model: |
| | return "Error loading embedding model.", "" |
| |
|
| | index = faiss.read_index(state_bundle["index_path"]) |
| | texts = state_bundle.get("texts", []) |
| |
|
| | query_embedding = embed_model.embed_query(question) |
| | q_arr = np.array([query_embedding], dtype=np.float32) |
| |
|
| | _, indices = index.search(q_arr, k=5) |
| |
|
| | sources = [texts[idx] for idx in indices[0] if 0 <= idx < len(texts)] |
| | if not sources: |
| | return "Could not find relevant information.", "" |
| |
|
| | context = "\n\n---\n\n".join(sources) |
| | sources_preview = "\n\n---\n\n".join(s[:500] + "..." for s in sources) |
| |
|
| | prompt = QA_PROMPT_TEMPLATE.format(context=context, question=question) |
| | answer = invoke_llm(prompt) |
| |
|
| | return answer, sources_preview |
| |
|
| | except Exception as e: |
| | print(f"[retrieve_and_answer] Error: {e}") |
| | return f"An error occurred: {e}", "" |
| |
|
| | @spaces.GPU |
| | def summarize_document(state_bundle: Dict[str, Any], summary_type: str) -> Tuple[str, Optional[str]]: |
| | """Generates a summary of the document and saves it to a temporary file.""" |
| | if not (state_bundle and "texts" in state_bundle): |
| | return "Please upload and process a PDF first.", None |
| |
|
| | texts = state_bundle.get("texts", []) |
| | if not texts: |
| | return "No text available to summarize.", None |
| |
|
| | full_text = "\n\n".join(texts) |
| |
|
| | prompt_template = SUMMARY_PROMPTS.get(summary_type, SUMMARY_PROMPTS["Standard"]) |
| | prompt = prompt_template.format(text=full_text) |
| |
|
| | print(f"[Summarize] Generating '{summary_type}' summary...") |
| | final_summary = invoke_llm(prompt) |
| |
|
| | |
| | temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".txt", mode="w", encoding="utf-8") |
| | temp_file.write(final_summary) |
| | temp_file.close() |
| |
|
| | return final_summary, temp_file.name |
| |
|
| | |
| | with gr.Blocks(title="PDF Summarizer & Assistant", theme=gr.themes.Soft()) as demo: |
| | gr.Markdown("# 📚 PDF Summarizer & Q&A Assistant") |
| | gr.Markdown("Upload a PDF to generate a summary or ask questions about its content.") |
| |
|
| | state = gr.State() |
| |
|
| | with gr.Row(): |
| | pdf_in = gr.File(label="Upload PDF", file_types=[".pdf"], type="filepath") |
| | process_btn = gr.Button("Process PDF", variant="primary") |
| |
|
| | status_output = gr.Textbox(label="Status", interactive=False) |
| |
|
| | with gr.Tabs(): |
| | with gr.TabItem("Summarization"): |
| | gr.Markdown("### Generate a Summary") |
| | gr.Markdown("Select the level of detail you want in the summary.") |
| | summary_type_radio = gr.Radio( |
| | choices=["Quick", "Standard", "Detailed"], |
| | value="Standard", |
| | label="Summary Type" |
| | ) |
| | summary_btn = gr.Button("Generate Summary", variant="secondary") |
| | out_summary = gr.Textbox(label="Document Summary", lines=20, max_lines=25) |
| | download_btn = gr.DownloadButton("Download Summary", visible=False) |
| |
|
| | with gr.TabItem("Question & Answer"): |
| | gr.Markdown("### Ask a Question") |
| | gr.Markdown("Ask a specific question about the document's content.") |
| | q_text = gr.Textbox(label="Your Question", placeholder="e.g., What was the main conclusion of the study?") |
| | q_btn = gr.Button("Get Answer", variant="secondary") |
| | q_out = gr.Textbox(label="Answer", lines=8) |
| | q_sources = gr.Textbox(label="Retrieved Sources", lines=8, max_lines=10) |
| |
|
| | |
| | def handle_process(pdf_file): |
| | """Wrapper to handle PDF processing and clear old outputs.""" |
| | if pdf_file is None: |
| | return "Please upload a file first.", None, "", "", "", "", None |
| | status_msg, bundle = process_pdf_and_index(pdf_file.name) |
| | |
| | return status_msg, bundle, "", "", "", "", None |
| |
|
| | process_btn.click( |
| | fn=handle_process, |
| | inputs=[pdf_in], |
| | outputs=[status_output, state, out_summary, q_text, q_out, q_sources, download_btn] |
| | ) |
| |
|
| | q_btn.click( |
| | fn=retrieve_and_answer, |
| | inputs=[q_text, state], |
| | outputs=[q_out, q_sources] |
| | ) |
| |
|
| | summary_btn.click( |
| | fn=summarize_document, |
| | inputs=[state, summary_type_radio], |
| | outputs=[out_summary, download_btn] |
| | ) |
| |
|
| | if __name__ == "__main__": |
| | demo.launch(share=False, show_error=True) |
| |
|