calettippo's picture
Add gradio app
eebc859
raw
history blame
22.2 kB
import os
import sys
import tempfile
import time
import logging
import gc
from dataclasses import dataclass
from typing import Optional, Tuple, List, Any, Dict
from contextlib import contextmanager
import gradio as gr
import torch
import psutil
from dotenv import load_dotenv
load_dotenv()
# Audio preprocessing not available in Hugging Face Spaces deployment
PREPROCESSING_AVAILABLE = False
def get_env_or_secret(key: str, default: Optional[str] = None) -> Optional[str]:
"""Get environment variable or default."""
return os.environ.get(key, default)
@dataclass
class InferenceMetrics:
"""Track inference performance metrics."""
processing_time: float
memory_usage: float
device_used: str
dtype_used: str
model_size_mb: Optional[float] = None
@dataclass
class PreprocessingConfig:
"""Configuration for audio preprocessing pipeline."""
normalize_format: bool = True
normalize_volume: bool = True
reduce_noise: bool = False
remove_silence: bool = False
def load_asr_pipeline(
model_id: str,
base_model_id: str,
device_pref: str = "auto",
hf_token: Optional[str] = None,
dtype_pref: str = "auto",
chunk_length_s: Optional[int] = None,
return_timestamps: bool = False,
):
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
logger.info(f"Loading ASR pipeline for model: {model_id}")
logger.info(
f"Device preference: {device_pref}, Token provided: {hf_token is not None}"
)
import torch
from transformers import pipeline
# Pick optimal device for inference
device_str = "cpu"
if device_pref == "auto":
if torch.cuda.is_available():
device_str = "cuda"
logger.info(f"Using CUDA: {torch.cuda.get_device_name()}")
elif getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
device_str = "mps"
logger.info("Using Apple Silicon MPS for inference")
else:
device_str = "cpu"
logger.info("Using CPU for inference")
else:
device_str = device_pref
# Pick dtype - optimized for inference performance
dtype = None
if dtype_pref == "auto":
# For whisper-medium models, use float32 for stability in medical transcription
if "whisper-medium" in model_id:
dtype = torch.float32
logger.info(
f"Using float32 for {model_id} (medical transcription stability)"
)
elif device_str == "cuda":
dtype = torch.float16 # Use half precision on GPU for speed
logger.info("Using float16 on CUDA for faster inference")
else:
dtype = torch.float32
else:
dtype = {"float32": torch.float32, "float16": torch.float16}.get(
dtype_pref, torch.float32
)
logger.info("Pipeline configuration:")
logger.info(f" Model: {model_id}")
logger.info(f" Base model: {base_model_id}")
logger.info(f" Dtype: {dtype}")
logger.info(f" Device: {device_str}")
logger.info(f" Chunk length: {chunk_length_s}s")
logger.info(f" Return timestamps: {return_timestamps}")
# Use ultra-simplified approach to avoid all compatibility issues
try:
logger.info(
"Setting up ultra-simplified pipeline to avoid forced_decoder_ids conflicts..."
)
# Create pipeline with absolute minimal configuration
asr = pipeline(
task="automatic-speech-recognition",
model=model_id,
torch_dtype=dtype,
device=0
if device_str == "cuda"
else ("mps" if device_str == "mps" else "cpu"),
token=hf_token,
)
# Post-loading cleanup to remove any forced_decoder_ids
if hasattr(asr.model, "generation_config"):
if hasattr(asr.model.generation_config, "forced_decoder_ids"):
logger.info("Removing forced_decoder_ids from model generation config")
asr.model.generation_config.forced_decoder_ids = None
# Set basic parameters after loading
if chunk_length_s:
logger.info(f"Setting chunk_length_s to {chunk_length_s}")
logger.info(f"Successfully created ultra-simplified pipeline for: {model_id}")
except Exception as e:
logger.error(f"Ultra-simplified pipeline creation failed: {e}")
logger.info("Falling back to absolute minimal settings...")
try:
# Fallback with absolute minimal settings
fallback_dtype = torch.float32
asr = pipeline(
task="automatic-speech-recognition",
model=model_id,
torch_dtype=fallback_dtype,
device="cpu", # Force CPU for maximum compatibility
token=hf_token,
)
# Post-loading cleanup
if hasattr(asr.model, "generation_config"):
if hasattr(asr.model.generation_config, "forced_decoder_ids"):
logger.info("Removing forced_decoder_ids from fallback model")
asr.model.generation_config.forced_decoder_ids = None
device_str = "cpu"
dtype = fallback_dtype
logger.info(
f"Minimal fallback pipeline created with dtype: {fallback_dtype}"
)
except Exception as fallback_error:
logger.error(f"Minimal fallback failed: {fallback_error}")
raise
return asr, device_str, str(dtype).replace("torch.", "")
@contextmanager
def memory_monitor():
"""Context manager to monitor memory usage during inference."""
process = psutil.Process()
start_memory = process.memory_info().rss / 1024 / 1024 # MB
yield
end_memory = process.memory_info().rss / 1024 / 1024 # MB
return end_memory - start_memory
def transcribe_local(
audio_path: str,
model_id: str,
base_model_id: str,
language: Optional[str],
task: str,
device_pref: str,
dtype_pref: str,
hf_token: Optional[str],
chunk_length_s: Optional[int],
stride_length_s: Optional[int],
return_timestamps: bool,
) -> Dict[str, Any]:
logger = logging.getLogger(__name__)
logger.info(f"Starting transcription: {os.path.basename(audio_path)}")
logger.info(f"Model: {model_id}")
# Validate audio_path
if audio_path is None:
raise ValueError("Audio path is None")
if not isinstance(audio_path, (str, bytes, os.PathLike)):
raise TypeError(
f"Audio path must be str, bytes or os.PathLike, got {type(audio_path)}"
)
if not os.path.exists(audio_path):
raise FileNotFoundError(f"Audio file not found: {audio_path}")
# Load ASR pipeline with performance monitoring
start_time = time.time()
asr, device_str, dtype_str = load_asr_pipeline(
model_id=model_id,
base_model_id=base_model_id,
device_pref=device_pref,
hf_token=hf_token,
dtype_pref=dtype_pref,
chunk_length_s=chunk_length_s,
return_timestamps=return_timestamps,
)
load_time = time.time() - start_time
logger.info(f"Model loaded in {load_time:.2f}s")
# Simplified configuration to avoid compatibility issues
# Let the pipeline handle generation parameters internally
logger.info("Using simplified configuration to avoid model compatibility issues")
# Setup inference parameters with performance monitoring
try:
# Start with minimal parameters to avoid conflicts
asr_kwargs = {}
# Only add parameters that are safe and supported
if return_timestamps:
asr_kwargs["return_timestamps"] = return_timestamps
logger.info("Timestamps enabled")
# Apply chunking strategy only if supported
if chunk_length_s:
try:
asr_kwargs["chunk_length_s"] = chunk_length_s
logger.info(f"Using chunking strategy: {chunk_length_s}s")
except Exception as chunk_error:
logger.warning(f"Chunking not supported: {chunk_error}")
if stride_length_s is not None:
try:
asr_kwargs["stride_length_s"] = stride_length_s
logger.info(f"Using stride: {stride_length_s}s")
except Exception as stride_error:
logger.warning(f"Stride not supported: {stride_error}")
logger.info(f"Inference parameters configured: {list(asr_kwargs.keys())}")
# Run inference with performance monitoring
inference_start = time.time()
memory_before = psutil.Process().memory_info().rss / 1024 / 1024 # MB
try:
# Primary inference attempt with safe parameters
if asr_kwargs:
result = asr(audio_path, **asr_kwargs)
else:
# Fallback to no parameters if all failed
result = asr(audio_path)
inference_time = time.time() - inference_start
memory_after = psutil.Process().memory_info().rss / 1024 / 1024 # MB
memory_used = memory_after - memory_before
logger.info(f"Inference completed successfully in {inference_time:.2f}s")
logger.info(f"Memory used: {memory_used:.1f}MB")
except Exception as e:
error_msg = str(e)
logger.warning(f"Inference failed with parameters: {error_msg}")
# Try with absolutely minimal parameters
if "forced_decoder_ids" in error_msg:
logger.info(
"Detected forced_decoder_ids error, trying with no parameters..."
)
elif (
"probability tensor contains either inf, nan or element < 0"
in error_msg
):
logger.info(
"Detected numerical instability, trying with no parameters..."
)
else:
logger.info("Unknown error, trying with no parameters...")
try:
inference_start = time.time()
result = asr(audio_path) # No parameters at all
inference_time = time.time() - inference_start
memory_used = 0 # Reset memory tracking
logger.info(f"Minimal inference completed in {inference_time:.2f}s")
except Exception as final_error:
logger.error(f"All inference attempts failed: {final_error}")
raise
except Exception as e:
logger.error(f"Inference failed: {e}")
raise
# Cleanup GPU memory after inference
if device_str == "cuda":
torch.cuda.empty_cache()
gc.collect()
# Return results with performance metrics
meta = {
"device": device_str,
"dtype": dtype_str,
"inference_time": inference_time,
"memory_used_mb": memory_used,
"model_type": "original" if model_id == base_model_id else "fine-tuned",
}
return {"result": result, "meta": meta}
def handle_whisper_problematic_output(text: str, model_name: str = "Whisper") -> dict:
"""Gestisce gli output problematici di Whisper come '!', '.', stringhe vuote, ecc."""
if not text:
return {
"text": "[WHISPER ISSUE: Output vuoto - Audio troppo corto o silenzioso]",
"is_problematic": True,
"original": text,
"issue_type": "empty",
}
text_stripped = text.strip()
# Casi problematici comuni
problematic_outputs = {
"!": "Audio troppo corto/silenzioso",
".": "Audio di bassa qualità",
"?": "Audio incomprensibile",
"...": "Audio troppo lungo senza parlato",
"--": "Audio distorto",
"—": "Audio con troppo rumore",
" per!": "Audio parzialmente comprensibile",
"per!": "Audio parzialmente comprensibile",
}
if text_stripped in problematic_outputs:
return {
"text": f"[WHISPER ISSUE: '{text_stripped}' - {problematic_outputs[text_stripped]}]",
"is_problematic": True,
"original": text,
"issue_type": text_stripped,
"suggestion": problematic_outputs[text_stripped],
}
# Testo troppo corto (meno di 3 caratteri e non alfabetico)
if len(text_stripped) <= 2 and not text_stripped.isalpha():
return {
"text": f"[WHISPER ISSUE: '{text_stripped}' - Output troppo corto/simbolico]",
"is_problematic": True,
"original": text,
"issue_type": "short_symbolic",
}
return {"text": text, "is_problematic": False, "original": text}
def transcribe_comparison(audio_file):
"""Main function for Gradio interface."""
if audio_file is None:
return "❌ Nessun file audio fornito", "❌ Nessun file audio fornito"
# Model configuration
model_id = get_env_or_secret("HF_MODEL_ID")
base_model_id = get_env_or_secret("BASE_WHISPER_MODEL_ID")
hf_token = get_env_or_secret("HF_TOKEN") or get_env_or_secret(
"HUGGINGFACEHUB_API_TOKEN"
)
if not model_id or not base_model_id:
error_msg = "❌ Modelli non configurati. Impostare HF_MODEL_ID e BASE_WHISPER_MODEL_ID nelle variabili d'ambiente"
return error_msg, error_msg
# Preprocessing sempre attivo (nascosto all'utente)
# Non viene più utilizzato nel codice ma potrebbe servire per future implementazioni
# Fixed settings optimized for medical transcription
language = "it" # Always Italian for ScribeAId
task = "transcribe"
return_ts = True # Timestamps for medical report segments
device_pref = "auto" # Auto-detect best device
dtype_pref = "auto" # Auto-select optimal precision
chunk_len = 7 # 7-second chunks for better context
stride_len = 1 # Minimal stride for accuracy
try:
# Use the audio file path directly from Gradio
tmp_path = audio_file
original_result = None
finetuned_result = None
original_text = ""
finetuned_text = ""
try:
# Transcribe with original model
original_result = transcribe_local(
audio_path=tmp_path,
model_id=base_model_id,
base_model_id=base_model_id,
language=language,
task=task,
device_pref=device_pref,
dtype_pref=dtype_pref,
hf_token=None, # Base model doesn't need token
chunk_length_s=int(chunk_len) if chunk_len else None,
stride_length_s=int(stride_len) if stride_len else None,
return_timestamps=return_ts,
)
# Extract text from result
if isinstance(original_result["result"], dict):
original_text = original_result["result"].get(
"text"
) or original_result["result"].get("transcription")
elif isinstance(original_result["result"], str):
original_text = original_result["result"]
if original_text:
result = handle_whisper_problematic_output(
original_text, "Original Whisper"
)
if result["is_problematic"]:
original_text = f"⚠️ {result['text']}\n\n💡 Suggerimenti:\n• Registra almeno 5-10 secondi di audio\n• Parla chiaramente e ad alto volume\n• Avvicinati al microfono\n• Evita rumori di fondo"
else:
original_text = result["text"]
else:
original_text = "❌ Nessun testo restituito dal modello originale"
except Exception as e:
original_text = f"❌ Errore modello originale: {str(e)}"
try:
# Transcribe with fine-tuned model
finetuned_result = transcribe_local(
audio_path=tmp_path,
model_id=model_id,
base_model_id=base_model_id,
language=language,
task=task,
device_pref=device_pref,
dtype_pref=dtype_pref,
hf_token=hf_token or None,
chunk_length_s=int(chunk_len) if chunk_len else None,
stride_length_s=int(stride_len) if stride_len else None,
return_timestamps=return_ts,
)
# Extract text from result
if isinstance(finetuned_result["result"], dict):
finetuned_text = finetuned_result["result"].get(
"text"
) or finetuned_result["result"].get("transcription")
elif isinstance(finetuned_result["result"], str):
finetuned_text = finetuned_result["result"]
if finetuned_text:
result = handle_whisper_problematic_output(
finetuned_text, "Fine-tuned Model"
)
if result["is_problematic"]:
finetuned_text = f"⚠️ {result['text']}\n\n💡 Suggerimenti:\n• Registra almeno 5-10 secondi di audio\n• Parla chiaramente e ad alto volume\n• Avvicinati al microfono\n• Evita rumori di fondo"
else:
finetuned_text = result["text"]
else:
finetuned_text = "❌ Nessun testo restituito dal modello fine-tuned"
except Exception as e:
finetuned_text = f"❌ Errore modello fine-tuned: {str(e)}"
# GPU memory cleanup
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
return original_text, finetuned_text
except Exception as e:
error_msg = f"❌ Errore generale: {str(e)}"
return error_msg, error_msg
# Gradio interface
def create_interface():
"""Create and configure the Gradio interface."""
model_id = get_env_or_secret("HF_MODEL_ID", "ReportAId/whisper-medium-it-finetuned")
base_model_id = get_env_or_secret("BASE_WHISPER_MODEL_ID", "openai/whisper-medium")
# Carica il logo SVG inline per garantirne la visualizzazione anche senza routing file
logo_html = None
try:
logo_path = os.path.join(os.path.dirname(__file__), "assets", "ScribeAId.svg")
with open(logo_path, "r", encoding="utf-8") as f:
svg_content = f.read()
# Wrappa lo svg in un contenitore centrato
logo_html = f"""
<div style=\"text-align: center; margin: 16px 0 8px;\">
<div style=\"display:inline-block; height:60px;\">{svg_content}</div>
</div>
"""
except Exception:
# Fallback al path file= se per qualche motivo non riusciamo a leggere il file
logo_html = """
<div style=\"text-align: center; margin: 16px 0 8px;\">
<img src=\"file=assets/ScribeAId.svg\" alt=\"ScribeAId\" style=\"height: 60px; margin-bottom: 8px;\">
</div>
"""
with gr.Blocks(
title="ScribeAId - Medical Transcription",
theme=gr.themes.Default(primary_hue="blue"),
css=".gradio-container{max-width: 900px !important; margin: 0 auto !important;} .center-col{display:flex;flex-direction:column;align-items:center;} .center-col .wrap{width:100%;}",
) as demo:
# Header con logo ScribeAId (semplice, bianco/nero)
gr.HTML(logo_html)
gr.Markdown("""
Quest’applicazione confronta un Whisper V3 di base con il modello open-source fine-tuned pubblicato da ReportAId su dati ambulatoriali italiani. È progettato per mitigare errori noti e migliorare le performance. Carica un audio o registra la voce: noterai trascrizioni più accurate di termini clinici come “Holter delle 24 ore”, “fibrillazione atriale” o “pressione sistolica”.
""")
with gr.Row():
with gr.Column():
gr.Markdown(f"""
**⚙️ Impostazioni**
- Modello originale: `{base_model_id}`
- Modello fine-tuned: `{model_id}`
- Lingua: Italiano (it)
- Preprocessing audio: ottimizzato per registrazioni mediche
""")
gr.Markdown("---")
# Titolo sezione input
gr.Markdown("## Input")
# Audio input e pulsante allineati a sinistra
audio_input = gr.Audio(
label="📥 Registra dal microfono o carica un file",
type="filepath",
sources=["microphone", "upload"],
format="wav",
streaming=False,
interactive=True,
)
transcribe_btn = gr.Button("🚀 Trascrivi e Confronta", variant="primary")
gr.Markdown("---")
gr.Markdown("## Output")
with gr.Row():
with gr.Column():
gr.Markdown("### Modello base (Whisper V3)")
original_output = gr.Textbox(
label="Transcription",
lines=12,
interactive=False,
show_copy_button=True,
)
with gr.Column():
gr.Markdown("### Modello fine-tuned ReportAId")
finetuned_output = gr.Textbox(
label="Transcription",
lines=12,
interactive=False,
show_copy_button=True,
)
# Click event
transcribe_btn.click(
fn=transcribe_comparison,
inputs=[audio_input],
outputs=[original_output, finetuned_output],
show_progress=True,
)
return demo
if __name__ == "__main__":
# Configure logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
demo = create_interface()
# Launch configuration for Hugging Face Spaces
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False,
show_error=True,
inbrowser=False,
quiet=False,
)