import os import sys import tempfile import time import logging import gc from dataclasses import dataclass from typing import Optional, Tuple, List, Any, Dict from contextlib import contextmanager import gradio as gr import torch import psutil from dotenv import load_dotenv load_dotenv() # Audio preprocessing not available in Hugging Face Spaces deployment PREPROCESSING_AVAILABLE = False def get_env_or_secret(key: str, default: Optional[str] = None) -> Optional[str]: """Get environment variable or default.""" return os.environ.get(key, default) @dataclass class InferenceMetrics: """Track inference performance metrics.""" processing_time: float memory_usage: float device_used: str dtype_used: str model_size_mb: Optional[float] = None @dataclass class PreprocessingConfig: """Configuration for audio preprocessing pipeline.""" normalize_format: bool = True normalize_volume: bool = True reduce_noise: bool = False remove_silence: bool = False def load_asr_pipeline( model_id: str, base_model_id: str, device_pref: str = "auto", hf_token: Optional[str] = None, dtype_pref: str = "auto", chunk_length_s: Optional[int] = None, return_timestamps: bool = False, ): logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) logger.info(f"Loading ASR pipeline for model: {model_id}") logger.info( f"Device preference: {device_pref}, Token provided: {hf_token is not None}" ) import torch from transformers import pipeline # Pick optimal device for inference device_str = "cpu" if device_pref == "auto": if torch.cuda.is_available(): device_str = "cuda" logger.info(f"Using CUDA: {torch.cuda.get_device_name()}") elif getattr(torch.backends, "mps", None) and torch.backends.mps.is_available(): device_str = "mps" logger.info("Using Apple Silicon MPS for inference") else: device_str = "cpu" logger.info("Using CPU for inference") else: device_str = device_pref # Pick dtype - optimized for inference performance dtype = None if dtype_pref == "auto": # For whisper-medium models, use float32 for stability in medical transcription if "whisper-medium" in model_id: dtype = torch.float32 logger.info( f"Using float32 for {model_id} (medical transcription stability)" ) elif device_str == "cuda": dtype = torch.float16 # Use half precision on GPU for speed logger.info("Using float16 on CUDA for faster inference") else: dtype = torch.float32 else: dtype = {"float32": torch.float32, "float16": torch.float16}.get( dtype_pref, torch.float32 ) logger.info("Pipeline configuration:") logger.info(f" Model: {model_id}") logger.info(f" Base model: {base_model_id}") logger.info(f" Dtype: {dtype}") logger.info(f" Device: {device_str}") logger.info(f" Chunk length: {chunk_length_s}s") logger.info(f" Return timestamps: {return_timestamps}") # Use ultra-simplified approach to avoid all compatibility issues try: logger.info( "Setting up ultra-simplified pipeline to avoid forced_decoder_ids conflicts..." ) # Create pipeline with absolute minimal configuration asr = pipeline( task="automatic-speech-recognition", model=model_id, torch_dtype=dtype, device=0 if device_str == "cuda" else ("mps" if device_str == "mps" else "cpu"), token=hf_token, ) # Post-loading cleanup to remove any forced_decoder_ids if hasattr(asr.model, "generation_config"): if hasattr(asr.model.generation_config, "forced_decoder_ids"): logger.info("Removing forced_decoder_ids from model generation config") asr.model.generation_config.forced_decoder_ids = None # Set basic parameters after loading if chunk_length_s: logger.info(f"Setting chunk_length_s to {chunk_length_s}") logger.info(f"Successfully created ultra-simplified pipeline for: {model_id}") except Exception as e: logger.error(f"Ultra-simplified pipeline creation failed: {e}") logger.info("Falling back to absolute minimal settings...") try: # Fallback with absolute minimal settings fallback_dtype = torch.float32 asr = pipeline( task="automatic-speech-recognition", model=model_id, torch_dtype=fallback_dtype, device="cpu", # Force CPU for maximum compatibility token=hf_token, ) # Post-loading cleanup if hasattr(asr.model, "generation_config"): if hasattr(asr.model.generation_config, "forced_decoder_ids"): logger.info("Removing forced_decoder_ids from fallback model") asr.model.generation_config.forced_decoder_ids = None device_str = "cpu" dtype = fallback_dtype logger.info( f"Minimal fallback pipeline created with dtype: {fallback_dtype}" ) except Exception as fallback_error: logger.error(f"Minimal fallback failed: {fallback_error}") raise return asr, device_str, str(dtype).replace("torch.", "") @contextmanager def memory_monitor(): """Context manager to monitor memory usage during inference.""" process = psutil.Process() start_memory = process.memory_info().rss / 1024 / 1024 # MB yield end_memory = process.memory_info().rss / 1024 / 1024 # MB return end_memory - start_memory def transcribe_local( audio_path: str, model_id: str, base_model_id: str, language: Optional[str], task: str, device_pref: str, dtype_pref: str, hf_token: Optional[str], chunk_length_s: Optional[int], stride_length_s: Optional[int], return_timestamps: bool, ) -> Dict[str, Any]: logger = logging.getLogger(__name__) logger.info(f"Starting transcription: {os.path.basename(audio_path)}") logger.info(f"Model: {model_id}") # Validate audio_path if audio_path is None: raise ValueError("Audio path is None") if not isinstance(audio_path, (str, bytes, os.PathLike)): raise TypeError( f"Audio path must be str, bytes or os.PathLike, got {type(audio_path)}" ) if not os.path.exists(audio_path): raise FileNotFoundError(f"Audio file not found: {audio_path}") # Load ASR pipeline with performance monitoring start_time = time.time() asr, device_str, dtype_str = load_asr_pipeline( model_id=model_id, base_model_id=base_model_id, device_pref=device_pref, hf_token=hf_token, dtype_pref=dtype_pref, chunk_length_s=chunk_length_s, return_timestamps=return_timestamps, ) load_time = time.time() - start_time logger.info(f"Model loaded in {load_time:.2f}s") # Simplified configuration to avoid compatibility issues # Let the pipeline handle generation parameters internally logger.info("Using simplified configuration to avoid model compatibility issues") # Setup inference parameters with performance monitoring try: # Start with minimal parameters to avoid conflicts asr_kwargs = {} # Only add parameters that are safe and supported if return_timestamps: asr_kwargs["return_timestamps"] = return_timestamps logger.info("Timestamps enabled") # Apply chunking strategy only if supported if chunk_length_s: try: asr_kwargs["chunk_length_s"] = chunk_length_s logger.info(f"Using chunking strategy: {chunk_length_s}s") except Exception as chunk_error: logger.warning(f"Chunking not supported: {chunk_error}") if stride_length_s is not None: try: asr_kwargs["stride_length_s"] = stride_length_s logger.info(f"Using stride: {stride_length_s}s") except Exception as stride_error: logger.warning(f"Stride not supported: {stride_error}") logger.info(f"Inference parameters configured: {list(asr_kwargs.keys())}") # Run inference with performance monitoring inference_start = time.time() memory_before = psutil.Process().memory_info().rss / 1024 / 1024 # MB try: # Primary inference attempt with safe parameters if asr_kwargs: result = asr(audio_path, **asr_kwargs) else: # Fallback to no parameters if all failed result = asr(audio_path) inference_time = time.time() - inference_start memory_after = psutil.Process().memory_info().rss / 1024 / 1024 # MB memory_used = memory_after - memory_before logger.info(f"Inference completed successfully in {inference_time:.2f}s") logger.info(f"Memory used: {memory_used:.1f}MB") except Exception as e: error_msg = str(e) logger.warning(f"Inference failed with parameters: {error_msg}") # Try with absolutely minimal parameters if "forced_decoder_ids" in error_msg: logger.info( "Detected forced_decoder_ids error, trying with no parameters..." ) elif ( "probability tensor contains either inf, nan or element < 0" in error_msg ): logger.info( "Detected numerical instability, trying with no parameters..." ) else: logger.info("Unknown error, trying with no parameters...") try: inference_start = time.time() result = asr(audio_path) # No parameters at all inference_time = time.time() - inference_start memory_used = 0 # Reset memory tracking logger.info(f"Minimal inference completed in {inference_time:.2f}s") except Exception as final_error: logger.error(f"All inference attempts failed: {final_error}") raise except Exception as e: logger.error(f"Inference failed: {e}") raise # Cleanup GPU memory after inference if device_str == "cuda": torch.cuda.empty_cache() gc.collect() # Return results with performance metrics meta = { "device": device_str, "dtype": dtype_str, "inference_time": inference_time, "memory_used_mb": memory_used, "model_type": "original" if model_id == base_model_id else "fine-tuned", } return {"result": result, "meta": meta} def handle_whisper_problematic_output(text: str, model_name: str = "Whisper") -> dict: """Gestisce gli output problematici di Whisper come '!', '.', stringhe vuote, ecc.""" if not text: return { "text": "[WHISPER ISSUE: Output vuoto - Audio troppo corto o silenzioso]", "is_problematic": True, "original": text, "issue_type": "empty", } text_stripped = text.strip() # Casi problematici comuni problematic_outputs = { "!": "Audio troppo corto/silenzioso", ".": "Audio di bassa qualità", "?": "Audio incomprensibile", "...": "Audio troppo lungo senza parlato", "--": "Audio distorto", "—": "Audio con troppo rumore", " per!": "Audio parzialmente comprensibile", "per!": "Audio parzialmente comprensibile", } if text_stripped in problematic_outputs: return { "text": f"[WHISPER ISSUE: '{text_stripped}' - {problematic_outputs[text_stripped]}]", "is_problematic": True, "original": text, "issue_type": text_stripped, "suggestion": problematic_outputs[text_stripped], } # Testo troppo corto (meno di 3 caratteri e non alfabetico) if len(text_stripped) <= 2 and not text_stripped.isalpha(): return { "text": f"[WHISPER ISSUE: '{text_stripped}' - Output troppo corto/simbolico]", "is_problematic": True, "original": text, "issue_type": "short_symbolic", } return {"text": text, "is_problematic": False, "original": text} def transcribe_comparison(audio_file): """Main function for Gradio interface.""" if audio_file is None: return "❌ Nessun file audio fornito", "❌ Nessun file audio fornito" # Model configuration model_id = get_env_or_secret("HF_MODEL_ID") base_model_id = get_env_or_secret("BASE_WHISPER_MODEL_ID") hf_token = get_env_or_secret("HF_TOKEN") or get_env_or_secret( "HUGGINGFACEHUB_API_TOKEN" ) if not model_id or not base_model_id: error_msg = "❌ Modelli non configurati. Impostare HF_MODEL_ID e BASE_WHISPER_MODEL_ID nelle variabili d'ambiente" return error_msg, error_msg # Preprocessing sempre attivo (nascosto all'utente) # Non viene più utilizzato nel codice ma potrebbe servire per future implementazioni # Fixed settings optimized for medical transcription language = "it" # Always Italian for ScribeAId task = "transcribe" return_ts = True # Timestamps for medical report segments device_pref = "auto" # Auto-detect best device dtype_pref = "auto" # Auto-select optimal precision chunk_len = 7 # 7-second chunks for better context stride_len = 1 # Minimal stride for accuracy try: # Use the audio file path directly from Gradio tmp_path = audio_file original_result = None finetuned_result = None original_text = "" finetuned_text = "" try: # Transcribe with original model original_result = transcribe_local( audio_path=tmp_path, model_id=base_model_id, base_model_id=base_model_id, language=language, task=task, device_pref=device_pref, dtype_pref=dtype_pref, hf_token=None, # Base model doesn't need token chunk_length_s=int(chunk_len) if chunk_len else None, stride_length_s=int(stride_len) if stride_len else None, return_timestamps=return_ts, ) # Extract text from result if isinstance(original_result["result"], dict): original_text = original_result["result"].get( "text" ) or original_result["result"].get("transcription") elif isinstance(original_result["result"], str): original_text = original_result["result"] if original_text: result = handle_whisper_problematic_output( original_text, "Original Whisper" ) if result["is_problematic"]: original_text = f"⚠️ {result['text']}\n\n💡 Suggerimenti:\n• Registra almeno 5-10 secondi di audio\n• Parla chiaramente e ad alto volume\n• Avvicinati al microfono\n• Evita rumori di fondo" else: original_text = result["text"] else: original_text = "❌ Nessun testo restituito dal modello originale" except Exception as e: original_text = f"❌ Errore modello originale: {str(e)}" try: # Transcribe with fine-tuned model finetuned_result = transcribe_local( audio_path=tmp_path, model_id=model_id, base_model_id=base_model_id, language=language, task=task, device_pref=device_pref, dtype_pref=dtype_pref, hf_token=hf_token or None, chunk_length_s=int(chunk_len) if chunk_len else None, stride_length_s=int(stride_len) if stride_len else None, return_timestamps=return_ts, ) # Extract text from result if isinstance(finetuned_result["result"], dict): finetuned_text = finetuned_result["result"].get( "text" ) or finetuned_result["result"].get("transcription") elif isinstance(finetuned_result["result"], str): finetuned_text = finetuned_result["result"] if finetuned_text: result = handle_whisper_problematic_output( finetuned_text, "Fine-tuned Model" ) if result["is_problematic"]: finetuned_text = f"⚠️ {result['text']}\n\n💡 Suggerimenti:\n• Registra almeno 5-10 secondi di audio\n• Parla chiaramente e ad alto volume\n• Avvicinati al microfono\n• Evita rumori di fondo" else: finetuned_text = result["text"] else: finetuned_text = "❌ Nessun testo restituito dal modello fine-tuned" except Exception as e: finetuned_text = f"❌ Errore modello fine-tuned: {str(e)}" # GPU memory cleanup if torch.cuda.is_available(): torch.cuda.empty_cache() gc.collect() return original_text, finetuned_text except Exception as e: error_msg = f"❌ Errore generale: {str(e)}" return error_msg, error_msg # Gradio interface def create_interface(): """Create and configure the Gradio interface.""" model_id = get_env_or_secret("HF_MODEL_ID", "ReportAId/whisper-medium-it-finetuned") base_model_id = get_env_or_secret("BASE_WHISPER_MODEL_ID", "openai/whisper-medium") # Carica il logo SVG inline per garantirne la visualizzazione anche senza routing file logo_html = None try: logo_path = os.path.join(os.path.dirname(__file__), "assets", "ScribeAId.svg") with open(logo_path, "r", encoding="utf-8") as f: svg_content = f.read() # Wrappa lo svg in un contenitore centrato logo_html = f"""