File size: 22,182 Bytes
eebc859 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 |
import os
import sys
import tempfile
import time
import logging
import gc
from dataclasses import dataclass
from typing import Optional, Tuple, List, Any, Dict
from contextlib import contextmanager
import gradio as gr
import torch
import psutil
from dotenv import load_dotenv
load_dotenv()
# Audio preprocessing not available in Hugging Face Spaces deployment
PREPROCESSING_AVAILABLE = False
def get_env_or_secret(key: str, default: Optional[str] = None) -> Optional[str]:
"""Get environment variable or default."""
return os.environ.get(key, default)
@dataclass
class InferenceMetrics:
"""Track inference performance metrics."""
processing_time: float
memory_usage: float
device_used: str
dtype_used: str
model_size_mb: Optional[float] = None
@dataclass
class PreprocessingConfig:
"""Configuration for audio preprocessing pipeline."""
normalize_format: bool = True
normalize_volume: bool = True
reduce_noise: bool = False
remove_silence: bool = False
def load_asr_pipeline(
model_id: str,
base_model_id: str,
device_pref: str = "auto",
hf_token: Optional[str] = None,
dtype_pref: str = "auto",
chunk_length_s: Optional[int] = None,
return_timestamps: bool = False,
):
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
logger.info(f"Loading ASR pipeline for model: {model_id}")
logger.info(
f"Device preference: {device_pref}, Token provided: {hf_token is not None}"
)
import torch
from transformers import pipeline
# Pick optimal device for inference
device_str = "cpu"
if device_pref == "auto":
if torch.cuda.is_available():
device_str = "cuda"
logger.info(f"Using CUDA: {torch.cuda.get_device_name()}")
elif getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
device_str = "mps"
logger.info("Using Apple Silicon MPS for inference")
else:
device_str = "cpu"
logger.info("Using CPU for inference")
else:
device_str = device_pref
# Pick dtype - optimized for inference performance
dtype = None
if dtype_pref == "auto":
# For whisper-medium models, use float32 for stability in medical transcription
if "whisper-medium" in model_id:
dtype = torch.float32
logger.info(
f"Using float32 for {model_id} (medical transcription stability)"
)
elif device_str == "cuda":
dtype = torch.float16 # Use half precision on GPU for speed
logger.info("Using float16 on CUDA for faster inference")
else:
dtype = torch.float32
else:
dtype = {"float32": torch.float32, "float16": torch.float16}.get(
dtype_pref, torch.float32
)
logger.info("Pipeline configuration:")
logger.info(f" Model: {model_id}")
logger.info(f" Base model: {base_model_id}")
logger.info(f" Dtype: {dtype}")
logger.info(f" Device: {device_str}")
logger.info(f" Chunk length: {chunk_length_s}s")
logger.info(f" Return timestamps: {return_timestamps}")
# Use ultra-simplified approach to avoid all compatibility issues
try:
logger.info(
"Setting up ultra-simplified pipeline to avoid forced_decoder_ids conflicts..."
)
# Create pipeline with absolute minimal configuration
asr = pipeline(
task="automatic-speech-recognition",
model=model_id,
torch_dtype=dtype,
device=0
if device_str == "cuda"
else ("mps" if device_str == "mps" else "cpu"),
token=hf_token,
)
# Post-loading cleanup to remove any forced_decoder_ids
if hasattr(asr.model, "generation_config"):
if hasattr(asr.model.generation_config, "forced_decoder_ids"):
logger.info("Removing forced_decoder_ids from model generation config")
asr.model.generation_config.forced_decoder_ids = None
# Set basic parameters after loading
if chunk_length_s:
logger.info(f"Setting chunk_length_s to {chunk_length_s}")
logger.info(f"Successfully created ultra-simplified pipeline for: {model_id}")
except Exception as e:
logger.error(f"Ultra-simplified pipeline creation failed: {e}")
logger.info("Falling back to absolute minimal settings...")
try:
# Fallback with absolute minimal settings
fallback_dtype = torch.float32
asr = pipeline(
task="automatic-speech-recognition",
model=model_id,
torch_dtype=fallback_dtype,
device="cpu", # Force CPU for maximum compatibility
token=hf_token,
)
# Post-loading cleanup
if hasattr(asr.model, "generation_config"):
if hasattr(asr.model.generation_config, "forced_decoder_ids"):
logger.info("Removing forced_decoder_ids from fallback model")
asr.model.generation_config.forced_decoder_ids = None
device_str = "cpu"
dtype = fallback_dtype
logger.info(
f"Minimal fallback pipeline created with dtype: {fallback_dtype}"
)
except Exception as fallback_error:
logger.error(f"Minimal fallback failed: {fallback_error}")
raise
return asr, device_str, str(dtype).replace("torch.", "")
@contextmanager
def memory_monitor():
"""Context manager to monitor memory usage during inference."""
process = psutil.Process()
start_memory = process.memory_info().rss / 1024 / 1024 # MB
yield
end_memory = process.memory_info().rss / 1024 / 1024 # MB
return end_memory - start_memory
def transcribe_local(
audio_path: str,
model_id: str,
base_model_id: str,
language: Optional[str],
task: str,
device_pref: str,
dtype_pref: str,
hf_token: Optional[str],
chunk_length_s: Optional[int],
stride_length_s: Optional[int],
return_timestamps: bool,
) -> Dict[str, Any]:
logger = logging.getLogger(__name__)
logger.info(f"Starting transcription: {os.path.basename(audio_path)}")
logger.info(f"Model: {model_id}")
# Validate audio_path
if audio_path is None:
raise ValueError("Audio path is None")
if not isinstance(audio_path, (str, bytes, os.PathLike)):
raise TypeError(
f"Audio path must be str, bytes or os.PathLike, got {type(audio_path)}"
)
if not os.path.exists(audio_path):
raise FileNotFoundError(f"Audio file not found: {audio_path}")
# Load ASR pipeline with performance monitoring
start_time = time.time()
asr, device_str, dtype_str = load_asr_pipeline(
model_id=model_id,
base_model_id=base_model_id,
device_pref=device_pref,
hf_token=hf_token,
dtype_pref=dtype_pref,
chunk_length_s=chunk_length_s,
return_timestamps=return_timestamps,
)
load_time = time.time() - start_time
logger.info(f"Model loaded in {load_time:.2f}s")
# Simplified configuration to avoid compatibility issues
# Let the pipeline handle generation parameters internally
logger.info("Using simplified configuration to avoid model compatibility issues")
# Setup inference parameters with performance monitoring
try:
# Start with minimal parameters to avoid conflicts
asr_kwargs = {}
# Only add parameters that are safe and supported
if return_timestamps:
asr_kwargs["return_timestamps"] = return_timestamps
logger.info("Timestamps enabled")
# Apply chunking strategy only if supported
if chunk_length_s:
try:
asr_kwargs["chunk_length_s"] = chunk_length_s
logger.info(f"Using chunking strategy: {chunk_length_s}s")
except Exception as chunk_error:
logger.warning(f"Chunking not supported: {chunk_error}")
if stride_length_s is not None:
try:
asr_kwargs["stride_length_s"] = stride_length_s
logger.info(f"Using stride: {stride_length_s}s")
except Exception as stride_error:
logger.warning(f"Stride not supported: {stride_error}")
logger.info(f"Inference parameters configured: {list(asr_kwargs.keys())}")
# Run inference with performance monitoring
inference_start = time.time()
memory_before = psutil.Process().memory_info().rss / 1024 / 1024 # MB
try:
# Primary inference attempt with safe parameters
if asr_kwargs:
result = asr(audio_path, **asr_kwargs)
else:
# Fallback to no parameters if all failed
result = asr(audio_path)
inference_time = time.time() - inference_start
memory_after = psutil.Process().memory_info().rss / 1024 / 1024 # MB
memory_used = memory_after - memory_before
logger.info(f"Inference completed successfully in {inference_time:.2f}s")
logger.info(f"Memory used: {memory_used:.1f}MB")
except Exception as e:
error_msg = str(e)
logger.warning(f"Inference failed with parameters: {error_msg}")
# Try with absolutely minimal parameters
if "forced_decoder_ids" in error_msg:
logger.info(
"Detected forced_decoder_ids error, trying with no parameters..."
)
elif (
"probability tensor contains either inf, nan or element < 0"
in error_msg
):
logger.info(
"Detected numerical instability, trying with no parameters..."
)
else:
logger.info("Unknown error, trying with no parameters...")
try:
inference_start = time.time()
result = asr(audio_path) # No parameters at all
inference_time = time.time() - inference_start
memory_used = 0 # Reset memory tracking
logger.info(f"Minimal inference completed in {inference_time:.2f}s")
except Exception as final_error:
logger.error(f"All inference attempts failed: {final_error}")
raise
except Exception as e:
logger.error(f"Inference failed: {e}")
raise
# Cleanup GPU memory after inference
if device_str == "cuda":
torch.cuda.empty_cache()
gc.collect()
# Return results with performance metrics
meta = {
"device": device_str,
"dtype": dtype_str,
"inference_time": inference_time,
"memory_used_mb": memory_used,
"model_type": "original" if model_id == base_model_id else "fine-tuned",
}
return {"result": result, "meta": meta}
def handle_whisper_problematic_output(text: str, model_name: str = "Whisper") -> dict:
"""Gestisce gli output problematici di Whisper come '!', '.', stringhe vuote, ecc."""
if not text:
return {
"text": "[WHISPER ISSUE: Output vuoto - Audio troppo corto o silenzioso]",
"is_problematic": True,
"original": text,
"issue_type": "empty",
}
text_stripped = text.strip()
# Casi problematici comuni
problematic_outputs = {
"!": "Audio troppo corto/silenzioso",
".": "Audio di bassa qualità",
"?": "Audio incomprensibile",
"...": "Audio troppo lungo senza parlato",
"--": "Audio distorto",
"—": "Audio con troppo rumore",
" per!": "Audio parzialmente comprensibile",
"per!": "Audio parzialmente comprensibile",
}
if text_stripped in problematic_outputs:
return {
"text": f"[WHISPER ISSUE: '{text_stripped}' - {problematic_outputs[text_stripped]}]",
"is_problematic": True,
"original": text,
"issue_type": text_stripped,
"suggestion": problematic_outputs[text_stripped],
}
# Testo troppo corto (meno di 3 caratteri e non alfabetico)
if len(text_stripped) <= 2 and not text_stripped.isalpha():
return {
"text": f"[WHISPER ISSUE: '{text_stripped}' - Output troppo corto/simbolico]",
"is_problematic": True,
"original": text,
"issue_type": "short_symbolic",
}
return {"text": text, "is_problematic": False, "original": text}
def transcribe_comparison(audio_file):
"""Main function for Gradio interface."""
if audio_file is None:
return "❌ Nessun file audio fornito", "❌ Nessun file audio fornito"
# Model configuration
model_id = get_env_or_secret("HF_MODEL_ID")
base_model_id = get_env_or_secret("BASE_WHISPER_MODEL_ID")
hf_token = get_env_or_secret("HF_TOKEN") or get_env_or_secret(
"HUGGINGFACEHUB_API_TOKEN"
)
if not model_id or not base_model_id:
error_msg = "❌ Modelli non configurati. Impostare HF_MODEL_ID e BASE_WHISPER_MODEL_ID nelle variabili d'ambiente"
return error_msg, error_msg
# Preprocessing sempre attivo (nascosto all'utente)
# Non viene più utilizzato nel codice ma potrebbe servire per future implementazioni
# Fixed settings optimized for medical transcription
language = "it" # Always Italian for ScribeAId
task = "transcribe"
return_ts = True # Timestamps for medical report segments
device_pref = "auto" # Auto-detect best device
dtype_pref = "auto" # Auto-select optimal precision
chunk_len = 7 # 7-second chunks for better context
stride_len = 1 # Minimal stride for accuracy
try:
# Use the audio file path directly from Gradio
tmp_path = audio_file
original_result = None
finetuned_result = None
original_text = ""
finetuned_text = ""
try:
# Transcribe with original model
original_result = transcribe_local(
audio_path=tmp_path,
model_id=base_model_id,
base_model_id=base_model_id,
language=language,
task=task,
device_pref=device_pref,
dtype_pref=dtype_pref,
hf_token=None, # Base model doesn't need token
chunk_length_s=int(chunk_len) if chunk_len else None,
stride_length_s=int(stride_len) if stride_len else None,
return_timestamps=return_ts,
)
# Extract text from result
if isinstance(original_result["result"], dict):
original_text = original_result["result"].get(
"text"
) or original_result["result"].get("transcription")
elif isinstance(original_result["result"], str):
original_text = original_result["result"]
if original_text:
result = handle_whisper_problematic_output(
original_text, "Original Whisper"
)
if result["is_problematic"]:
original_text = f"⚠️ {result['text']}\n\n💡 Suggerimenti:\n• Registra almeno 5-10 secondi di audio\n• Parla chiaramente e ad alto volume\n• Avvicinati al microfono\n• Evita rumori di fondo"
else:
original_text = result["text"]
else:
original_text = "❌ Nessun testo restituito dal modello originale"
except Exception as e:
original_text = f"❌ Errore modello originale: {str(e)}"
try:
# Transcribe with fine-tuned model
finetuned_result = transcribe_local(
audio_path=tmp_path,
model_id=model_id,
base_model_id=base_model_id,
language=language,
task=task,
device_pref=device_pref,
dtype_pref=dtype_pref,
hf_token=hf_token or None,
chunk_length_s=int(chunk_len) if chunk_len else None,
stride_length_s=int(stride_len) if stride_len else None,
return_timestamps=return_ts,
)
# Extract text from result
if isinstance(finetuned_result["result"], dict):
finetuned_text = finetuned_result["result"].get(
"text"
) or finetuned_result["result"].get("transcription")
elif isinstance(finetuned_result["result"], str):
finetuned_text = finetuned_result["result"]
if finetuned_text:
result = handle_whisper_problematic_output(
finetuned_text, "Fine-tuned Model"
)
if result["is_problematic"]:
finetuned_text = f"⚠️ {result['text']}\n\n💡 Suggerimenti:\n• Registra almeno 5-10 secondi di audio\n• Parla chiaramente e ad alto volume\n• Avvicinati al microfono\n• Evita rumori di fondo"
else:
finetuned_text = result["text"]
else:
finetuned_text = "❌ Nessun testo restituito dal modello fine-tuned"
except Exception as e:
finetuned_text = f"❌ Errore modello fine-tuned: {str(e)}"
# GPU memory cleanup
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
return original_text, finetuned_text
except Exception as e:
error_msg = f"❌ Errore generale: {str(e)}"
return error_msg, error_msg
# Gradio interface
def create_interface():
"""Create and configure the Gradio interface."""
model_id = get_env_or_secret("HF_MODEL_ID", "ReportAId/whisper-medium-it-finetuned")
base_model_id = get_env_or_secret("BASE_WHISPER_MODEL_ID", "openai/whisper-medium")
# Carica il logo SVG inline per garantirne la visualizzazione anche senza routing file
logo_html = None
try:
logo_path = os.path.join(os.path.dirname(__file__), "assets", "ScribeAId.svg")
with open(logo_path, "r", encoding="utf-8") as f:
svg_content = f.read()
# Wrappa lo svg in un contenitore centrato
logo_html = f"""
<div style=\"text-align: center; margin: 16px 0 8px;\">
<div style=\"display:inline-block; height:60px;\">{svg_content}</div>
</div>
"""
except Exception:
# Fallback al path file= se per qualche motivo non riusciamo a leggere il file
logo_html = """
<div style=\"text-align: center; margin: 16px 0 8px;\">
<img src=\"file=assets/ScribeAId.svg\" alt=\"ScribeAId\" style=\"height: 60px; margin-bottom: 8px;\">
</div>
"""
with gr.Blocks(
title="ScribeAId - Medical Transcription",
theme=gr.themes.Default(primary_hue="blue"),
css=".gradio-container{max-width: 900px !important; margin: 0 auto !important;} .center-col{display:flex;flex-direction:column;align-items:center;} .center-col .wrap{width:100%;}",
) as demo:
# Header con logo ScribeAId (semplice, bianco/nero)
gr.HTML(logo_html)
gr.Markdown("""
Quest’applicazione confronta un Whisper V3 di base con il modello open-source fine-tuned pubblicato da ReportAId su dati ambulatoriali italiani. È progettato per mitigare errori noti e migliorare le performance. Carica un audio o registra la voce: noterai trascrizioni più accurate di termini clinici come “Holter delle 24 ore”, “fibrillazione atriale” o “pressione sistolica”.
""")
with gr.Row():
with gr.Column():
gr.Markdown(f"""
**⚙️ Impostazioni**
- Modello originale: `{base_model_id}`
- Modello fine-tuned: `{model_id}`
- Lingua: Italiano (it)
- Preprocessing audio: ottimizzato per registrazioni mediche
""")
gr.Markdown("---")
# Titolo sezione input
gr.Markdown("## Input")
# Audio input e pulsante allineati a sinistra
audio_input = gr.Audio(
label="📥 Registra dal microfono o carica un file",
type="filepath",
sources=["microphone", "upload"],
format="wav",
streaming=False,
interactive=True,
)
transcribe_btn = gr.Button("🚀 Trascrivi e Confronta", variant="primary")
gr.Markdown("---")
gr.Markdown("## Output")
with gr.Row():
with gr.Column():
gr.Markdown("### Modello base (Whisper V3)")
original_output = gr.Textbox(
label="Transcription",
lines=12,
interactive=False,
show_copy_button=True,
)
with gr.Column():
gr.Markdown("### Modello fine-tuned ReportAId")
finetuned_output = gr.Textbox(
label="Transcription",
lines=12,
interactive=False,
show_copy_button=True,
)
# Click event
transcribe_btn.click(
fn=transcribe_comparison,
inputs=[audio_input],
outputs=[original_output, finetuned_output],
show_progress=True,
)
return demo
if __name__ == "__main__":
# Configure logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
demo = create_interface()
# Launch configuration for Hugging Face Spaces
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False,
show_error=True,
inbrowser=False,
quiet=False,
)
|