Spaces:
Sleeping
Sleeping
| # backend/data_manager.py | |
| import os | |
| import base64 | |
| import io | |
| from typing import List, Optional, Any, Dict | |
| import numpy as np | |
| from datasets import load_dataset | |
| from .config import AUDIO_DATASET_ID | |
| from .models import Clip | |
| try: | |
| import soundfile as sf | |
| except ImportError: | |
| sf = None | |
| class DataManager: | |
| """Handles loading and processing data from Hugging Face.""" | |
| def __init__(self, dataset_id: str = AUDIO_DATASET_ID): | |
| self.dataset_id = dataset_id | |
| self._clips: Optional[List[Clip]] = None | |
| self._loading = False | |
| def _get_audio_data(self, audio_val) -> Optional[str]: | |
| """ | |
| Handle audio data from HuggingFace dataset with LFS files. | |
| Returns file path or data URL that Gradio can handle. | |
| """ | |
| try: | |
| array = None | |
| sr = None | |
| if isinstance(audio_val, dict): | |
| array = audio_val.get("array") | |
| sr = audio_val.get("sampling_rate") | |
| if array is None or sr is None: | |
| try: | |
| array = audio_val["array"] | |
| sr = audio_val["sampling_rate"] | |
| except Exception: | |
| array = getattr(audio_val, "array", None) | |
| sr = getattr(audio_val, "sampling_rate", None) | |
| if array is not None and sr is not None and sf is not None: | |
| # Convert to temporary file that Gradio can handle | |
| import tempfile | |
| with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file: | |
| sf.write(tmp_file.name, np.array(array), int(sr)) | |
| return tmp_file.name | |
| except Exception as e: | |
| print(f"[WARN] Failed to process audio data: {e}") | |
| print("[WARN] Could not process audio data for this example") | |
| return None | |
| def load_clips(self) -> List[Clip]: | |
| if self._clips is not None: | |
| return self._clips | |
| if self._loading: | |
| print("Dataset loading already in progress...") | |
| return [] | |
| self._loading = True | |
| print(f"Loading dataset {self.dataset_id}...") | |
| dataset = load_dataset(self.dataset_id, split="train") | |
| clips: List[Clip] = [] | |
| for row in dataset: | |
| audio_val = row.get("audio") | |
| audio_data = self._get_audio_data(audio_val) | |
| if audio_data is None: | |
| print(f"[WARN] Skipping clip {row.get('exercise_id')} – could not process audio data") | |
| continue | |
| clip = Clip( | |
| id=f"{row['model']}_{row['speaker']}_{row['exercise_id']}", | |
| model=row["model"], | |
| speaker=row["speaker"], | |
| exercise=row["exercise"], | |
| exercise_id=row["exercise_id"], | |
| transcript=row["rt"], | |
| audio_url=audio_data, # file path or data for Gradio Audio | |
| ) | |
| clips.append(clip) | |
| self._clips = clips | |
| self._loading = False | |
| print(f"Loaded {len(clips)} clips") | |
| return clips | |