calettippo commited on
Commit
edf497c
·
1 Parent(s): 65b0afc

Improve Whisper pipeline caching

Browse files
Files changed (2) hide show
  1. .gitignore +1 -0
  2. app.py +127 -33
.gitignore CHANGED
@@ -4,3 +4,4 @@ __pycache__/
4
  *.py[cod]
5
  .DS_Store
6
  *.log
 
 
4
  *.py[cod]
5
  .DS_Store
6
  *.log
7
+ hf_models/
app.py CHANGED
@@ -5,6 +5,7 @@ import time
5
  import logging
6
  import gc
7
  import io
 
8
  from dataclasses import dataclass
9
  from typing import Optional, Tuple, List, Any, Dict
10
  from contextlib import contextmanager
@@ -18,6 +19,7 @@ from pydub import AudioSegment
18
  from pydub.silence import split_on_silence
19
  import soundfile as sf
20
  import noisereduce
 
21
 
22
  load_dotenv()
23
 
@@ -25,6 +27,13 @@ load_dotenv()
25
  PREPROCESSING_AVAILABLE = True
26
 
27
 
 
 
 
 
 
 
 
28
  def get_env_or_secret(key: str, default: Optional[str] = None) -> Optional[str]:
29
  """Get environment variable or default."""
30
  return os.environ.get(key, default)
@@ -51,6 +60,51 @@ class PreprocessingConfig:
51
  remove_silence: bool = True
52
 
53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  def normalize_audio(audio_bytes: bytes) -> bytes:
55
  """
56
  Converte un chunk audio in bytes nel formato standard per Whisper.
@@ -248,59 +302,95 @@ def load_asr_pipeline(
248
  logger.info(f" Chunk length: {chunk_length_s}s")
249
  logger.info(f" Return timestamps: {return_timestamps}")
250
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
251
  # Use ultra-simplified approach to avoid all compatibility issues
252
  try:
253
  logger.info(
254
  "Setting up ultra-simplified pipeline to avoid forced_decoder_ids conflicts..."
255
  )
256
 
257
- # Create pipeline with absolute minimal configuration
258
- asr = pipeline(
259
- task="automatic-speech-recognition",
260
- model=model_id,
261
- torch_dtype=dtype,
262
- device=0
263
- if device_str == "cuda"
264
- else ("mps" if device_str == "mps" else "cpu"),
265
- token=hf_token,
266
- )
267
 
268
  # Post-loading cleanup to remove any forced_decoder_ids
269
- if hasattr(asr.model, "generation_config"):
270
- if hasattr(asr.model.generation_config, "forced_decoder_ids"):
271
- logger.info("Removing forced_decoder_ids from model generation config")
272
- asr.model.generation_config.forced_decoder_ids = None
 
273
 
274
- # Set basic parameters after loading
275
  if chunk_length_s:
276
  logger.info(f"Setting chunk_length_s to {chunk_length_s}")
277
 
 
 
 
 
278
  logger.info(f"Successfully created ultra-simplified pipeline for: {model_id}")
279
 
280
  except Exception as e:
281
  logger.error(f"Ultra-simplified pipeline creation failed: {e}")
282
  logger.info("Falling back to absolute minimal settings...")
283
 
284
- try:
285
- # Fallback with absolute minimal settings
286
- fallback_dtype = torch.float32
287
-
288
- asr = pipeline(
289
- task="automatic-speech-recognition",
290
- model=model_id,
291
- torch_dtype=fallback_dtype,
292
- device="cpu", # Force CPU for maximum compatibility
293
- token=hf_token,
 
 
294
  )
 
295
 
296
- # Post-loading cleanup
297
- if hasattr(asr.model, "generation_config"):
298
- if hasattr(asr.model.generation_config, "forced_decoder_ids"):
299
- logger.info("Removing forced_decoder_ids from fallback model")
300
- asr.model.generation_config.forced_decoder_ids = None
 
301
 
302
- device_str = "cpu"
303
- dtype = fallback_dtype
 
 
 
 
 
 
 
 
 
 
304
  logger.info(
305
  f"Minimal fallback pipeline created with dtype: {fallback_dtype}"
306
  )
@@ -309,7 +399,11 @@ def load_asr_pipeline(
309
  logger.error(f"Minimal fallback failed: {fallback_error}")
310
  raise
311
 
312
- return asr, device_str, str(dtype).replace("torch.", "")
 
 
 
 
313
 
314
 
315
  @contextmanager
 
5
  import logging
6
  import gc
7
  import io
8
+ import threading
9
  from dataclasses import dataclass
10
  from typing import Optional, Tuple, List, Any, Dict
11
  from contextlib import contextmanager
 
19
  from pydub.silence import split_on_silence
20
  import soundfile as sf
21
  import noisereduce
22
+ from huggingface_hub import snapshot_download
23
 
24
  load_dotenv()
25
 
 
27
  PREPROCESSING_AVAILABLE = True
28
 
29
 
30
+ # Shared caches to keep models/pipelines in memory across requests
31
+ PIPELINE_CACHE: Dict[Tuple[str, str, str], Tuple[Any, str, str]] = {}
32
+ PIPELINE_CACHE_LOCK = threading.Lock()
33
+ MODEL_PATH_CACHE: Dict[str, str] = {}
34
+ MODEL_PATH_CACHE_LOCK = threading.Lock()
35
+
36
+
37
  def get_env_or_secret(key: str, default: Optional[str] = None) -> Optional[str]:
38
  """Get environment variable or default."""
39
  return os.environ.get(key, default)
 
60
  remove_silence: bool = True
61
 
62
 
63
+ def ensure_local_model(model_id: str, hf_token: Optional[str] = None) -> str:
64
+ """Ensure a model snapshot is available locally and return its path."""
65
+
66
+ if os.path.isdir(model_id):
67
+ return model_id
68
+
69
+ with MODEL_PATH_CACHE_LOCK:
70
+ cached_path = MODEL_PATH_CACHE.get(model_id)
71
+ if cached_path and os.path.isdir(cached_path):
72
+ return cached_path
73
+
74
+ logger = logging.getLogger(__name__)
75
+
76
+ cache_root = get_env_or_secret("HF_MODEL_CACHE_DIR")
77
+ if not cache_root:
78
+ cache_root = os.path.join(os.path.dirname(__file__), "hf_models")
79
+
80
+ os.makedirs(cache_root, exist_ok=True)
81
+ local_dir = os.path.join(cache_root, model_id.replace("/", "__"))
82
+
83
+ try:
84
+ snapshot_download(
85
+ repo_id=model_id,
86
+ token=hf_token,
87
+ local_dir=local_dir,
88
+ local_dir_use_symlinks=False,
89
+ resume_download=True,
90
+ )
91
+ except Exception as download_error:
92
+ # If download fails but we already have weights, continue with local copy
93
+ if os.path.isdir(local_dir) and os.listdir(local_dir):
94
+ logger.warning(
95
+ "Unable to refresh model %s from hub (%s), using existing files",
96
+ model_id,
97
+ download_error,
98
+ )
99
+ else:
100
+ raise
101
+
102
+ with MODEL_PATH_CACHE_LOCK:
103
+ MODEL_PATH_CACHE[model_id] = local_dir
104
+
105
+ return local_dir
106
+
107
+
108
  def normalize_audio(audio_bytes: bytes) -> bytes:
109
  """
110
  Converte un chunk audio in bytes nel formato standard per Whisper.
 
302
  logger.info(f" Chunk length: {chunk_length_s}s")
303
  logger.info(f" Return timestamps: {return_timestamps}")
304
 
305
+ dtype_name = str(dtype).replace("torch.", "") if dtype is not None else "auto"
306
+ cache_key = (model_id, device_str, dtype_name)
307
+
308
+ with PIPELINE_CACHE_LOCK:
309
+ cached_pipeline = PIPELINE_CACHE.get(cache_key)
310
+ if cached_pipeline:
311
+ logger.info(
312
+ "Reusing cached pipeline for %s on %s (%s)",
313
+ model_id,
314
+ device_str,
315
+ dtype_name,
316
+ )
317
+ return cached_pipeline
318
+
319
+ model_source = ensure_local_model(model_id, hf_token=hf_token)
320
+ logger.info(f"Using local model files from: {model_source}")
321
+
322
+ device_argument: Any = 0 if device_str == "cuda" else device_str
323
+
324
+ pipeline_kwargs = {
325
+ "task": "automatic-speech-recognition",
326
+ "model": model_source,
327
+ "device": device_argument,
328
+ }
329
+ if dtype is not None:
330
+ pipeline_kwargs["torch_dtype"] = dtype
331
+
332
  # Use ultra-simplified approach to avoid all compatibility issues
333
  try:
334
  logger.info(
335
  "Setting up ultra-simplified pipeline to avoid forced_decoder_ids conflicts..."
336
  )
337
 
338
+ asr = pipeline(**pipeline_kwargs)
 
 
 
 
 
 
 
 
 
339
 
340
  # Post-loading cleanup to remove any forced_decoder_ids
341
+ if hasattr(asr.model, "generation_config") and hasattr(
342
+ asr.model.generation_config, "forced_decoder_ids"
343
+ ):
344
+ logger.info("Removing forced_decoder_ids from model generation config")
345
+ asr.model.generation_config.forced_decoder_ids = None
346
 
 
347
  if chunk_length_s:
348
  logger.info(f"Setting chunk_length_s to {chunk_length_s}")
349
 
350
+ final_device = device_str
351
+ final_dtype = dtype
352
+ final_dtype_name = dtype_name
353
+
354
  logger.info(f"Successfully created ultra-simplified pipeline for: {model_id}")
355
 
356
  except Exception as e:
357
  logger.error(f"Ultra-simplified pipeline creation failed: {e}")
358
  logger.info("Falling back to absolute minimal settings...")
359
 
360
+ fallback_device = "cpu"
361
+ fallback_dtype = torch.float32
362
+ fallback_dtype_name = str(fallback_dtype).replace("torch.", "")
363
+ fallback_key = (model_id, fallback_device, fallback_dtype_name)
364
+
365
+ with PIPELINE_CACHE_LOCK:
366
+ cached_pipeline = PIPELINE_CACHE.get(fallback_key)
367
+ if cached_pipeline:
368
+ logger.info(
369
+ "Reusing cached fallback pipeline for %s (%s)",
370
+ model_id,
371
+ fallback_dtype_name,
372
  )
373
+ return cached_pipeline
374
 
375
+ fallback_kwargs = {
376
+ "task": "automatic-speech-recognition",
377
+ "model": model_source,
378
+ "device": fallback_device,
379
+ "torch_dtype": fallback_dtype,
380
+ }
381
 
382
+ try:
383
+ asr = pipeline(**fallback_kwargs)
384
+
385
+ if hasattr(asr.model, "generation_config") and hasattr(
386
+ asr.model.generation_config, "forced_decoder_ids"
387
+ ):
388
+ logger.info("Removing forced_decoder_ids from fallback model")
389
+ asr.model.generation_config.forced_decoder_ids = None
390
+
391
+ final_device = fallback_device
392
+ final_dtype = fallback_dtype
393
+ final_dtype_name = fallback_dtype_name
394
  logger.info(
395
  f"Minimal fallback pipeline created with dtype: {fallback_dtype}"
396
  )
 
399
  logger.error(f"Minimal fallback failed: {fallback_error}")
400
  raise
401
 
402
+ cache_key = (model_id, final_device, final_dtype_name)
403
+ with PIPELINE_CACHE_LOCK:
404
+ PIPELINE_CACHE[cache_key] = (asr, final_device, final_dtype_name)
405
+
406
+ return asr, final_device, final_dtype_name
407
 
408
 
409
  @contextmanager