Spaces:
Runtime error
Runtime error
| import torch | |
| import os | |
| import shutil | |
| import subprocess | |
| import gradio as gr | |
| import json | |
| import tempfile | |
| from huggingface_hub import snapshot_download | |
| import soundfile as sf | |
| import tempfile | |
| from datetime import datetime | |
| is_shared_ui = True if "fffiloni/Meigen-MultiTalk" in os.environ['SPACE_ID'] else False | |
| def trim_audio_to_5s_temp(audio_path, sample_rate=16000): | |
| max_duration_sec = 5 | |
| audio, sr = sf.read(audio_path) | |
| if sr != sample_rate: | |
| raise ValueError(f"Expected sample rate {sample_rate}, but got {sr}") | |
| max_samples = max_duration_sec * sample_rate | |
| if len(audio) > max_samples: | |
| audio = audio[:max_samples] | |
| timestamp = datetime.now().strftime("%Y%m%d%H%M%S%f") | |
| base_name = os.path.splitext(os.path.basename(audio_path))[0] | |
| temp_filename = f"{base_name}_trimmed_{timestamp}.wav" | |
| temp_path = os.path.join(tempfile.gettempdir(), temp_filename) | |
| sf.write(temp_path, audio, samplerate=sample_rate) | |
| return temp_path | |
| num_gpus = torch.cuda.device_count() | |
| print(f"GPU AVAILABLE: {num_gpus}") | |
| # Download All Required Models using `snapshot_download` | |
| # Download Wan2.1-I2V-14B-480P model | |
| wan_model_path = snapshot_download( | |
| repo_id="Wan-AI/Wan2.1-I2V-14B-480P", | |
| local_dir="./weights/Wan2.1-I2V-14B-480P", | |
| #local_dir_use_symlinks=False | |
| ) | |
| # Download Chinese wav2vec2 model | |
| wav2vec_path = snapshot_download( | |
| repo_id="TencentGameMate/chinese-wav2vec2-base", | |
| local_dir="./weights/chinese-wav2vec2-base", | |
| #local_dir_use_symlinks=False | |
| ) | |
| # Download MeiGen MultiTalk weights | |
| multitalk_path = snapshot_download( | |
| repo_id="MeiGen-AI/MeiGen-MultiTalk", | |
| local_dir="./weights/MeiGen-MultiTalk", | |
| #local_dir_use_symlinks=False | |
| ) | |
| # Define paths | |
| base_model_dir = "./weights/Wan2.1-I2V-14B-480P" | |
| multitalk_dir = "./weights/MeiGen-MultiTalk" | |
| # File to rename | |
| original_index = os.path.join(base_model_dir, "diffusion_pytorch_model.safetensors.index.json") | |
| backup_index = os.path.join(base_model_dir, "diffusion_pytorch_model.safetensors.index.json_old") | |
| # Rename the original index file | |
| if os.path.exists(original_index): | |
| os.rename(original_index, backup_index) | |
| print("Renamed original index file to .json_old") | |
| # Copy updated index file from MultiTalk | |
| shutil.copy2( | |
| os.path.join(multitalk_dir, "diffusion_pytorch_model.safetensors.index.json"), | |
| base_model_dir | |
| ) | |
| # Copy MultiTalk model weights | |
| shutil.copy2( | |
| os.path.join(multitalk_dir, "multitalk.safetensors"), | |
| base_model_dir | |
| ) | |
| print("Copied MultiTalk files into base model directory.") | |
| # Check if CUDA-compatible GPU is available | |
| if torch.cuda.is_available(): | |
| # Get current GPU name | |
| gpu_name = torch.cuda.get_device_name(torch.cuda.current_device()) | |
| print(f"Current GPU: {gpu_name}") | |
| # Enforce GPU requirement | |
| if "A100" not in gpu_name and "L4" not in gpu_name: | |
| raise RuntimeError(f"This notebook requires an A100 or L4 GPU. Found: {gpu_name}") | |
| elif "L4" in gpu_name: | |
| print("Warning: L4 is supported, but A100 is recommended for faster inference.") | |
| else: | |
| raise RuntimeError("No CUDA-compatible GPU found. An A100 or L4 GPU is required.") | |
| GPU_TO_VRAM_PARAMS = { | |
| "NVIDIA A100": 11000000000, | |
| "NVIDIA A100-SXM4-40GB": 11000000000, | |
| "NVIDIA A100-SXM4-80GB": 22000000000, | |
| "NVIDIA L4": 5000000000, | |
| "NVIDIA L40S": 22000000000 | |
| } | |
| USED_VRAM_PARAMS = GPU_TO_VRAM_PARAMS[gpu_name] | |
| print("Using", USED_VRAM_PARAMS, "for num_persistent_param_in_dit") | |
| def create_temp_input_json(prompt: str, cond_image_path: str, cond_audio_path: str) -> str: | |
| """ | |
| Create a temporary JSON file with the user-provided prompt, image, and audio paths. | |
| Returns the path to the temporary JSON file. | |
| """ | |
| # Structure based on your original JSON format | |
| data = { | |
| "prompt": prompt, | |
| "cond_image": cond_image_path, | |
| "cond_audio": { | |
| "person1": cond_audio_path | |
| } | |
| } | |
| # Create a temp file | |
| temp_json = tempfile.NamedTemporaryFile(delete=False, suffix=".json", mode='w', encoding='utf-8') | |
| json.dump(data, temp_json, indent=4) | |
| temp_json_path = temp_json.name | |
| temp_json.close() | |
| print(f"Temporary input JSON saved to: {temp_json_path}") | |
| return temp_json_path | |
| def infer(prompt, cond_image_path, cond_audio_path): | |
| if is_shared_ui: | |
| trimmed_audio_path = trim_audio_to_5s_temp(cond_audio_path) | |
| cond_audio_path = trimmed_audio_path | |
| # Prepare input JSON | |
| input_json_path = create_temp_input_json(prompt, cond_image_path, cond_audio_path) | |
| # Base args | |
| common_args = [ | |
| "--ckpt_dir", "weights/Wan2.1-I2V-14B-480P", | |
| "--wav2vec_dir", "weights/chinese-wav2vec2-base", | |
| "--input_json", input_json_path, | |
| "--sample_steps", "6", | |
| "--mode", "streaming", | |
| "--use_teacache", | |
| "--save_file", "multi_long_multigpu_exp" | |
| ] | |
| if num_gpus > 1: | |
| cmd = [ | |
| "torchrun", | |
| f"--nproc_per_node={num_gpus}", | |
| "--standalone", | |
| "generate_multitalk.py", | |
| "--dit_fsdp", "--t5_fsdp", | |
| "--ulysses_size", str(num_gpus), | |
| ] + common_args | |
| else: | |
| cmd = ["python3", "generate_multitalk.py"] + common_args | |
| try: | |
| # Log to file and stream | |
| with open("inference.log", "w") as log_file: | |
| process = subprocess.Popen( | |
| cmd, | |
| stdout=subprocess.PIPE, | |
| stderr=subprocess.STDOUT, | |
| text=True, | |
| bufsize=1 | |
| ) | |
| for line in process.stdout: | |
| print(line, end="") | |
| log_file.write(line) | |
| process.wait() | |
| if process.returncode != 0: | |
| raise RuntimeError("Inference failed. Check inference.log for details.") | |
| return "multi_long_multigpu_exp.mp4" | |
| finally: | |
| if os.path.exists(trimmed_audio_path): | |
| os.remove(trimmed_audio_path) | |
| with gr.Blocks(title="MultiTalk Inference") as demo: | |
| gr.Markdown("## 🎤 MultiTalk Inference Demo") | |
| with gr.Row(): | |
| with gr.Column(): | |
| prompt_input = gr.Textbox( | |
| label="Text Prompt", | |
| placeholder="Describe the scene...", | |
| lines=4 | |
| ) | |
| image_input = gr.Image( | |
| type="filepath", | |
| label="Conditioning Image" | |
| ) | |
| audio_input = gr.Audio( | |
| type="filepath", | |
| label="Conditioning Audio (.wav)" | |
| ) | |
| submit_btn = gr.Button("Generate") | |
| gr.Examples( | |
| examples = [ | |
| ["A woman sings passionately in a dimly lit studio.", "examples/single/single1.png", "examples/single/1.wav"] | |
| ], | |
| inputs = [prompt_input, image_input, audio_input] | |
| ) | |
| with gr.Column(): | |
| output_video = gr.Video(label="Generated Video") | |
| submit_btn.click( | |
| fn=infer, | |
| inputs=[prompt_input, image_input, audio_input], | |
| outputs=output_video | |
| ) | |
| demo.launch() | |