Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import sys | |
| import tempfile | |
| import time | |
| import logging | |
| import gc | |
| from dataclasses import dataclass | |
| from typing import Optional, Tuple, List, Any, Dict | |
| from contextlib import contextmanager | |
| import gradio as gr | |
| import torch | |
| import psutil | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| # Audio preprocessing not available in Hugging Face Spaces deployment | |
| PREPROCESSING_AVAILABLE = False | |
| def get_env_or_secret(key: str, default: Optional[str] = None) -> Optional[str]: | |
| """Get environment variable or default.""" | |
| return os.environ.get(key, default) | |
| class InferenceMetrics: | |
| """Track inference performance metrics.""" | |
| processing_time: float | |
| memory_usage: float | |
| device_used: str | |
| dtype_used: str | |
| model_size_mb: Optional[float] = None | |
| class PreprocessingConfig: | |
| """Configuration for audio preprocessing pipeline.""" | |
| normalize_format: bool = True | |
| normalize_volume: bool = True | |
| reduce_noise: bool = False | |
| remove_silence: bool = False | |
| def load_asr_pipeline( | |
| model_id: str, | |
| base_model_id: str, | |
| device_pref: str = "auto", | |
| hf_token: Optional[str] = None, | |
| dtype_pref: str = "auto", | |
| chunk_length_s: Optional[int] = None, | |
| return_timestamps: bool = False, | |
| ): | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| logger.info(f"Loading ASR pipeline for model: {model_id}") | |
| logger.info( | |
| f"Device preference: {device_pref}, Token provided: {hf_token is not None}" | |
| ) | |
| import torch | |
| from transformers import pipeline | |
| # Pick optimal device for inference | |
| device_str = "cpu" | |
| if device_pref == "auto": | |
| if torch.cuda.is_available(): | |
| device_str = "cuda" | |
| logger.info(f"Using CUDA: {torch.cuda.get_device_name()}") | |
| elif getattr(torch.backends, "mps", None) and torch.backends.mps.is_available(): | |
| device_str = "mps" | |
| logger.info("Using Apple Silicon MPS for inference") | |
| else: | |
| device_str = "cpu" | |
| logger.info("Using CPU for inference") | |
| else: | |
| device_str = device_pref | |
| # Pick dtype - optimized for inference performance | |
| dtype = None | |
| if dtype_pref == "auto": | |
| # For whisper-medium models, use float32 for stability in medical transcription | |
| if "whisper-medium" in model_id: | |
| dtype = torch.float32 | |
| logger.info( | |
| f"Using float32 for {model_id} (medical transcription stability)" | |
| ) | |
| elif device_str == "cuda": | |
| dtype = torch.float16 # Use half precision on GPU for speed | |
| logger.info("Using float16 on CUDA for faster inference") | |
| else: | |
| dtype = torch.float32 | |
| else: | |
| dtype = {"float32": torch.float32, "float16": torch.float16}.get( | |
| dtype_pref, torch.float32 | |
| ) | |
| logger.info("Pipeline configuration:") | |
| logger.info(f" Model: {model_id}") | |
| logger.info(f" Base model: {base_model_id}") | |
| logger.info(f" Dtype: {dtype}") | |
| logger.info(f" Device: {device_str}") | |
| logger.info(f" Chunk length: {chunk_length_s}s") | |
| logger.info(f" Return timestamps: {return_timestamps}") | |
| # Use ultra-simplified approach to avoid all compatibility issues | |
| try: | |
| logger.info( | |
| "Setting up ultra-simplified pipeline to avoid forced_decoder_ids conflicts..." | |
| ) | |
| # Create pipeline with absolute minimal configuration | |
| asr = pipeline( | |
| task="automatic-speech-recognition", | |
| model=model_id, | |
| torch_dtype=dtype, | |
| device=0 | |
| if device_str == "cuda" | |
| else ("mps" if device_str == "mps" else "cpu"), | |
| token=hf_token, | |
| ) | |
| # Post-loading cleanup to remove any forced_decoder_ids | |
| if hasattr(asr.model, "generation_config"): | |
| if hasattr(asr.model.generation_config, "forced_decoder_ids"): | |
| logger.info("Removing forced_decoder_ids from model generation config") | |
| asr.model.generation_config.forced_decoder_ids = None | |
| # Set basic parameters after loading | |
| if chunk_length_s: | |
| logger.info(f"Setting chunk_length_s to {chunk_length_s}") | |
| logger.info(f"Successfully created ultra-simplified pipeline for: {model_id}") | |
| except Exception as e: | |
| logger.error(f"Ultra-simplified pipeline creation failed: {e}") | |
| logger.info("Falling back to absolute minimal settings...") | |
| try: | |
| # Fallback with absolute minimal settings | |
| fallback_dtype = torch.float32 | |
| asr = pipeline( | |
| task="automatic-speech-recognition", | |
| model=model_id, | |
| torch_dtype=fallback_dtype, | |
| device="cpu", # Force CPU for maximum compatibility | |
| token=hf_token, | |
| ) | |
| # Post-loading cleanup | |
| if hasattr(asr.model, "generation_config"): | |
| if hasattr(asr.model.generation_config, "forced_decoder_ids"): | |
| logger.info("Removing forced_decoder_ids from fallback model") | |
| asr.model.generation_config.forced_decoder_ids = None | |
| device_str = "cpu" | |
| dtype = fallback_dtype | |
| logger.info( | |
| f"Minimal fallback pipeline created with dtype: {fallback_dtype}" | |
| ) | |
| except Exception as fallback_error: | |
| logger.error(f"Minimal fallback failed: {fallback_error}") | |
| raise | |
| return asr, device_str, str(dtype).replace("torch.", "") | |
| def memory_monitor(): | |
| """Context manager to monitor memory usage during inference.""" | |
| process = psutil.Process() | |
| start_memory = process.memory_info().rss / 1024 / 1024 # MB | |
| yield | |
| end_memory = process.memory_info().rss / 1024 / 1024 # MB | |
| return end_memory - start_memory | |
| def transcribe_local( | |
| audio_path: str, | |
| model_id: str, | |
| base_model_id: str, | |
| language: Optional[str], | |
| task: str, | |
| device_pref: str, | |
| dtype_pref: str, | |
| hf_token: Optional[str], | |
| chunk_length_s: Optional[int], | |
| stride_length_s: Optional[int], | |
| return_timestamps: bool, | |
| ) -> Dict[str, Any]: | |
| logger = logging.getLogger(__name__) | |
| logger.info(f"Starting transcription: {os.path.basename(audio_path)}") | |
| logger.info(f"Model: {model_id}") | |
| # Validate audio_path | |
| if audio_path is None: | |
| raise ValueError("Audio path is None") | |
| if not isinstance(audio_path, (str, bytes, os.PathLike)): | |
| raise TypeError( | |
| f"Audio path must be str, bytes or os.PathLike, got {type(audio_path)}" | |
| ) | |
| if not os.path.exists(audio_path): | |
| raise FileNotFoundError(f"Audio file not found: {audio_path}") | |
| # Load ASR pipeline with performance monitoring | |
| start_time = time.time() | |
| asr, device_str, dtype_str = load_asr_pipeline( | |
| model_id=model_id, | |
| base_model_id=base_model_id, | |
| device_pref=device_pref, | |
| hf_token=hf_token, | |
| dtype_pref=dtype_pref, | |
| chunk_length_s=chunk_length_s, | |
| return_timestamps=return_timestamps, | |
| ) | |
| load_time = time.time() - start_time | |
| logger.info(f"Model loaded in {load_time:.2f}s") | |
| # Simplified configuration to avoid compatibility issues | |
| # Let the pipeline handle generation parameters internally | |
| logger.info("Using simplified configuration to avoid model compatibility issues") | |
| # Setup inference parameters with performance monitoring | |
| try: | |
| # Start with minimal parameters to avoid conflicts | |
| asr_kwargs = {} | |
| # Only add parameters that are safe and supported | |
| if return_timestamps: | |
| asr_kwargs["return_timestamps"] = return_timestamps | |
| logger.info("Timestamps enabled") | |
| # Apply chunking strategy only if supported | |
| if chunk_length_s: | |
| try: | |
| asr_kwargs["chunk_length_s"] = chunk_length_s | |
| logger.info(f"Using chunking strategy: {chunk_length_s}s") | |
| except Exception as chunk_error: | |
| logger.warning(f"Chunking not supported: {chunk_error}") | |
| if stride_length_s is not None: | |
| try: | |
| asr_kwargs["stride_length_s"] = stride_length_s | |
| logger.info(f"Using stride: {stride_length_s}s") | |
| except Exception as stride_error: | |
| logger.warning(f"Stride not supported: {stride_error}") | |
| logger.info(f"Inference parameters configured: {list(asr_kwargs.keys())}") | |
| # Run inference with performance monitoring | |
| inference_start = time.time() | |
| memory_before = psutil.Process().memory_info().rss / 1024 / 1024 # MB | |
| try: | |
| # Primary inference attempt with safe parameters | |
| if asr_kwargs: | |
| result = asr(audio_path, **asr_kwargs) | |
| else: | |
| # Fallback to no parameters if all failed | |
| result = asr(audio_path) | |
| inference_time = time.time() - inference_start | |
| memory_after = psutil.Process().memory_info().rss / 1024 / 1024 # MB | |
| memory_used = memory_after - memory_before | |
| logger.info(f"Inference completed successfully in {inference_time:.2f}s") | |
| logger.info(f"Memory used: {memory_used:.1f}MB") | |
| except Exception as e: | |
| error_msg = str(e) | |
| logger.warning(f"Inference failed with parameters: {error_msg}") | |
| # Try with absolutely minimal parameters | |
| if "forced_decoder_ids" in error_msg: | |
| logger.info( | |
| "Detected forced_decoder_ids error, trying with no parameters..." | |
| ) | |
| elif ( | |
| "probability tensor contains either inf, nan or element < 0" | |
| in error_msg | |
| ): | |
| logger.info( | |
| "Detected numerical instability, trying with no parameters..." | |
| ) | |
| else: | |
| logger.info("Unknown error, trying with no parameters...") | |
| try: | |
| inference_start = time.time() | |
| result = asr(audio_path) # No parameters at all | |
| inference_time = time.time() - inference_start | |
| memory_used = 0 # Reset memory tracking | |
| logger.info(f"Minimal inference completed in {inference_time:.2f}s") | |
| except Exception as final_error: | |
| logger.error(f"All inference attempts failed: {final_error}") | |
| raise | |
| except Exception as e: | |
| logger.error(f"Inference failed: {e}") | |
| raise | |
| # Cleanup GPU memory after inference | |
| if device_str == "cuda": | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| # Return results with performance metrics | |
| meta = { | |
| "device": device_str, | |
| "dtype": dtype_str, | |
| "inference_time": inference_time, | |
| "memory_used_mb": memory_used, | |
| "model_type": "original" if model_id == base_model_id else "fine-tuned", | |
| } | |
| return {"result": result, "meta": meta} | |
| def handle_whisper_problematic_output(text: str, model_name: str = "Whisper") -> dict: | |
| """Gestisce gli output problematici di Whisper come '!', '.', stringhe vuote, ecc.""" | |
| if not text: | |
| return { | |
| "text": "[WHISPER ISSUE: Output vuoto - Audio troppo corto o silenzioso]", | |
| "is_problematic": True, | |
| "original": text, | |
| "issue_type": "empty", | |
| } | |
| text_stripped = text.strip() | |
| # Casi problematici comuni | |
| problematic_outputs = { | |
| "!": "Audio troppo corto/silenzioso", | |
| ".": "Audio di bassa qualità", | |
| "?": "Audio incomprensibile", | |
| "...": "Audio troppo lungo senza parlato", | |
| "--": "Audio distorto", | |
| "—": "Audio con troppo rumore", | |
| " per!": "Audio parzialmente comprensibile", | |
| "per!": "Audio parzialmente comprensibile", | |
| } | |
| if text_stripped in problematic_outputs: | |
| return { | |
| "text": f"[WHISPER ISSUE: '{text_stripped}' - {problematic_outputs[text_stripped]}]", | |
| "is_problematic": True, | |
| "original": text, | |
| "issue_type": text_stripped, | |
| "suggestion": problematic_outputs[text_stripped], | |
| } | |
| # Testo troppo corto (meno di 3 caratteri e non alfabetico) | |
| if len(text_stripped) <= 2 and not text_stripped.isalpha(): | |
| return { | |
| "text": f"[WHISPER ISSUE: '{text_stripped}' - Output troppo corto/simbolico]", | |
| "is_problematic": True, | |
| "original": text, | |
| "issue_type": "short_symbolic", | |
| } | |
| return {"text": text, "is_problematic": False, "original": text} | |
| def transcribe_comparison(audio_file): | |
| """Main function for Gradio interface.""" | |
| if audio_file is None: | |
| return "❌ Nessun file audio fornito", "❌ Nessun file audio fornito" | |
| # Model configuration | |
| model_id = get_env_or_secret("HF_MODEL_ID") | |
| base_model_id = get_env_or_secret("BASE_WHISPER_MODEL_ID") | |
| hf_token = get_env_or_secret("HF_TOKEN") or get_env_or_secret( | |
| "HUGGINGFACEHUB_API_TOKEN" | |
| ) | |
| if not model_id or not base_model_id: | |
| error_msg = "❌ Modelli non configurati. Impostare HF_MODEL_ID e BASE_WHISPER_MODEL_ID nelle variabili d'ambiente" | |
| return error_msg, error_msg | |
| # Preprocessing sempre attivo (nascosto all'utente) | |
| # Non viene più utilizzato nel codice ma potrebbe servire per future implementazioni | |
| # Fixed settings optimized for medical transcription | |
| language = "it" # Always Italian for ScribeAId | |
| task = "transcribe" | |
| return_ts = True # Timestamps for medical report segments | |
| device_pref = "auto" # Auto-detect best device | |
| dtype_pref = "auto" # Auto-select optimal precision | |
| chunk_len = 7 # 7-second chunks for better context | |
| stride_len = 1 # Minimal stride for accuracy | |
| try: | |
| # Use the audio file path directly from Gradio | |
| tmp_path = audio_file | |
| original_result = None | |
| finetuned_result = None | |
| original_text = "" | |
| finetuned_text = "" | |
| try: | |
| # Transcribe with original model | |
| original_result = transcribe_local( | |
| audio_path=tmp_path, | |
| model_id=base_model_id, | |
| base_model_id=base_model_id, | |
| language=language, | |
| task=task, | |
| device_pref=device_pref, | |
| dtype_pref=dtype_pref, | |
| hf_token=None, # Base model doesn't need token | |
| chunk_length_s=int(chunk_len) if chunk_len else None, | |
| stride_length_s=int(stride_len) if stride_len else None, | |
| return_timestamps=return_ts, | |
| ) | |
| # Extract text from result | |
| if isinstance(original_result["result"], dict): | |
| original_text = original_result["result"].get( | |
| "text" | |
| ) or original_result["result"].get("transcription") | |
| elif isinstance(original_result["result"], str): | |
| original_text = original_result["result"] | |
| if original_text: | |
| result = handle_whisper_problematic_output( | |
| original_text, "Original Whisper" | |
| ) | |
| if result["is_problematic"]: | |
| original_text = f"⚠️ {result['text']}\n\n💡 Suggerimenti:\n• Registra almeno 5-10 secondi di audio\n• Parla chiaramente e ad alto volume\n• Avvicinati al microfono\n• Evita rumori di fondo" | |
| else: | |
| original_text = result["text"] | |
| else: | |
| original_text = "❌ Nessun testo restituito dal modello originale" | |
| except Exception as e: | |
| original_text = f"❌ Errore modello originale: {str(e)}" | |
| try: | |
| # Transcribe with fine-tuned model | |
| finetuned_result = transcribe_local( | |
| audio_path=tmp_path, | |
| model_id=model_id, | |
| base_model_id=base_model_id, | |
| language=language, | |
| task=task, | |
| device_pref=device_pref, | |
| dtype_pref=dtype_pref, | |
| hf_token=hf_token or None, | |
| chunk_length_s=int(chunk_len) if chunk_len else None, | |
| stride_length_s=int(stride_len) if stride_len else None, | |
| return_timestamps=return_ts, | |
| ) | |
| # Extract text from result | |
| if isinstance(finetuned_result["result"], dict): | |
| finetuned_text = finetuned_result["result"].get( | |
| "text" | |
| ) or finetuned_result["result"].get("transcription") | |
| elif isinstance(finetuned_result["result"], str): | |
| finetuned_text = finetuned_result["result"] | |
| if finetuned_text: | |
| result = handle_whisper_problematic_output( | |
| finetuned_text, "Fine-tuned Model" | |
| ) | |
| if result["is_problematic"]: | |
| finetuned_text = f"⚠️ {result['text']}\n\n💡 Suggerimenti:\n• Registra almeno 5-10 secondi di audio\n• Parla chiaramente e ad alto volume\n• Avvicinati al microfono\n• Evita rumori di fondo" | |
| else: | |
| finetuned_text = result["text"] | |
| else: | |
| finetuned_text = "❌ Nessun testo restituito dal modello fine-tuned" | |
| except Exception as e: | |
| finetuned_text = f"❌ Errore modello fine-tuned: {str(e)}" | |
| # GPU memory cleanup | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| return original_text, finetuned_text | |
| except Exception as e: | |
| error_msg = f"❌ Errore generale: {str(e)}" | |
| return error_msg, error_msg | |
| # Gradio interface | |
| def create_interface(): | |
| """Create and configure the Gradio interface.""" | |
| model_id = get_env_or_secret("HF_MODEL_ID", "ReportAId/whisper-medium-it-finetuned") | |
| base_model_id = get_env_or_secret("BASE_WHISPER_MODEL_ID", "openai/whisper-medium") | |
| # Carica il logo SVG inline per garantirne la visualizzazione anche senza routing file | |
| logo_html = None | |
| try: | |
| logo_path = os.path.join(os.path.dirname(__file__), "assets", "ScribeAId.svg") | |
| with open(logo_path, "r", encoding="utf-8") as f: | |
| svg_content = f.read() | |
| # Wrappa lo svg in un contenitore centrato | |
| logo_html = f""" | |
| <div style=\"text-align: center; margin: 16px 0 8px;\"> | |
| <div style=\"display:inline-block; height:60px;\">{svg_content}</div> | |
| </div> | |
| """ | |
| except Exception: | |
| # Fallback al path file= se per qualche motivo non riusciamo a leggere il file | |
| logo_html = """ | |
| <div style=\"text-align: center; margin: 16px 0 8px;\"> | |
| <img src=\"file=assets/ScribeAId.svg\" alt=\"ScribeAId\" style=\"height: 60px; margin-bottom: 8px;\"> | |
| </div> | |
| """ | |
| with gr.Blocks( | |
| title="ScribeAId - Medical Transcription", | |
| theme=gr.themes.Default(primary_hue="blue"), | |
| css=".gradio-container{max-width: 900px !important; margin: 0 auto !important;} .center-col{display:flex;flex-direction:column;align-items:center;} .center-col .wrap{width:100%;}", | |
| ) as demo: | |
| # Header con logo ScribeAId (semplice, bianco/nero) | |
| gr.HTML(logo_html) | |
| gr.Markdown(""" | |
| Quest’applicazione confronta un Whisper V3 di base con il modello open-source fine-tuned pubblicato da ReportAId su dati ambulatoriali italiani. È progettato per mitigare errori noti e migliorare le performance. Carica un audio o registra la voce: noterai trascrizioni più accurate di termini clinici come “Holter delle 24 ore”, “fibrillazione atriale” o “pressione sistolica”. | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown(f""" | |
| **⚙️ Impostazioni** | |
| - Modello originale: `{base_model_id}` | |
| - Modello fine-tuned: `{model_id}` | |
| - Lingua: Italiano (it) | |
| - Preprocessing audio: ottimizzato per registrazioni mediche | |
| """) | |
| gr.Markdown("---") | |
| # Titolo sezione input | |
| gr.Markdown("## Input") | |
| # Audio input e pulsante allineati a sinistra | |
| audio_input = gr.Audio( | |
| label="📥 Registra dal microfono o carica un file", | |
| type="filepath", | |
| sources=["microphone", "upload"], | |
| format="wav", | |
| streaming=False, | |
| interactive=True, | |
| ) | |
| transcribe_btn = gr.Button("🚀 Trascrivi e Confronta", variant="primary") | |
| gr.Markdown("---") | |
| gr.Markdown("## Output") | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("### Modello base (Whisper V3)") | |
| original_output = gr.Textbox( | |
| label="Transcription", | |
| lines=12, | |
| interactive=False, | |
| show_copy_button=True, | |
| ) | |
| with gr.Column(): | |
| gr.Markdown("### Modello fine-tuned ReportAId") | |
| finetuned_output = gr.Textbox( | |
| label="Transcription", | |
| lines=12, | |
| interactive=False, | |
| show_copy_button=True, | |
| ) | |
| # Click event | |
| transcribe_btn.click( | |
| fn=transcribe_comparison, | |
| inputs=[audio_input], | |
| outputs=[original_output, finetuned_output], | |
| show_progress=True, | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| # Configure logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", | |
| ) | |
| demo = create_interface() | |
| # Launch configuration for Hugging Face Spaces | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=False, | |
| show_error=True, | |
| inbrowser=False, | |
| quiet=False, | |
| ) | |