atc-tts-mos / backend /data_manager.py
aether-raider
removed back button
02c7c7b
# backend/data_manager.py
import os
import base64
import io
from typing import List, Optional, Any, Dict
import numpy as np
from datasets import load_dataset
from .config import AUDIO_DATASET_ID
from .models import Clip
try:
import soundfile as sf
except ImportError:
sf = None
class DataManager:
"""Handles loading and processing data from Hugging Face."""
def __init__(self, dataset_id: str = AUDIO_DATASET_ID):
self.dataset_id = dataset_id
self._clips: Optional[List[Clip]] = None
self._loading = False
def _get_audio_data(self, audio_val) -> Optional[str]:
"""
Handle audio data from HuggingFace dataset with LFS files.
Returns file path or data URL that Gradio can handle.
"""
try:
array = None
sr = None
if isinstance(audio_val, dict):
array = audio_val.get("array")
sr = audio_val.get("sampling_rate")
if array is None or sr is None:
try:
array = audio_val["array"]
sr = audio_val["sampling_rate"]
except Exception:
array = getattr(audio_val, "array", None)
sr = getattr(audio_val, "sampling_rate", None)
if array is not None and sr is not None and sf is not None:
# Convert to temporary file that Gradio can handle
import tempfile
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file:
sf.write(tmp_file.name, np.array(array), int(sr))
return tmp_file.name
except Exception as e:
print(f"[WARN] Failed to process audio data: {e}")
print("[WARN] Could not process audio data for this example")
return None
def load_clips(self) -> List[Clip]:
if self._clips is not None:
return self._clips
if self._loading:
print("Dataset loading already in progress...")
return []
self._loading = True
print(f"Loading dataset {self.dataset_id}...")
dataset = load_dataset(self.dataset_id, split="train")
clips: List[Clip] = []
for row in dataset:
audio_val = row.get("audio")
audio_data = self._get_audio_data(audio_val)
if audio_data is None:
print(f"[WARN] Skipping clip {row.get('exercise_id')} – could not process audio data")
continue
clip = Clip(
id=f"{row['model']}_{row['speaker']}_{row['exercise_id']}",
model=row["model"],
speaker=row["speaker"],
exercise=row["exercise"],
exercise_id=row["exercise_id"],
transcript=row["rt"],
audio_url=audio_data, # file path or data for Gradio Audio
)
clips.append(clip)
self._clips = clips
self._loading = False
print(f"Loaded {len(clips)} clips")
return clips