Commit
·
edf497c
1
Parent(s):
65b0afc
Improve Whisper pipeline caching
Browse files- .gitignore +1 -0
- app.py +127 -33
.gitignore
CHANGED
|
@@ -4,3 +4,4 @@ __pycache__/
|
|
| 4 |
*.py[cod]
|
| 5 |
.DS_Store
|
| 6 |
*.log
|
|
|
|
|
|
| 4 |
*.py[cod]
|
| 5 |
.DS_Store
|
| 6 |
*.log
|
| 7 |
+
hf_models/
|
app.py
CHANGED
|
@@ -5,6 +5,7 @@ import time
|
|
| 5 |
import logging
|
| 6 |
import gc
|
| 7 |
import io
|
|
|
|
| 8 |
from dataclasses import dataclass
|
| 9 |
from typing import Optional, Tuple, List, Any, Dict
|
| 10 |
from contextlib import contextmanager
|
|
@@ -18,6 +19,7 @@ from pydub import AudioSegment
|
|
| 18 |
from pydub.silence import split_on_silence
|
| 19 |
import soundfile as sf
|
| 20 |
import noisereduce
|
|
|
|
| 21 |
|
| 22 |
load_dotenv()
|
| 23 |
|
|
@@ -25,6 +27,13 @@ load_dotenv()
|
|
| 25 |
PREPROCESSING_AVAILABLE = True
|
| 26 |
|
| 27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
def get_env_or_secret(key: str, default: Optional[str] = None) -> Optional[str]:
|
| 29 |
"""Get environment variable or default."""
|
| 30 |
return os.environ.get(key, default)
|
|
@@ -51,6 +60,51 @@ class PreprocessingConfig:
|
|
| 51 |
remove_silence: bool = True
|
| 52 |
|
| 53 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
def normalize_audio(audio_bytes: bytes) -> bytes:
|
| 55 |
"""
|
| 56 |
Converte un chunk audio in bytes nel formato standard per Whisper.
|
|
@@ -248,59 +302,95 @@ def load_asr_pipeline(
|
|
| 248 |
logger.info(f" Chunk length: {chunk_length_s}s")
|
| 249 |
logger.info(f" Return timestamps: {return_timestamps}")
|
| 250 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 251 |
# Use ultra-simplified approach to avoid all compatibility issues
|
| 252 |
try:
|
| 253 |
logger.info(
|
| 254 |
"Setting up ultra-simplified pipeline to avoid forced_decoder_ids conflicts..."
|
| 255 |
)
|
| 256 |
|
| 257 |
-
|
| 258 |
-
asr = pipeline(
|
| 259 |
-
task="automatic-speech-recognition",
|
| 260 |
-
model=model_id,
|
| 261 |
-
torch_dtype=dtype,
|
| 262 |
-
device=0
|
| 263 |
-
if device_str == "cuda"
|
| 264 |
-
else ("mps" if device_str == "mps" else "cpu"),
|
| 265 |
-
token=hf_token,
|
| 266 |
-
)
|
| 267 |
|
| 268 |
# Post-loading cleanup to remove any forced_decoder_ids
|
| 269 |
-
if hasattr(asr.model, "generation_config")
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
|
|
|
| 273 |
|
| 274 |
-
# Set basic parameters after loading
|
| 275 |
if chunk_length_s:
|
| 276 |
logger.info(f"Setting chunk_length_s to {chunk_length_s}")
|
| 277 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 278 |
logger.info(f"Successfully created ultra-simplified pipeline for: {model_id}")
|
| 279 |
|
| 280 |
except Exception as e:
|
| 281 |
logger.error(f"Ultra-simplified pipeline creation failed: {e}")
|
| 282 |
logger.info("Falling back to absolute minimal settings...")
|
| 283 |
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
|
|
|
|
|
|
| 294 |
)
|
|
|
|
| 295 |
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
|
|
|
| 301 |
|
| 302 |
-
|
| 303 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 304 |
logger.info(
|
| 305 |
f"Minimal fallback pipeline created with dtype: {fallback_dtype}"
|
| 306 |
)
|
|
@@ -309,7 +399,11 @@ def load_asr_pipeline(
|
|
| 309 |
logger.error(f"Minimal fallback failed: {fallback_error}")
|
| 310 |
raise
|
| 311 |
|
| 312 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 313 |
|
| 314 |
|
| 315 |
@contextmanager
|
|
|
|
| 5 |
import logging
|
| 6 |
import gc
|
| 7 |
import io
|
| 8 |
+
import threading
|
| 9 |
from dataclasses import dataclass
|
| 10 |
from typing import Optional, Tuple, List, Any, Dict
|
| 11 |
from contextlib import contextmanager
|
|
|
|
| 19 |
from pydub.silence import split_on_silence
|
| 20 |
import soundfile as sf
|
| 21 |
import noisereduce
|
| 22 |
+
from huggingface_hub import snapshot_download
|
| 23 |
|
| 24 |
load_dotenv()
|
| 25 |
|
|
|
|
| 27 |
PREPROCESSING_AVAILABLE = True
|
| 28 |
|
| 29 |
|
| 30 |
+
# Shared caches to keep models/pipelines in memory across requests
|
| 31 |
+
PIPELINE_CACHE: Dict[Tuple[str, str, str], Tuple[Any, str, str]] = {}
|
| 32 |
+
PIPELINE_CACHE_LOCK = threading.Lock()
|
| 33 |
+
MODEL_PATH_CACHE: Dict[str, str] = {}
|
| 34 |
+
MODEL_PATH_CACHE_LOCK = threading.Lock()
|
| 35 |
+
|
| 36 |
+
|
| 37 |
def get_env_or_secret(key: str, default: Optional[str] = None) -> Optional[str]:
|
| 38 |
"""Get environment variable or default."""
|
| 39 |
return os.environ.get(key, default)
|
|
|
|
| 60 |
remove_silence: bool = True
|
| 61 |
|
| 62 |
|
| 63 |
+
def ensure_local_model(model_id: str, hf_token: Optional[str] = None) -> str:
|
| 64 |
+
"""Ensure a model snapshot is available locally and return its path."""
|
| 65 |
+
|
| 66 |
+
if os.path.isdir(model_id):
|
| 67 |
+
return model_id
|
| 68 |
+
|
| 69 |
+
with MODEL_PATH_CACHE_LOCK:
|
| 70 |
+
cached_path = MODEL_PATH_CACHE.get(model_id)
|
| 71 |
+
if cached_path and os.path.isdir(cached_path):
|
| 72 |
+
return cached_path
|
| 73 |
+
|
| 74 |
+
logger = logging.getLogger(__name__)
|
| 75 |
+
|
| 76 |
+
cache_root = get_env_or_secret("HF_MODEL_CACHE_DIR")
|
| 77 |
+
if not cache_root:
|
| 78 |
+
cache_root = os.path.join(os.path.dirname(__file__), "hf_models")
|
| 79 |
+
|
| 80 |
+
os.makedirs(cache_root, exist_ok=True)
|
| 81 |
+
local_dir = os.path.join(cache_root, model_id.replace("/", "__"))
|
| 82 |
+
|
| 83 |
+
try:
|
| 84 |
+
snapshot_download(
|
| 85 |
+
repo_id=model_id,
|
| 86 |
+
token=hf_token,
|
| 87 |
+
local_dir=local_dir,
|
| 88 |
+
local_dir_use_symlinks=False,
|
| 89 |
+
resume_download=True,
|
| 90 |
+
)
|
| 91 |
+
except Exception as download_error:
|
| 92 |
+
# If download fails but we already have weights, continue with local copy
|
| 93 |
+
if os.path.isdir(local_dir) and os.listdir(local_dir):
|
| 94 |
+
logger.warning(
|
| 95 |
+
"Unable to refresh model %s from hub (%s), using existing files",
|
| 96 |
+
model_id,
|
| 97 |
+
download_error,
|
| 98 |
+
)
|
| 99 |
+
else:
|
| 100 |
+
raise
|
| 101 |
+
|
| 102 |
+
with MODEL_PATH_CACHE_LOCK:
|
| 103 |
+
MODEL_PATH_CACHE[model_id] = local_dir
|
| 104 |
+
|
| 105 |
+
return local_dir
|
| 106 |
+
|
| 107 |
+
|
| 108 |
def normalize_audio(audio_bytes: bytes) -> bytes:
|
| 109 |
"""
|
| 110 |
Converte un chunk audio in bytes nel formato standard per Whisper.
|
|
|
|
| 302 |
logger.info(f" Chunk length: {chunk_length_s}s")
|
| 303 |
logger.info(f" Return timestamps: {return_timestamps}")
|
| 304 |
|
| 305 |
+
dtype_name = str(dtype).replace("torch.", "") if dtype is not None else "auto"
|
| 306 |
+
cache_key = (model_id, device_str, dtype_name)
|
| 307 |
+
|
| 308 |
+
with PIPELINE_CACHE_LOCK:
|
| 309 |
+
cached_pipeline = PIPELINE_CACHE.get(cache_key)
|
| 310 |
+
if cached_pipeline:
|
| 311 |
+
logger.info(
|
| 312 |
+
"Reusing cached pipeline for %s on %s (%s)",
|
| 313 |
+
model_id,
|
| 314 |
+
device_str,
|
| 315 |
+
dtype_name,
|
| 316 |
+
)
|
| 317 |
+
return cached_pipeline
|
| 318 |
+
|
| 319 |
+
model_source = ensure_local_model(model_id, hf_token=hf_token)
|
| 320 |
+
logger.info(f"Using local model files from: {model_source}")
|
| 321 |
+
|
| 322 |
+
device_argument: Any = 0 if device_str == "cuda" else device_str
|
| 323 |
+
|
| 324 |
+
pipeline_kwargs = {
|
| 325 |
+
"task": "automatic-speech-recognition",
|
| 326 |
+
"model": model_source,
|
| 327 |
+
"device": device_argument,
|
| 328 |
+
}
|
| 329 |
+
if dtype is not None:
|
| 330 |
+
pipeline_kwargs["torch_dtype"] = dtype
|
| 331 |
+
|
| 332 |
# Use ultra-simplified approach to avoid all compatibility issues
|
| 333 |
try:
|
| 334 |
logger.info(
|
| 335 |
"Setting up ultra-simplified pipeline to avoid forced_decoder_ids conflicts..."
|
| 336 |
)
|
| 337 |
|
| 338 |
+
asr = pipeline(**pipeline_kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 339 |
|
| 340 |
# Post-loading cleanup to remove any forced_decoder_ids
|
| 341 |
+
if hasattr(asr.model, "generation_config") and hasattr(
|
| 342 |
+
asr.model.generation_config, "forced_decoder_ids"
|
| 343 |
+
):
|
| 344 |
+
logger.info("Removing forced_decoder_ids from model generation config")
|
| 345 |
+
asr.model.generation_config.forced_decoder_ids = None
|
| 346 |
|
|
|
|
| 347 |
if chunk_length_s:
|
| 348 |
logger.info(f"Setting chunk_length_s to {chunk_length_s}")
|
| 349 |
|
| 350 |
+
final_device = device_str
|
| 351 |
+
final_dtype = dtype
|
| 352 |
+
final_dtype_name = dtype_name
|
| 353 |
+
|
| 354 |
logger.info(f"Successfully created ultra-simplified pipeline for: {model_id}")
|
| 355 |
|
| 356 |
except Exception as e:
|
| 357 |
logger.error(f"Ultra-simplified pipeline creation failed: {e}")
|
| 358 |
logger.info("Falling back to absolute minimal settings...")
|
| 359 |
|
| 360 |
+
fallback_device = "cpu"
|
| 361 |
+
fallback_dtype = torch.float32
|
| 362 |
+
fallback_dtype_name = str(fallback_dtype).replace("torch.", "")
|
| 363 |
+
fallback_key = (model_id, fallback_device, fallback_dtype_name)
|
| 364 |
+
|
| 365 |
+
with PIPELINE_CACHE_LOCK:
|
| 366 |
+
cached_pipeline = PIPELINE_CACHE.get(fallback_key)
|
| 367 |
+
if cached_pipeline:
|
| 368 |
+
logger.info(
|
| 369 |
+
"Reusing cached fallback pipeline for %s (%s)",
|
| 370 |
+
model_id,
|
| 371 |
+
fallback_dtype_name,
|
| 372 |
)
|
| 373 |
+
return cached_pipeline
|
| 374 |
|
| 375 |
+
fallback_kwargs = {
|
| 376 |
+
"task": "automatic-speech-recognition",
|
| 377 |
+
"model": model_source,
|
| 378 |
+
"device": fallback_device,
|
| 379 |
+
"torch_dtype": fallback_dtype,
|
| 380 |
+
}
|
| 381 |
|
| 382 |
+
try:
|
| 383 |
+
asr = pipeline(**fallback_kwargs)
|
| 384 |
+
|
| 385 |
+
if hasattr(asr.model, "generation_config") and hasattr(
|
| 386 |
+
asr.model.generation_config, "forced_decoder_ids"
|
| 387 |
+
):
|
| 388 |
+
logger.info("Removing forced_decoder_ids from fallback model")
|
| 389 |
+
asr.model.generation_config.forced_decoder_ids = None
|
| 390 |
+
|
| 391 |
+
final_device = fallback_device
|
| 392 |
+
final_dtype = fallback_dtype
|
| 393 |
+
final_dtype_name = fallback_dtype_name
|
| 394 |
logger.info(
|
| 395 |
f"Minimal fallback pipeline created with dtype: {fallback_dtype}"
|
| 396 |
)
|
|
|
|
| 399 |
logger.error(f"Minimal fallback failed: {fallback_error}")
|
| 400 |
raise
|
| 401 |
|
| 402 |
+
cache_key = (model_id, final_device, final_dtype_name)
|
| 403 |
+
with PIPELINE_CACHE_LOCK:
|
| 404 |
+
PIPELINE_CACHE[cache_key] = (asr, final_device, final_dtype_name)
|
| 405 |
+
|
| 406 |
+
return asr, final_device, final_dtype_name
|
| 407 |
|
| 408 |
|
| 409 |
@contextmanager
|