| | |
| | """ |
| | SAM Audio ONNX Runtime Inference Example |
| | |
| | This script demonstrates how to use the exported ONNX models for audio source |
| | separation inference. It shows the complete pipeline from text input to |
| | separated audio output. |
| | |
| | Usage: |
| | python onnx_inference.py --audio input.wav --text "a person speaking" |
| | """ |
| |
|
| | import os |
| | import argparse |
| | import numpy as np |
| | import json |
| | from typing import Optional |
| |
|
| |
|
| | def load_audio(path: str, target_sr: int = 48000) -> np.ndarray: |
| | """Load audio file and resample to target sample rate. Supports video files via torchaudio/librosa.""" |
| | |
| | try: |
| | import torchaudio |
| | import torch |
| | wav, sr = torchaudio.load(path) |
| | if wav.shape[0] > 1: |
| | wav = wav.mean(0, keepdim=True) |
| | if sr != target_sr: |
| | resampler = torchaudio.transforms.Resample(sr, target_sr) |
| | wav = resampler(wav) |
| | return wav.squeeze().numpy().astype(np.float32) |
| | except Exception as e: |
| | |
| | try: |
| | import librosa |
| | audio, sr = librosa.load(path, sr=target_sr, mono=True) |
| | return audio.astype(np.float32) |
| | except ImportError: |
| | raise ImportError("Please install torchaudio or librosa: pip install torchaudio librosa") |
| | except Exception as e2: |
| | raise RuntimeError(f"Failed to load audio from {path}: {e2}") |
| |
|
| |
|
| | def save_audio(audio: np.ndarray, path: str, sample_rate: int = 48000): |
| | """Save audio to WAV file.""" |
| | try: |
| | import soundfile as sf |
| | |
| | if audio.ndim > 1: |
| | audio = audio.flatten() |
| | sf.write(path, audio, sample_rate) |
| | print(f"Saved audio to {path}") |
| | except ImportError: |
| | raise ImportError("Please install soundfile: pip install soundfile") |
| |
|
| |
|
| | def save_video_with_audio(frames: np.ndarray, audio: np.ndarray, path: str, sample_rate: int = 48000, fps: float = 24.0): |
| | """Save masked video frames and separated audio to a movie file.""" |
| | try: |
| | import torch |
| | import torchvision |
| | import torchaudio |
| | |
| | |
| | |
| | frames_uint8 = ((frames * 0.5 + 0.5) * 255).astype(np.uint8) |
| | |
| | |
| | video_tensor = torch.from_numpy(frames_uint8).permute(0, 2, 3, 1) |
| | |
| | |
| | if audio.ndim == 1: |
| | audio = audio[None, :] |
| | audio_tensor = torch.from_numpy(audio) |
| | |
| | print(f"Saving merged video to {path}...") |
| | torchvision.io.write_video( |
| | path, |
| | video_tensor, |
| | fps=fps, |
| | video_codec="libx264", |
| | audio_array=audio_tensor, |
| | audio_fps=sample_rate, |
| | audio_codec="aac" |
| | ) |
| | print(f" ✓ Video saved to {path}") |
| | except Exception as e: |
| | print(f"Warning: Failed to save video: {e}") |
| |
|
| |
|
| | class SAMAudioONNXPipeline: |
| | """ |
| | ONNX-based SAM Audio inference pipeline. |
| | |
| | This class orchestrates all the ONNX models to perform audio source separation. |
| | """ |
| | |
| | def __init__( |
| | self, |
| | model_dir: str = "onnx_models", |
| | device: str = "cpu", |
| | num_ode_steps: int = 16, |
| | ): |
| | import onnxruntime as ort |
| | |
| | self.model_dir = model_dir |
| | self.num_ode_steps = num_ode_steps |
| | self.step_size = 1.0 / num_ode_steps |
| | |
| | |
| | if device == "cuda": |
| | providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] |
| | else: |
| | providers = ["CPUExecutionProvider"] |
| | |
| | |
| | print("Loading ONNX models...") |
| | |
| | self.dacvae_encoder = ort.InferenceSession( |
| | os.path.join(model_dir, "dacvae_encoder.onnx"), |
| | providers=providers, |
| | ) |
| | print(" ✓ DACVAE encoder loaded") |
| | |
| | self.dacvae_decoder = ort.InferenceSession( |
| | os.path.join(model_dir, "dacvae_decoder.onnx"), |
| | providers=providers, |
| | ) |
| | print(" ✓ DACVAE decoder loaded") |
| | |
| | self.t5_encoder = ort.InferenceSession( |
| | os.path.join(model_dir, "t5_encoder.onnx"), |
| | providers=providers, |
| | ) |
| | print(" ✓ T5 encoder loaded") |
| | |
| | self.dit = ort.InferenceSession( |
| | os.path.join(model_dir, "dit_single_step.onnx"), |
| | providers=providers, |
| | ) |
| | print(" ✓ DiT denoiser loaded") |
| | |
| | |
| | self.vision_encoder = None |
| | vision_path = os.path.join(model_dir, "vision_encoder.onnx") |
| | if os.path.exists(vision_path): |
| | self.vision_encoder = ort.InferenceSession( |
| | vision_path, |
| | providers=providers, |
| | ) |
| | print(" ✓ Vision encoder loaded") |
| | |
| | |
| | self._load_tokenizer() |
| | print(" ✓ Tokenizer loaded") |
| | |
| | print("All models loaded!") |
| | |
| | def _load_tokenizer(self): |
| | """ |
| | Load the T5 tokenizer using SentencePiece. |
| | This avoids the dependency on the 'transformers' library. |
| | """ |
| | try: |
| | import sentencepiece as spm |
| | except ImportError: |
| | raise ImportError("Please install sentencepiece: pip install sentencepiece") |
| | |
| | |
| | sp_path = os.path.join(self.model_dir, "tokenizer", "spiece.model") |
| | if not os.path.exists(sp_path): |
| | sp_path = os.path.join(self.model_dir, "spiece.model") |
| | |
| | if not os.path.exists(sp_path): |
| | raise FileNotFoundError(f"SentencePiece model not found at {sp_path}") |
| | |
| | |
| | class T5ONNXTokenizer: |
| | def __init__(self, sp_path): |
| | self.sp = spm.SentencePieceProcessor() |
| | self.sp.load(sp_path) |
| | |
| | def encode(self, text: str) -> np.ndarray: |
| | ids = self.sp.encode(text) |
| | if len(ids) > 0 and ids[-1] != 1: |
| | ids.append(1) |
| | elif len(ids) == 0: |
| | ids = [1] |
| | return np.array(ids, dtype=np.int64).reshape(1, -1) |
| | |
| | def decode(self, tokens: np.ndarray) -> str: |
| | if tokens.ndim > 1: |
| | tokens = tokens.flatten() |
| | return self.sp.decode(tokens.tolist()) |
| |
|
| | self.tokenizer = T5ONNXTokenizer(sp_path) |
| |
|
| | def load_video_frames(self, path: str, num_steps: int, mask_path: Optional[str] = None) -> tuple[np.ndarray, np.ndarray, float]: |
| | """ |
| | Load video frames and align them to audio latent steps. |
| | Optionally applies a binary mask for visual prompting. |
| | Returns (normalized_frames, visual_frames). |
| | """ |
| | try: |
| | from torchcodec.decoders import VideoDecoder |
| | import torch |
| | import torch.nn.functional as F |
| | except ImportError: |
| | raise ImportError("Please install torchcodec and torch: pip install torchcodec torch") |
| |
|
| | decoder = VideoDecoder(path, dimension_order="NCHW") |
| | all_data = decoder.get_frames_in_range(0, len(decoder)) |
| | |
| | |
| | |
| | hop_length = 1536 |
| | sample_rate = 48000 |
| | step_timestamps = np.arange(num_steps) * hop_length / sample_rate |
| | |
| | |
| | metadata = decoder.metadata |
| | fps = metadata.average_fps if metadata.average_fps is not None else 24.0 |
| | |
| | |
| | diffs = np.abs(all_data.pts_seconds.numpy()[:, None] - step_timestamps[None, :]) |
| | frame_idxs = np.argmin(diffs, axis=0) |
| | |
| | frames = all_data.data[frame_idxs] |
| | |
| | |
| | if mask_path: |
| | print(f" Applying mask from {mask_path}...") |
| | mask_decoder = VideoDecoder(mask_path, dimension_order="NCHW") |
| | mask_data = mask_decoder.get_frames_in_range(0, len(mask_decoder)) |
| | |
| | |
| | m_diffs = np.abs(mask_data.pts_seconds.numpy()[:, None] - step_timestamps[None, :]) |
| | m_frame_idxs = np.argmin(m_diffs, axis=0) |
| | masks = mask_data.data[m_frame_idxs] |
| | |
| | |
| | |
| | binary_mask = (masks.float().mean(dim=1, keepdim=True) > 128).float() |
| | frames = frames.float() * (1.0 - binary_mask) |
| | |
| | |
| | image_size = 336 |
| | frames_resized = F.interpolate(frames.float(), size=(image_size, image_size), mode="bicubic") |
| | frames_norm = (frames_resized / 255.0 - 0.5) / 0.5 |
| | |
| | return frames_norm.numpy(), frames_norm.numpy(), fps |
| |
|
| | def encode_video(self, frames: np.ndarray) -> np.ndarray: |
| | """Run vision encoder on framed images.""" |
| | if self.vision_encoder is None: |
| | raise RuntimeError("Vision encoder model not loaded") |
| | |
| | |
| | |
| | all_features = [] |
| | for i in range(len(frames)): |
| | frame = frames[i:i+1] |
| | outputs = self.vision_encoder.run( |
| | ["vision_features"], |
| | {"video_frames": frame} |
| | ) |
| | all_features.append(outputs[0]) |
| | |
| | features = np.concatenate(all_features, axis=0) |
| | |
| | |
| | return features.transpose(1, 0)[None, :, :] |
| |
|
| | |
| | def encode_audio(self, audio: np.ndarray) -> np.ndarray: |
| | """ |
| | Encode audio waveform to latent features. |
| | |
| | Args: |
| | audio: Audio waveform, shape (samples,) or (1, 1, samples) |
| | |
| | Returns: |
| | Latent features, shape (1, latent_dim, time_steps) |
| | """ |
| | |
| | if audio.ndim == 1: |
| | audio = audio.reshape(1, 1, -1) |
| | elif audio.ndim == 2: |
| | audio = audio.reshape(1, *audio.shape) |
| | |
| | outputs = self.dacvae_encoder.run( |
| | ["latent_features"], |
| | {"audio": audio.astype(np.float32)}, |
| | ) |
| | return outputs[0] |
| | |
| | def decode_audio(self, latent: np.ndarray) -> np.ndarray: |
| | """ |
| | Decode latent features to audio waveform. |
| | |
| | Uses chunked decoding since the DACVAE decoder was exported with |
| | fixed 25 time steps. Processes in chunks and concatenates. |
| | |
| | Args: |
| | latent: Latent features, shape (1, latent_dim, time_steps) |
| | |
| | Returns: |
| | Audio waveform, shape (samples,) |
| | """ |
| | chunk_size = 25 |
| | hop_length = 1920 |
| | |
| | _, _, time_steps = latent.shape |
| | |
| | audio_chunks = [] |
| | for start_idx in range(0, time_steps, chunk_size): |
| | end_idx = min(start_idx + chunk_size, time_steps) |
| | chunk = latent[:, :, start_idx:end_idx] |
| | |
| | |
| | actual_size = chunk.shape[2] |
| | if actual_size < chunk_size: |
| | pad_size = chunk_size - actual_size |
| | chunk = np.pad(chunk, ((0, 0), (0, 0), (0, pad_size)), mode='constant') |
| | |
| | |
| | chunk_audio = self.dacvae_decoder.run( |
| | ["waveform"], |
| | {"latent_features": chunk.astype(np.float32)}, |
| | )[0] |
| | |
| | |
| | if actual_size < chunk_size: |
| | trim_samples = actual_size * hop_length |
| | chunk_audio = chunk_audio[:, :, :trim_samples] |
| | |
| | audio_chunks.append(chunk_audio) |
| | |
| | |
| | full_audio = np.concatenate(audio_chunks, axis=2) |
| | return full_audio.squeeze() |
| | |
| | def encode_text(self, text: str) -> tuple[np.ndarray, np.ndarray]: |
| | """ |
| | Encode text prompt to features. |
| | |
| | Args: |
| | text: Text description of the audio to separate |
| | |
| | Returns: |
| | Tuple of (hidden_states, attention_mask) |
| | """ |
| | input_ids = self.tokenizer.encode(text) |
| | attention_mask = np.ones_like(input_ids) |
| | |
| | outputs = self.t5_encoder.run( |
| | ["hidden_states"], |
| | { |
| | "input_ids": input_ids.astype(np.int64), |
| | "attention_mask": attention_mask.astype(np.int64), |
| | }, |
| | ) |
| | |
| | return outputs[0], attention_mask |
| | |
| | def dit_step( |
| | self, |
| | noisy_audio: np.ndarray, |
| | time: float, |
| | audio_features: np.ndarray, |
| | text_features: np.ndarray, |
| | text_mask: np.ndarray, |
| | masked_video_features: Optional[np.ndarray] = None, |
| | ) -> np.ndarray: |
| | """Run a single DiT denoiser step.""" |
| | batch_size = noisy_audio.shape[0] |
| | seq_len = noisy_audio.shape[1] |
| | |
| | |
| | first_input = self.dit.get_inputs()[0] |
| | use_fp16 = first_input.type == 'tensor(float16)' |
| | float_dtype = np.float16 if use_fp16 else np.float32 |
| | |
| | |
| | |
| | anchor_ids = np.zeros((batch_size, 2), dtype=np.int64) |
| | anchor_ids[:, 1] = 3 |
| | |
| | |
| | anchor_alignment = np.zeros((batch_size, seq_len), dtype=np.int64) |
| | |
| | |
| | audio_pad_mask = np.ones((batch_size, seq_len), dtype=np.bool_) |
| | |
| | |
| | if masked_video_features is None: |
| | |
| | vision_dim = 1024 |
| | masked_video_features = np.zeros((batch_size, vision_dim, seq_len), dtype=float_dtype) |
| | |
| | inputs = { |
| | "noisy_audio": noisy_audio.astype(float_dtype), |
| | "time": np.array([time], dtype=float_dtype), |
| | "audio_features": audio_features.astype(float_dtype), |
| | "text_features": text_features.astype(float_dtype), |
| | "text_mask": text_mask.astype(np.bool_), |
| | "masked_video_features": masked_video_features.astype(float_dtype), |
| | "anchor_ids": anchor_ids.astype(np.int64), |
| | "anchor_alignment": anchor_alignment.astype(np.int64), |
| | "audio_pad_mask": audio_pad_mask.astype(np.bool_), |
| | } |
| | |
| | outputs = self.dit.run(None, inputs) |
| | return outputs[0] |
| |
|
| | |
| | def separate( |
| | self, |
| | audio: np.ndarray, |
| | text: str, |
| | video_path: Optional[str] = None, |
| | mask_path: Optional[str] = None |
| | ) -> tuple[np.ndarray, np.ndarray, Optional[np.ndarray], float]: |
| | """ |
| | Perform the full separation pipeline. |
| | |
| | Args: |
| | audio: Input mixture waveform |
| | text: Text description of the target source |
| | video_path: Optional path to a video for visual conditioning |
| | mask_path: Optional path to a video/image mask for visual prompting |
| | |
| | Returns: |
| | Tuple of (target audio, residual audio, masked video frames if any, fps) |
| | - target: The separated sound matching the text/visual prompt |
| | - residual: Everything else in the audio (the remainder) |
| | """ |
| | |
| | print("1. Encoding audio...") |
| | latent_features = self.encode_audio(audio) |
| | |
| | latent_features = latent_features.transpose(0, 2, 1) |
| | |
| | |
| | audio_features = np.concatenate([latent_features, latent_features], axis=2) |
| | print(f" Audio latent shape: {latent_features.shape}") |
| | |
| | |
| | print("2. Encoding text...") |
| | text_features, text_mask = self.encode_text(text) |
| | print(f" Text features shape: {text_features.shape}") |
| | |
| | |
| | masked_video_features = None |
| | visual_frames = None |
| | fps = 24.0 |
| | if video_path and self.vision_encoder: |
| | print("3a. Loading and encoding video...") |
| | norm_frames, visual_frames, fps = self.load_video_frames(video_path, latent_features.shape[1], mask_path) |
| | masked_video_features = self.encode_video(norm_frames) |
| | print(f" Video features shape: {masked_video_features.shape}") |
| | |
| | |
| | print("3. Running ODE solver...") |
| | |
| | |
| | B, T, C = audio_features.shape |
| | x = np.random.randn(B, T, C).astype(np.float32) |
| | |
| | steps = self.num_ode_steps |
| | dt = 1.0 / steps |
| | |
| | for i in range(steps): |
| | t = i * dt |
| | print(f" ODE step {i+1}/{steps}", end="\r") |
| | |
| | k1 = self.dit_step(x, t, audio_features, text_features, text_mask, masked_video_features) |
| | x_mid = x + k1 * (dt / 2.0) |
| | k2 = self.dit_step(x_mid, t + dt/2.0, audio_features, text_features, text_mask, masked_video_features) |
| | |
| | x = x + k2 * dt |
| | |
| | |
| | |
| | |
| | |
| | |
| | target_latent = x[:, :, :128].transpose(0, 2, 1) |
| | residual_latent = x[:, :, 128:].transpose(0, 2, 1) |
| | print(f"\n Target latent shape: {target_latent.shape}") |
| | print(f" Residual latent shape: {residual_latent.shape}") |
| | |
| | |
| | print("4. Decoding target audio...") |
| | target_audio = self.decode_audio(target_latent) |
| | print(f" Target audio shape: {target_audio.shape}") |
| | |
| | print("5. Decoding residual audio...") |
| | residual_audio = self.decode_audio(residual_latent) |
| | print(f" Residual audio shape: {residual_audio.shape}") |
| | |
| | return target_audio, residual_audio, visual_frames, fps |
| |
|
| |
|
| | def main(): |
| | parser = argparse.ArgumentParser( |
| | description="SAM Audio ONNX Runtime Inference" |
| | ) |
| | parser.add_argument( |
| | "--audio", |
| | type=str, |
| | help="Path to input audio file (optional if --video is provided)", |
| | ) |
| | parser.add_argument("--text", type=str, default="", help="Text description of the target source (optional if --video is provided)") |
| | parser.add_argument("--video", type=str, help="Optional path to video file for conditional separation") |
| | parser.add_argument("--mask", type=str, help="Optional path to mask file (visual prompting)") |
| | parser.add_argument("--output", type=str, default="target.wav", help="Output WAV file path for target (separated) audio") |
| | parser.add_argument("--output-residual", type=str, default="residual.wav", help="Output WAV file path for residual audio") |
| | parser.add_argument("--output-video", type=str, help="Optional path to save masked video with separated audio") |
| | parser.add_argument("--model-dir", type=str, default="onnx_models", help="Directory containing ONNX models") |
| | parser.add_argument("--steps", type=int, default=16, help="Number of ODE solver steps") |
| | parser.add_argument("--device", type=str, default="cuda", choices=["cpu", "cuda"], help="Inference device") |
| | |
| | args = parser.parse_args() |
| | |
| | |
| | pipeline = SAMAudioONNXPipeline( |
| | model_dir=args.model_dir, |
| | device=args.device, |
| | num_ode_steps=args.steps, |
| | ) |
| | |
| | |
| | if not args.audio and not args.video: |
| | parser.error("At least one of --audio or --video must be provided.") |
| | |
| | |
| | if not args.text and not args.video: |
| | parser.error("--text is required for audio-only separation.") |
| |
|
| | audio_path = args.audio if args.audio else args.video |
| | |
| | |
| | print(f"\nLoading audio from: {audio_path}") |
| | audio = load_audio(audio_path, target_sr=48000) |
| | print(f"Audio duration: {len(audio)/48000:.2f} seconds") |
| | |
| | |
| | try: |
| | |
| | target_audio, residual_audio, masked_frames, fps = pipeline.separate( |
| | audio, |
| | args.text, |
| | video_path=args.video if args.video else None, |
| | mask_path=args.mask |
| | ) |
| | |
| | |
| | save_audio(target_audio, args.output, sample_rate=48000) |
| | save_audio(residual_audio, args.output_residual, sample_rate=48000) |
| | |
| | |
| | if args.output_video and masked_frames is not None: |
| | save_video_with_audio(masked_frames, target_audio, args.output_video, sample_rate=48000, fps=fps) |
| | |
| | print(f"\n✓ Done!") |
| | print(f" Target audio saved to: {args.output}") |
| | print(f" Residual audio saved to: {args.output_residual}") |
| | |
| | except Exception as e: |
| | print(f"\nError during separation: {e}") |
| | import traceback |
| | traceback.print_exc() |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|