# ============================================================================ # FILE: app.py (Hugging Face Transformers + GGUF caching + Vision) # ============================================================================ import gradio as gr import torch from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration, BitsAndBytesConfig from transformers import AutoTokenizer, AutoModelForCausalLM from PIL import Image import fitz import json import pandas as pd import os import re import xlsxwriter import tempfile from pathlib import Path import gc import time import requests # ============================================================================ # SYSTEM PROMPT # ============================================================================ SYSTEM_PROMPT = """You are a data extraction assistant. Extract item details from the provided text. Provide output as a JSON array of objects with keys: 'Flag', 'Product Code', 'Description', 'Manufacturer', 'Supplier', 'Material', 'Dimensions', 'Product Image'. If a key's value is not found, provide empty string "". If no items found, return empty array []. Include only unique Product Code values. For Dimensions, format as "Height: X; Width: Y; Depth: Z" (semicolon-separated). Do not add duplicate or test data.""" # ============================================================================ # GLOBAL MODELS # ============================================================================ vision_model = None vision_processor = None text_model = None text_tokenizer = None # ============================================================================ # GGUF CACHING # ============================================================================ def get_gguf_local(model_id, filename="unsloth.Q4_K_M.gguf"): """Download GGUF file once and cache locally.""" cache_dir = Path("/tmp/gguf_cache") cache_dir.mkdir(parents=True, exist_ok=True) local_file = cache_dir / filename if not local_file.exists(): url = f"https://huggingface.co/{model_id}/resolve/main/{filename}" print(f"📥 Downloading GGUF file from Hugging Face: {url}") r = requests.get(url) r.raise_for_status() with open(local_file, "wb") as f: f.write(r.content) print("✅ GGUF file downloaded and cached locally.") else: print("✅ Using cached GGUF file.") return str(local_file) # ============================================================================ # MODEL LOADERS # ============================================================================ def load_vision_model(): """Load vision model lazily""" global vision_model, vision_processor if vision_model is None: print("📸 Loading vision model...") vision_processor = LlavaNextProcessor.from_pretrained( "llava-hf/llava-v1.6-mistral-7b-hf" ) quantization_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.float16, ) vision_model = LlavaNextForConditionalGeneration.from_pretrained( "llava-hf/llava-v1.6-mistral-7b-hf", torch_dtype=torch.float16, quantization_config=quantization_config, device_map="auto", low_cpu_mem_usage=True, ) print("✅ Vision model loaded!") return vision_model, vision_processor def load_text_model(): """Load GGUF text model with caching""" global text_model, text_tokenizer if text_model is None: model_id = "pragnesh002/Qwen3-4B-Product-Extractor-GGUF-Q4-K-M" gguf_local = get_gguf_local(model_id) print("📝 Loading GGUF text extraction model...") text_tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) text_model = AutoModelForCausalLM.from_pretrained( gguf_local, device_map="auto", torch_dtype=torch.float16, low_cpu_mem_usage=True, trust_remote_code=True ) print("✅ GGUF text model loaded via Transformers!") return text_model, text_tokenizer # ============================================================================ # HELPER FUNCTIONS # ============================================================================ def cleanup_memory(): gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() def extract_pdf_text(pdf_path): """Extract text from PDF""" try: doc = fitz.open(pdf_path) pages_text = [] for page_num in range(doc.page_count): page = doc.load_page(page_num) text = page.get_text().strip() if len(text) < 50: text = f"[Page {page_num + 1} - Low text content]" pages_text.append(text) doc.close() return pages_text except Exception as e: return [f"Error extracting text: {str(e)}"] def extract_products_from_text(page_text, page_num): """Extract product data using Transformers LLM""" text_model, text_tokenizer = load_text_model() messages = [ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": f"Text:\n{page_text[:2000]}\n\nOutput JSON:"} ] try: # Use chat template if available if hasattr(text_tokenizer, "apply_chat_template"): prompt = text_tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) else: prompt = messages[-1]["content"] inputs = text_tokenizer(prompt, return_tensors="pt").to(text_model.device) with torch.no_grad(): outputs = text_model.generate( **inputs, max_new_tokens=1024, temperature=0.1, do_sample=False, pad_token_id=text_tokenizer.eos_token_id ) output_text = text_tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True) # Extract JSON json_match = re.search(r'```json\s*(.*?)\s*```', output_text, re.DOTALL) if json_match: json_str = json_match.group(1) else: json_match = re.search(r'(\[.*\]|\{.*\})', output_text, re.DOTALL) json_str = json_match.group(1) if json_match else output_text parsed = json.loads(json_str) if isinstance(parsed, dict): parsed = [parsed] elif not isinstance(parsed, list): parsed = [] del inputs, outputs cleanup_memory() return parsed except Exception as e: print(f"Error extracting from page {page_num}: {e}") cleanup_memory() return [] # For brevity, vision analysis, image extraction, Excel creation functions # can remain unchanged from your last version. Use your previous implementations: # - analyze_image_with_vision_model() # - extract_images_from_page() # - create_excel_with_images() # ============================================================================ # MAIN PROCESSING FUNCTION # ============================================================================ def process_pdf(pdf_file, max_pages, progress=gr.Progress()): if pdf_file is None: return None, "⚠️ Please upload a PDF file first" progress(0, desc="Initializing...") try: load_text_model() except Exception as e: return None, f"❌ Error loading text model: {str(e)}" with tempfile.TemporaryDirectory() as temp_dir: temp_dir = Path(temp_dir) img_dir = temp_dir / "images" img_dir.mkdir() try: pages_text = extract_pdf_text(pdf_file.name) doc = fitz.open(pdf_file.name) all_products = {} product_images = {} total_pages = min(len(pages_text), doc.page_count, max_pages) if total_pages > 20: doc.close() return None, f"⚠️ PDF has {total_pages} pages. Limit to 20 pages." for page_num in range(total_pages): progress(0.2 + (0.6 * page_num / total_pages), desc=f"Processing page {page_num+1}/{total_pages}...") products = extract_products_from_text(pages_text[page_num], page_num) # Images & matching logic (use your previous code) # ... # Store products for product in products: code = product.get('Product Code', '').strip() if code and code not in all_products: product['Product Image File'] = product_images.get(code, '') all_products[code] = product cleanup_memory() time.sleep(0.5) doc.close() progress(0.95, desc="Creating Excel...") if all_products: output_excel = temp_dir / "products_with_images.xlsx" create_excel_with_images(list(all_products.values()), str(output_excel)) total_products = len(all_products) products_with_images = sum(1 for p in all_products.values() if p.get('Product Image File')) summary = f""" ## ✅ Extraction Complete! - **Total products found:** {total_products} - **Products with images:** {products_with_images} - **Pages processed:** {total_pages} - **Image match rate:** {(products_with_images/total_products*100):.1f}% ### Download your Excel file below! 📥 """ progress(1.0, desc="✅ Done!") return str(output_excel), summary else: return None, "⚠️ No products found in PDF." except Exception as e: import traceback return None, f"❌ Error: {str(e)}\n```\n{traceback.format_exc()}\n```" # ============================================================================ # GRADIO INTERFACE # ============================================================================ with gr.Blocks(title="Product Data Extractor", theme=gr.themes.Soft()) as demo: gr.Markdown(""" # 📦 Product Data Extractor with AI Vision Automatically extract product info and images from PDF using LLaVA + Qwen GGUF. """) with gr.Row(): with gr.Column(scale=1): pdf_input = gr.File(label="📄 Upload PDF", file_types=[".pdf"], type="filepath") max_pages_slider = gr.Slider(minimum=1, maximum=20, value=10, step=1, label="Max Pages") extract_btn = gr.Button("🚀 Extract Products", variant="primary", size="lg") with gr.Column(scale=1): summary_output = gr.Markdown(label="Results") excel_output = gr.File(label="📥 Download Excel") extract_btn.click(fn=process_pdf, inputs=[pdf_input, max_pages_slider], outputs=[excel_output, summary_output]) # Launch if __name__ == "__main__": demo.queue(max_size=5) demo.launch()