marcosremar2 Claude Opus 4.5 commited on
Commit
af11910
·
1 Parent(s): bc451c3

fix: correct facial alignment issues and add API server

Browse files

- Fix frame shape bug in inference.py line 216 (use ori_frame instead of frame)
- Adjust upper_boundary_ratio from 0.5 to 0.4 for better facial blending
- Add MuseTalk API server with multiple versions
- Add inference configs and helper scripts
- Update .gitignore for conda environments

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <[email protected]>

.gitignore CHANGED
@@ -15,4 +15,12 @@ ffmprobe*
15
  ffplay*
16
  debug
17
  exp_out
18
- .gradio
 
 
 
 
 
 
 
 
 
15
  ffplay*
16
  debug
17
  exp_out
18
+ .gradio
19
+
20
+ # Conda environment (Lightning AI persistent)
21
+ .conda_env/
22
+ miniconda/
23
+ venv/
24
+
25
+ # Arquivos temporários de instalação
26
+ =*
activate.sh ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # Script para ativar o ambiente MuseTalk
3
+
4
+ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
5
+
6
+ # Ativar conda local
7
+ source "$SCRIPT_DIR/miniconda/bin/activate" musetalk
8
+
9
+ # Configurar token HuggingFace (defina sua variável HF_TOKEN ou edite aqui)
10
+ # export HF_TOKEN="seu_token_aqui"
11
+
12
+ # Configurar FFMPEG se necessário
13
+ # export FFMPEG_PATH="$SCRIPT_DIR/ffmpeg"
14
+
15
+ cd "${SCRIPT_DIR}"
16
+ echo "✅ Ambiente MuseTalk ativado!"
17
+ echo "Diretório: ${SCRIPT_DIR}"
18
+ echo "Python: $(python --version)"
19
+ echo "PyTorch: $(python -c 'import torch; print(torch.__version__)')"
avatar_pipeline.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Multimodal Avatar Pipeline
3
+ Audio Input -> Whisper -> LLM -> XTTS -> MuseTalk
4
+
5
+ This creates a complete avatar that can understand spoken Portuguese
6
+ and respond with lip-synced video.
7
+ """
8
+
9
+ import os
10
+ import requests
11
+ import tempfile
12
+ import soundfile as sf
13
+
14
+ os.environ["HF_HOME"] = "/workspace/MuseTalk/.cache/huggingface"
15
+ os.environ["COQUI_TOS_AGREED"] = "1"
16
+
17
+ from faster_whisper import WhisperModel
18
+ from llama_cpp import Llama
19
+ from TTS.api import TTS
20
+
21
+
22
+ class MultimodalAvatar:
23
+ def __init__(
24
+ self,
25
+ whisper_model: str = "tiny",
26
+ llm_model_path: str = "models/llm/qwen2.5-0.5b-instruct-q4_k_m.gguf",
27
+ reference_audio: str = "data/audio/mariana_ref.wav",
28
+ avatar_id: str = "mariana_hd",
29
+ musetalk_url: str = "http://localhost:8000",
30
+ system_prompt: str = None
31
+ ):
32
+ print("Initializing Multimodal Avatar Pipeline...")
33
+
34
+ # Whisper for speech-to-text
35
+ print(" Loading Whisper...")
36
+ self.whisper = WhisperModel(whisper_model, device="cpu", compute_type="int8")
37
+
38
+ # LLM for understanding and response
39
+ print(" Loading LLM...")
40
+ self.llm = Llama(
41
+ model_path=llm_model_path,
42
+ n_ctx=2048,
43
+ n_threads=4,
44
+ verbose=False
45
+ )
46
+
47
+ # XTTS for text-to-speech
48
+ print(" Loading XTTS...")
49
+ self.tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2", gpu=False)
50
+
51
+ self.reference_audio = reference_audio
52
+ self.avatar_id = avatar_id
53
+ self.musetalk_url = musetalk_url
54
+
55
+ self.system_prompt = system_prompt or """Você é Mariana, uma assistente virtual brasileira.
56
+ Você é simpática, prestativa e sempre responde em português brasileiro.
57
+ Suas respostas são claras, concisas e naturais, como se estivesse conversando.
58
+ Evite respostas muito longas - prefira 2-3 frases no máximo."""
59
+
60
+ self.conversation_history = []
61
+ print("Avatar ready!")
62
+
63
+ def transcribe(self, audio_path: str) -> str:
64
+ """Transcribe audio to text using Whisper"""
65
+ segments, info = self.whisper.transcribe(audio_path, language="pt")
66
+ text = " ".join([segment.text for segment in segments]).strip()
67
+ return text
68
+
69
+ def think(self, user_message: str) -> str:
70
+ """Generate response using LLM"""
71
+ self.conversation_history.append({"role": "user", "content": user_message})
72
+
73
+ messages = [{"role": "system", "content": self.system_prompt}]
74
+ messages.extend(self.conversation_history[-10:]) # Keep last 10 messages
75
+
76
+ response = self.llm.create_chat_completion(
77
+ messages=messages,
78
+ max_tokens=200,
79
+ temperature=0.7,
80
+ stop=["<|im_end|>", "<|endoftext|>"]
81
+ )
82
+
83
+ assistant_message = response['choices'][0]['message']['content'].strip()
84
+ self.conversation_history.append({"role": "assistant", "content": assistant_message})
85
+
86
+ return assistant_message
87
+
88
+ def speak(self, text: str, output_path: str) -> str:
89
+ """Convert text to speech using XTTS"""
90
+ self.tts.tts_to_file(
91
+ text=text,
92
+ speaker_wav=self.reference_audio,
93
+ language="pt",
94
+ file_path=output_path
95
+ )
96
+ return output_path
97
+
98
+ def animate(self, audio_path: str, output_path: str) -> str:
99
+ """Generate lip-sync video using MuseTalk"""
100
+ with open(audio_path, 'rb') as f:
101
+ response = requests.post(
102
+ f"{self.musetalk_url}/inference",
103
+ files={"audio": f},
104
+ data={"avatar_id": self.avatar_id},
105
+ timeout=300
106
+ )
107
+
108
+ if response.status_code == 200:
109
+ with open(output_path, 'wb') as f:
110
+ f.write(response.content)
111
+ return output_path
112
+ else:
113
+ raise Exception(f"MuseTalk error: {response.text}")
114
+
115
+ def respond(self, audio_input: str, output_video: str) -> dict:
116
+ """
117
+ Complete pipeline: audio input -> transcribe -> think -> speak -> animate
118
+
119
+ Returns dict with all intermediate results
120
+ """
121
+ print("\n=== Processing Request ===")
122
+
123
+ # Step 1: Transcribe
124
+ print("1. Transcribing audio...")
125
+ user_text = self.transcribe(audio_input)
126
+ print(f" User said: {user_text}")
127
+
128
+ # Step 2: Think
129
+ print("2. Generating response...")
130
+ response_text = self.think(user_text)
131
+ print(f" Response: {response_text}")
132
+
133
+ # Step 3: Speak
134
+ print("3. Synthesizing speech...")
135
+ audio_output = output_video.replace('.mp4', '.wav')
136
+ self.speak(response_text, audio_output)
137
+
138
+ # Get audio duration
139
+ data, sr = sf.read(audio_output)
140
+ audio_duration = len(data) / sr
141
+ print(f" Audio duration: {audio_duration:.2f}s")
142
+
143
+ # Step 4: Animate
144
+ print("4. Generating lip-sync video...")
145
+ self.animate(audio_output, output_video)
146
+ print(f" Video saved: {output_video}")
147
+
148
+ return {
149
+ "user_text": user_text,
150
+ "response_text": response_text,
151
+ "audio_path": audio_output,
152
+ "video_path": output_video,
153
+ "audio_duration": audio_duration
154
+ }
155
+
156
+ def respond_to_text(self, user_text: str, output_video: str) -> dict:
157
+ """
158
+ Pipeline for text input: think -> speak -> animate
159
+ """
160
+ print("\n=== Processing Text Request ===")
161
+ print(f" User: {user_text}")
162
+
163
+ # Step 1: Think
164
+ print("1. Generating response...")
165
+ response_text = self.think(user_text)
166
+ print(f" Response: {response_text}")
167
+
168
+ # Step 2: Speak
169
+ print("2. Synthesizing speech...")
170
+ audio_output = output_video.replace('.mp4', '.wav')
171
+ self.speak(response_text, audio_output)
172
+
173
+ data, sr = sf.read(audio_output)
174
+ audio_duration = len(data) / sr
175
+ print(f" Audio duration: {audio_duration:.2f}s")
176
+
177
+ # Step 3: Animate
178
+ print("3. Generating lip-sync video...")
179
+ self.animate(audio_output, output_video)
180
+ print(f" Video saved: {output_video}")
181
+
182
+ return {
183
+ "user_text": user_text,
184
+ "response_text": response_text,
185
+ "audio_path": audio_output,
186
+ "video_path": output_video,
187
+ "audio_duration": audio_duration
188
+ }
189
+
190
+
191
+ if __name__ == "__main__":
192
+ # Initialize the avatar
193
+ avatar = MultimodalAvatar()
194
+
195
+ # Test with text input
196
+ result = avatar.respond_to_text(
197
+ user_text="Olá Mariana! Me conte sobre você.",
198
+ output_video="results/avatar_test.mp4"
199
+ )
200
+
201
+ print("\n=== Result ===")
202
+ print(f"User: {result['user_text']}")
203
+ print(f"Mariana: {result['response_text']}")
204
+ print(f"Video: {result['video_path']} ({result['audio_duration']:.1f}s)")
configs/inference/hello_world.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ task_hello_world:
2
+ video_path: "data/video/video_hd_1min_25fps.mp4"
3
+ audio_path: "data/audio/hello_world.wav"
configs/inference/professor_test.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ task_0:
2
+ video_path: "data/video/yongen.mp4"
3
+ audio_path: "data/audio/professor_pt.wav"
4
+ bbox_shift: 0
musetalk/utils/blending.py CHANGED
@@ -32,7 +32,7 @@ def face_seg(image, mode="raw", fp=None):
32
  return seg_image
33
 
34
 
35
- def get_image(image, face, face_box, upper_boundary_ratio=0.5, expand=1.5, mode="raw", fp=None):
36
  """
37
  将裁剪的面部图像粘贴回原始图像,并进行一些处理。
38
 
@@ -109,7 +109,7 @@ def get_image_blending(image, face, face_box, mask_array, crop_box):
109
  return body[:,:,::-1]
110
 
111
 
112
- def get_image_prepare_material(image, face_box, upper_boundary_ratio=0.5, expand=1.5, fp=None, mode="raw"):
113
  body = Image.fromarray(image[:,:,::-1])
114
 
115
  x, y, x1, y1 = face_box
 
32
  return seg_image
33
 
34
 
35
+ def get_image(image, face, face_box, upper_boundary_ratio=0.4, expand=1.5, mode="raw", fp=None):
36
  """
37
  将裁剪的面部图像粘贴回原始图像,并进行一些处理。
38
 
 
109
  return body[:,:,::-1]
110
 
111
 
112
+ def get_image_prepare_material(image, face_box, upper_boundary_ratio=0.4, expand=1.5, fp=None, mode="raw"):
113
  body = Image.fromarray(image[:,:,::-1])
114
 
115
  x, y, x1, y1 = face_box
musetalk_api_server.py ADDED
@@ -0,0 +1,551 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MuseTalk HTTP API Server
3
+ Keeps models loaded in GPU memory for fast inference.
4
+ """
5
+ import os
6
+ import cv2
7
+ import copy
8
+ import torch
9
+ import glob
10
+ import shutil
11
+ import pickle
12
+ import numpy as np
13
+ import subprocess
14
+ import tempfile
15
+ import hashlib
16
+ import time
17
+ from pathlib import Path
18
+ from typing import Optional
19
+ from fastapi import FastAPI, File, UploadFile, Form, HTTPException, BackgroundTasks
20
+ from fastapi.responses import FileResponse, JSONResponse
21
+ from fastapi.middleware.cors import CORSMiddleware
22
+ from pydantic import BaseModel
23
+ from tqdm import tqdm
24
+ from omegaconf import OmegaConf
25
+ from transformers import WhisperModel
26
+ import uvicorn
27
+
28
+ # MuseTalk imports
29
+ from musetalk.utils.blending import get_image
30
+ from musetalk.utils.face_parsing import FaceParsing
31
+ from musetalk.utils.audio_processor import AudioProcessor
32
+ from musetalk.utils.utils import get_file_type, datagen, load_all_model
33
+ from musetalk.utils.preprocessing import get_landmark_and_bbox, read_imgs, coord_placeholder
34
+
35
+
36
+ class MuseTalkServer:
37
+ """Singleton server that keeps models loaded in GPU memory."""
38
+
39
+ def __init__(self):
40
+ self.device = None
41
+ self.vae = None
42
+ self.unet = None
43
+ self.pe = None
44
+ self.whisper = None
45
+ self.audio_processor = None
46
+ self.fp = None
47
+ self.timesteps = None
48
+ self.weight_dtype = None
49
+ self.is_loaded = False
50
+
51
+ # Cache directories
52
+ self.cache_dir = Path("./cache")
53
+ self.cache_dir.mkdir(exist_ok=True)
54
+ self.landmarks_cache = self.cache_dir / "landmarks"
55
+ self.latents_cache = self.cache_dir / "latents"
56
+ self.whisper_cache = self.cache_dir / "whisper_features"
57
+ self.landmarks_cache.mkdir(exist_ok=True)
58
+ self.latents_cache.mkdir(exist_ok=True)
59
+ self.whisper_cache.mkdir(exist_ok=True)
60
+
61
+ # Config
62
+ self.fps = 25
63
+ self.batch_size = 8
64
+ self.use_float16 = True
65
+ self.version = "v15"
66
+ self.extra_margin = 10
67
+ self.parsing_mode = "jaw"
68
+ self.left_cheek_width = 90
69
+ self.right_cheek_width = 90
70
+ self.audio_padding_left = 2
71
+ self.audio_padding_right = 2
72
+
73
+ def load_models(
74
+ self,
75
+ gpu_id: int = 0,
76
+ unet_model_path: str = "./models/musetalkV15/unet.pth",
77
+ unet_config: str = "./models/musetalk/config.json",
78
+ vae_type: str = "sd-vae",
79
+ whisper_dir: str = "./models/whisper",
80
+ use_float16: bool = True,
81
+ version: str = "v15"
82
+ ):
83
+ """Load all models into GPU memory."""
84
+ if self.is_loaded:
85
+ print("Models already loaded!")
86
+ return
87
+
88
+ print("=" * 50)
89
+ print("Loading MuseTalk models into GPU memory...")
90
+ print("=" * 50)
91
+
92
+ start_time = time.time()
93
+
94
+ # Set device
95
+ self.device = torch.device(f"cuda:{gpu_id}" if torch.cuda.is_available() else "cpu")
96
+ print(f"Using device: {self.device}")
97
+
98
+ # Load model weights
99
+ print("Loading VAE, UNet, PE...")
100
+ self.vae, self.unet, self.pe = load_all_model(
101
+ unet_model_path=unet_model_path,
102
+ vae_type=vae_type,
103
+ unet_config=unet_config,
104
+ device=self.device
105
+ )
106
+ self.timesteps = torch.tensor([0], device=self.device)
107
+
108
+ # Convert to float16 if enabled
109
+ self.use_float16 = use_float16
110
+ if use_float16:
111
+ print("Converting to float16...")
112
+ self.pe = self.pe.half()
113
+ self.vae.vae = self.vae.vae.half()
114
+ self.unet.model = self.unet.model.half()
115
+
116
+ # Move to device
117
+ self.pe = self.pe.to(self.device)
118
+ self.vae.vae = self.vae.vae.to(self.device)
119
+ self.unet.model = self.unet.model.to(self.device)
120
+
121
+ # Initialize audio processor and Whisper
122
+ print("Loading Whisper model...")
123
+ self.audio_processor = AudioProcessor(feature_extractor_path=whisper_dir)
124
+ self.weight_dtype = self.unet.model.dtype
125
+ self.whisper = WhisperModel.from_pretrained(whisper_dir)
126
+ self.whisper = self.whisper.to(device=self.device, dtype=self.weight_dtype).eval()
127
+ self.whisper.requires_grad_(False)
128
+
129
+ # Initialize face parser
130
+ self.version = version
131
+ if version == "v15":
132
+ self.fp = FaceParsing(
133
+ left_cheek_width=self.left_cheek_width,
134
+ right_cheek_width=self.right_cheek_width
135
+ )
136
+ else:
137
+ self.fp = FaceParsing()
138
+
139
+ self.is_loaded = True
140
+ load_time = time.time() - start_time
141
+ print(f"Models loaded in {load_time:.2f}s")
142
+ print("=" * 50)
143
+ print("Server ready for inference!")
144
+ print("=" * 50)
145
+
146
+ def _get_file_hash(self, file_path: str) -> str:
147
+ """Get MD5 hash of a file for caching."""
148
+ hash_md5 = hashlib.md5()
149
+ with open(file_path, "rb") as f:
150
+ for chunk in iter(lambda: f.read(4096), b""):
151
+ hash_md5.update(chunk)
152
+ return hash_md5.hexdigest()[:16]
153
+
154
+ def _get_cached_landmarks(self, video_hash: str, bbox_shift: int):
155
+ """Get cached landmarks if available."""
156
+ # Disabled due to tensor comparison issues
157
+ return None
158
+
159
+ def _save_landmarks_cache(self, video_hash: str, bbox_shift: int, coord_list, frame_list):
160
+ """Save landmarks to cache."""
161
+ cache_file = self.landmarks_cache / f"{video_hash}_shift{bbox_shift}.pkl"
162
+ with open(cache_file, 'wb') as f:
163
+ pickle.dump((coord_list, frame_list), f)
164
+
165
+ def _get_cached_latents(self, video_hash: str):
166
+ """Get cached VAE latents if available."""
167
+ # Disabled due to tensor comparison issues
168
+ return None
169
+
170
+ def _save_latents_cache(self, video_hash: str, latent_list):
171
+ """Save VAE latents to cache."""
172
+ cache_file = self.latents_cache / f"{video_hash}.pkl"
173
+ with open(cache_file, 'wb') as f:
174
+ pickle.dump(latent_list, f)
175
+
176
+ def _get_cached_whisper(self, audio_hash: str):
177
+ """Get cached Whisper features if available."""
178
+ # Disabled due to tensor comparison issues
179
+ return None
180
+
181
+ def _save_whisper_cache(self, audio_hash: str, whisper_data):
182
+ """Save Whisper features to cache."""
183
+ cache_file = self.whisper_cache / f"{audio_hash}.pkl"
184
+ with open(cache_file, 'wb') as f:
185
+ pickle.dump(whisper_data, f)
186
+
187
+ @torch.no_grad()
188
+ def generate(
189
+ self,
190
+ video_path: str,
191
+ audio_path: str,
192
+ output_path: str,
193
+ fps: Optional[int] = None,
194
+ use_cache: bool = True
195
+ ) -> dict:
196
+ """
197
+ Generate lip-synced video.
198
+
199
+ Returns dict with timing info.
200
+ """
201
+ if not self.is_loaded:
202
+ raise RuntimeError("Models not loaded! Call load_models() first.")
203
+
204
+ fps = fps or self.fps
205
+ timings = {"total": 0}
206
+ total_start = time.time()
207
+
208
+ # Get file hashes for caching
209
+ video_hash = self._get_file_hash(video_path)
210
+ audio_hash = self._get_file_hash(audio_path)
211
+
212
+ # Create temp directory
213
+ temp_dir = tempfile.mkdtemp()
214
+
215
+ try:
216
+ # 1. Extract frames
217
+ t0 = time.time()
218
+ input_basename = Path(video_path).stem
219
+ save_dir_full = os.path.join(temp_dir, "frames")
220
+ os.makedirs(save_dir_full, exist_ok=True)
221
+
222
+ if get_file_type(video_path) == "video":
223
+ cmd = f"ffmpeg -v fatal -i {video_path} -vf fps={fps} -start_number 0 {save_dir_full}/%08d.png"
224
+ os.system(cmd)
225
+ input_img_list = sorted(glob.glob(os.path.join(save_dir_full, '*.[jpJP][pnPN]*[gG]')))
226
+ elif get_file_type(video_path) == "image":
227
+ input_img_list = [video_path]
228
+ else:
229
+ raise ValueError(f"Unsupported video type: {video_path}")
230
+
231
+ timings["frame_extraction"] = time.time() - t0
232
+
233
+ # 2. Extract audio features (with caching)
234
+ t0 = time.time()
235
+ cached_whisper = self._get_cached_whisper(audio_hash) if use_cache else None
236
+
237
+ if cached_whisper:
238
+ whisper_chunks = cached_whisper
239
+ timings["whisper_source"] = "cache"
240
+ else:
241
+ whisper_input_features, librosa_length = self.audio_processor.get_audio_feature(audio_path)
242
+ whisper_chunks = self.audio_processor.get_whisper_chunk(
243
+ whisper_input_features,
244
+ self.device,
245
+ self.weight_dtype,
246
+ self.whisper,
247
+ librosa_length,
248
+ fps=fps,
249
+ audio_padding_length_left=self.audio_padding_left,
250
+ audio_padding_length_right=self.audio_padding_right,
251
+ )
252
+ if use_cache:
253
+ self._save_whisper_cache(audio_hash, whisper_chunks)
254
+ timings["whisper_source"] = "computed"
255
+
256
+ timings["whisper_features"] = time.time() - t0
257
+
258
+ # 3. Get landmarks (with caching)
259
+ t0 = time.time()
260
+ bbox_shift = 0 if self.version == "v15" else 0
261
+ cache_key = f"{video_hash}_{fps}"
262
+
263
+ cached_landmarks = self._get_cached_landmarks(cache_key, bbox_shift) if use_cache else None
264
+
265
+ if cached_landmarks:
266
+ coord_list, frame_list = cached_landmarks
267
+ timings["landmarks_source"] = "cache"
268
+ else:
269
+ coord_list, frame_list = get_landmark_and_bbox(input_img_list, bbox_shift)
270
+ if use_cache:
271
+ self._save_landmarks_cache(cache_key, bbox_shift, coord_list, frame_list)
272
+ timings["landmarks_source"] = "computed"
273
+
274
+ timings["landmarks"] = time.time() - t0
275
+
276
+ # 4. Compute VAE latents (with caching)
277
+ t0 = time.time()
278
+ latent_cache_key = f"{video_hash}_{fps}_{self.version}"
279
+ cached_latents = self._get_cached_latents(latent_cache_key) if use_cache else None
280
+
281
+ if cached_latents:
282
+ input_latent_list = cached_latents
283
+ timings["latents_source"] = "cache"
284
+ else:
285
+ input_latent_list = []
286
+ for bbox, frame in zip(coord_list, frame_list):
287
+ if isinstance(bbox, (list, tuple)) and list(bbox) == list(coord_placeholder):
288
+ continue
289
+ x1, y1, x2, y2 = bbox
290
+ if self.version == "v15":
291
+ y2 = y2 + self.extra_margin
292
+ y2 = min(y2, frame.shape[0])
293
+ crop_frame = frame[y1:y2, x1:x2]
294
+ crop_frame = cv2.resize(crop_frame, (256, 256), interpolation=cv2.INTER_LANCZOS4)
295
+ latents = self.vae.get_latents_for_unet(crop_frame)
296
+ input_latent_list.append(latents)
297
+
298
+ if use_cache:
299
+ self._save_latents_cache(latent_cache_key, input_latent_list)
300
+ timings["latents_source"] = "computed"
301
+
302
+ timings["vae_encoding"] = time.time() - t0
303
+
304
+ # 5. Prepare cycled lists
305
+ frame_list_cycle = frame_list + frame_list[::-1]
306
+ coord_list_cycle = coord_list + coord_list[::-1]
307
+ input_latent_list_cycle = input_latent_list + input_latent_list[::-1]
308
+
309
+ # 6. UNet inference
310
+ t0 = time.time()
311
+ video_num = len(whisper_chunks)
312
+ gen = datagen(
313
+ whisper_chunks=whisper_chunks,
314
+ vae_encode_latents=input_latent_list_cycle,
315
+ batch_size=self.batch_size,
316
+ delay_frame=0,
317
+ device=self.device,
318
+ )
319
+
320
+ res_frame_list = []
321
+ for whisper_batch, latent_batch in gen:
322
+ audio_feature_batch = self.pe(whisper_batch)
323
+ latent_batch = latent_batch.to(dtype=self.unet.model.dtype)
324
+ pred_latents = self.unet.model(
325
+ latent_batch, self.timesteps,
326
+ encoder_hidden_states=audio_feature_batch
327
+ ).sample
328
+ recon = self.vae.decode_latents(pred_latents)
329
+ for res_frame in recon:
330
+ res_frame_list.append(res_frame)
331
+
332
+ timings["unet_inference"] = time.time() - t0
333
+
334
+ # 7. Face blending
335
+ t0 = time.time()
336
+ result_img_path = os.path.join(temp_dir, "results")
337
+ os.makedirs(result_img_path, exist_ok=True)
338
+
339
+ for i, res_frame in enumerate(res_frame_list):
340
+ bbox = coord_list_cycle[i % len(coord_list_cycle)]
341
+ ori_frame = copy.deepcopy(frame_list_cycle[i % len(frame_list_cycle)])
342
+ x1, y1, x2, y2 = bbox
343
+ if self.version == "v15":
344
+ y2 = y2 + self.extra_margin
345
+ y2 = min(y2, ori_frame.shape[0])
346
+ try:
347
+ res_frame = cv2.resize(res_frame.astype(np.uint8), (x2-x1, y2-y1))
348
+ except:
349
+ continue
350
+
351
+ if self.version == "v15":
352
+ combine_frame = get_image(ori_frame, res_frame, [x1, y1, x2, y2],
353
+ mode=self.parsing_mode, fp=self.fp)
354
+ else:
355
+ combine_frame = get_image(ori_frame, res_frame, [x1, y1, x2, y2], fp=self.fp)
356
+
357
+ cv2.imwrite(f"{result_img_path}/{str(i).zfill(8)}.png", combine_frame)
358
+
359
+ timings["face_blending"] = time.time() - t0
360
+
361
+ # 8. Encode video
362
+ t0 = time.time()
363
+ temp_vid = os.path.join(temp_dir, "temp.mp4")
364
+ cmd_img2video = f"ffmpeg -y -v warning -r {fps} -f image2 -i {result_img_path}/%08d.png -vcodec libx264 -vf format=yuv420p -crf 18 {temp_vid}"
365
+ os.system(cmd_img2video)
366
+
367
+ cmd_combine = f"ffmpeg -y -v warning -i {audio_path} -i {temp_vid} {output_path}"
368
+ os.system(cmd_combine)
369
+
370
+ timings["video_encoding"] = time.time() - t0
371
+
372
+ finally:
373
+ # Cleanup
374
+ shutil.rmtree(temp_dir, ignore_errors=True)
375
+
376
+ timings["total"] = time.time() - total_start
377
+ timings["frames_generated"] = len(res_frame_list)
378
+
379
+ return timings
380
+
381
+
382
+ # Global server instance
383
+ server = MuseTalkServer()
384
+
385
+ # FastAPI app
386
+ app = FastAPI(
387
+ title="MuseTalk API",
388
+ description="HTTP API for MuseTalk lip-sync generation",
389
+ version="1.0.0"
390
+ )
391
+
392
+ # CORS middleware
393
+ app.add_middleware(
394
+ CORSMiddleware,
395
+ allow_origins=["*"],
396
+ allow_credentials=True,
397
+ allow_methods=["*"],
398
+ allow_headers=["*"],
399
+ )
400
+
401
+
402
+ @app.on_event("startup")
403
+ async def startup_event():
404
+ """Load models on server startup."""
405
+ server.load_models()
406
+
407
+
408
+ @app.get("/health")
409
+ async def health_check():
410
+ """Check if server is ready."""
411
+ return {
412
+ "status": "ok" if server.is_loaded else "loading",
413
+ "models_loaded": server.is_loaded,
414
+ "device": str(server.device) if server.device else None
415
+ }
416
+
417
+
418
+ @app.get("/cache/stats")
419
+ async def cache_stats():
420
+ """Get cache statistics."""
421
+ landmarks_count = len(list(server.landmarks_cache.glob("*.pkl")))
422
+ latents_count = len(list(server.latents_cache.glob("*.pkl")))
423
+ whisper_count = len(list(server.whisper_cache.glob("*.pkl")))
424
+
425
+ return {
426
+ "landmarks_cached": landmarks_count,
427
+ "latents_cached": latents_count,
428
+ "whisper_features_cached": whisper_count
429
+ }
430
+
431
+
432
+ @app.post("/cache/clear")
433
+ async def clear_cache():
434
+ """Clear all caches."""
435
+ for cache_dir in [server.landmarks_cache, server.latents_cache, server.whisper_cache]:
436
+ for f in cache_dir.glob("*.pkl"):
437
+ f.unlink()
438
+ return {"status": "cleared"}
439
+
440
+
441
+ class GenerateRequest(BaseModel):
442
+ video_path: str
443
+ audio_path: str
444
+ output_path: str
445
+ fps: Optional[int] = 25
446
+ use_cache: bool = True
447
+
448
+
449
+ @app.post("/generate")
450
+ async def generate_from_paths(request: GenerateRequest):
451
+ """
452
+ Generate lip-synced video from file paths.
453
+
454
+ Use this when files are already on the server.
455
+ """
456
+ if not server.is_loaded:
457
+ raise HTTPException(status_code=503, detail="Models not loaded yet")
458
+
459
+ if not os.path.exists(request.video_path):
460
+ raise HTTPException(status_code=404, detail=f"Video not found: {request.video_path}")
461
+ if not os.path.exists(request.audio_path):
462
+ raise HTTPException(status_code=404, detail=f"Audio not found: {request.audio_path}")
463
+
464
+ try:
465
+ timings = server.generate(
466
+ video_path=request.video_path,
467
+ audio_path=request.audio_path,
468
+ output_path=request.output_path,
469
+ fps=request.fps,
470
+ use_cache=request.use_cache
471
+ )
472
+ return {
473
+ "status": "success",
474
+ "output_path": request.output_path,
475
+ "timings": timings
476
+ }
477
+ except Exception as e:
478
+ raise HTTPException(status_code=500, detail=str(e))
479
+
480
+
481
+ @app.post("/generate/upload")
482
+ async def generate_from_upload(
483
+ video: UploadFile = File(...),
484
+ audio: UploadFile = File(...),
485
+ fps: int = Form(25),
486
+ use_cache: bool = Form(True)
487
+ ):
488
+ """
489
+ Generate lip-synced video from uploaded files.
490
+
491
+ Returns the generated video file.
492
+ """
493
+ if not server.is_loaded:
494
+ raise HTTPException(status_code=503, detail="Models not loaded yet")
495
+
496
+ # Save uploaded files
497
+ temp_dir = tempfile.mkdtemp()
498
+ try:
499
+ video_path = os.path.join(temp_dir, video.filename)
500
+ audio_path = os.path.join(temp_dir, audio.filename)
501
+ output_path = os.path.join(temp_dir, "output.mp4")
502
+
503
+ with open(video_path, "wb") as f:
504
+ f.write(await video.read())
505
+ with open(audio_path, "wb") as f:
506
+ f.write(await audio.read())
507
+
508
+ timings = server.generate(
509
+ video_path=video_path,
510
+ audio_path=audio_path,
511
+ output_path=output_path,
512
+ fps=fps,
513
+ use_cache=use_cache
514
+ )
515
+
516
+ # Return the video file
517
+ return FileResponse(
518
+ output_path,
519
+ media_type="video/mp4",
520
+ filename="result.mp4",
521
+ headers={"X-Timings": str(timings)}
522
+ )
523
+ except Exception as e:
524
+ shutil.rmtree(temp_dir, ignore_errors=True)
525
+ raise HTTPException(status_code=500, detail=str(e))
526
+
527
+
528
+ if __name__ == "__main__":
529
+ import argparse
530
+
531
+ parser = argparse.ArgumentParser(description="MuseTalk API Server")
532
+ parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind")
533
+ parser.add_argument("--port", type=int, default=8000, help="Port to bind")
534
+ parser.add_argument("--gpu_id", type=int, default=0, help="GPU ID")
535
+ parser.add_argument("--unet_model_path", type=str, default="./models/musetalkV15/unet.pth")
536
+ parser.add_argument("--unet_config", type=str, default="./models/musetalk/config.json")
537
+ parser.add_argument("--whisper_dir", type=str, default="./models/whisper")
538
+ parser.add_argument("--no_float16", action="store_true", help="Disable float16")
539
+ args = parser.parse_args()
540
+
541
+ # Pre-configure server
542
+ server.load_models(
543
+ gpu_id=args.gpu_id,
544
+ unet_model_path=args.unet_model_path,
545
+ unet_config=args.unet_config,
546
+ whisper_dir=args.whisper_dir,
547
+ use_float16=not args.no_float16
548
+ )
549
+
550
+ # Start server
551
+ uvicorn.run(app, host=args.host, port=args.port)
musetalk_api_server_v2.py ADDED
@@ -0,0 +1,445 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MuseTalk HTTP API Server v2
3
+ Optimized for repeated use of the same avatar.
4
+ """
5
+ import os
6
+ import cv2
7
+ import copy
8
+ import torch
9
+ import glob
10
+ import shutil
11
+ import pickle
12
+ import numpy as np
13
+ import subprocess
14
+ import tempfile
15
+ import hashlib
16
+ import time
17
+ from pathlib import Path
18
+ from typing import Optional
19
+ from fastapi import FastAPI, File, UploadFile, Form, HTTPException, BackgroundTasks
20
+ from fastapi.responses import FileResponse, JSONResponse
21
+ from fastapi.middleware.cors import CORSMiddleware
22
+ from pydantic import BaseModel
23
+ from tqdm import tqdm
24
+ from omegaconf import OmegaConf
25
+ from transformers import WhisperModel
26
+ import uvicorn
27
+
28
+ # MuseTalk imports
29
+ from musetalk.utils.blending import get_image
30
+ from musetalk.utils.face_parsing import FaceParsing
31
+ from musetalk.utils.audio_processor import AudioProcessor
32
+ from musetalk.utils.utils import get_file_type, datagen, load_all_model
33
+ from musetalk.utils.preprocessing import get_landmark_and_bbox, read_imgs, coord_placeholder
34
+
35
+
36
+ class MuseTalkServerV2:
37
+ """Server optimized for pre-processed avatars."""
38
+
39
+ def __init__(self):
40
+ self.device = None
41
+ self.vae = None
42
+ self.unet = None
43
+ self.pe = None
44
+ self.whisper = None
45
+ self.audio_processor = None
46
+ self.fp = None
47
+ self.timesteps = None
48
+ self.weight_dtype = None
49
+ self.is_loaded = False
50
+
51
+ # Avatar cache (in-memory)
52
+ self.loaded_avatars = {}
53
+ self.avatar_dir = Path("./avatars")
54
+
55
+ # Config
56
+ self.fps = 25
57
+ self.batch_size = 8
58
+ self.use_float16 = True
59
+ self.version = "v15"
60
+ self.extra_margin = 10
61
+ self.parsing_mode = "jaw"
62
+ self.left_cheek_width = 90
63
+ self.right_cheek_width = 90
64
+ self.audio_padding_left = 2
65
+ self.audio_padding_right = 2
66
+
67
+ def load_models(
68
+ self,
69
+ gpu_id: int = 0,
70
+ unet_model_path: str = "./models/musetalkV15/unet.pth",
71
+ unet_config: str = "./models/musetalk/config.json",
72
+ vae_type: str = "sd-vae",
73
+ whisper_dir: str = "./models/whisper",
74
+ use_float16: bool = True,
75
+ version: str = "v15"
76
+ ):
77
+ if self.is_loaded:
78
+ print("Models already loaded!")
79
+ return
80
+
81
+ print("=" * 50)
82
+ print("Loading MuseTalk models into GPU memory...")
83
+ print("=" * 50)
84
+
85
+ start_time = time.time()
86
+ self.device = torch.device(f"cuda:{gpu_id}" if torch.cuda.is_available() else "cpu")
87
+ print(f"Using device: {self.device}")
88
+
89
+ print("Loading VAE, UNet, PE...")
90
+ self.vae, self.unet, self.pe = load_all_model(
91
+ unet_model_path=unet_model_path,
92
+ vae_type=vae_type,
93
+ unet_config=unet_config,
94
+ device=self.device
95
+ )
96
+ self.timesteps = torch.tensor([0], device=self.device)
97
+
98
+ self.use_float16 = use_float16
99
+ if use_float16:
100
+ print("Converting to float16...")
101
+ self.pe = self.pe.half()
102
+ self.vae.vae = self.vae.vae.half()
103
+ self.unet.model = self.unet.model.half()
104
+
105
+ self.pe = self.pe.to(self.device)
106
+ self.vae.vae = self.vae.vae.to(self.device)
107
+ self.unet.model = self.unet.model.to(self.device)
108
+
109
+ print("Loading Whisper model...")
110
+ self.audio_processor = AudioProcessor(feature_extractor_path=whisper_dir)
111
+ self.weight_dtype = self.unet.model.dtype
112
+ self.whisper = WhisperModel.from_pretrained(whisper_dir)
113
+ self.whisper = self.whisper.to(device=self.device, dtype=self.weight_dtype).eval()
114
+ self.whisper.requires_grad_(False)
115
+
116
+ self.version = version
117
+ if version == "v15":
118
+ self.fp = FaceParsing(
119
+ left_cheek_width=self.left_cheek_width,
120
+ right_cheek_width=self.right_cheek_width
121
+ )
122
+ else:
123
+ self.fp = FaceParsing()
124
+
125
+ self.is_loaded = True
126
+ print(f"Models loaded in {time.time() - start_time:.2f}s")
127
+ print("=" * 50)
128
+
129
+ def load_avatar(self, avatar_name: str) -> dict:
130
+ """Load a preprocessed avatar into memory."""
131
+ if avatar_name in self.loaded_avatars:
132
+ return self.loaded_avatars[avatar_name]
133
+
134
+ avatar_path = self.avatar_dir / avatar_name
135
+ if not avatar_path.exists():
136
+ raise FileNotFoundError(f"Avatar not found: {avatar_name}")
137
+
138
+ print(f"Loading avatar '{avatar_name}' into memory...")
139
+ t0 = time.time()
140
+
141
+ avatar_data = {}
142
+
143
+ # Load metadata
144
+ with open(avatar_path / "metadata.pkl", 'rb') as f:
145
+ avatar_data['metadata'] = pickle.load(f)
146
+
147
+ # Load coords
148
+ with open(avatar_path / "coords.pkl", 'rb') as f:
149
+ avatar_data['coord_list'] = pickle.load(f)
150
+
151
+ # Load frames
152
+ with open(avatar_path / "frames.pkl", 'rb') as f:
153
+ avatar_data['frame_list'] = pickle.load(f)
154
+
155
+ # Load latents and convert to GPU tensors
156
+ with open(avatar_path / "latents.pkl", 'rb') as f:
157
+ latents_np = pickle.load(f)
158
+ avatar_data['latent_list'] = [
159
+ torch.from_numpy(l).to(self.device) for l in latents_np
160
+ ]
161
+
162
+ # Load crop info
163
+ with open(avatar_path / "crop_info.pkl", 'rb') as f:
164
+ avatar_data['crop_info'] = pickle.load(f)
165
+
166
+ # Load parsing data (optional)
167
+ parsing_path = avatar_path / "parsing.pkl"
168
+ if parsing_path.exists():
169
+ with open(parsing_path, 'rb') as f:
170
+ avatar_data['parsing_data'] = pickle.load(f)
171
+
172
+ self.loaded_avatars[avatar_name] = avatar_data
173
+ print(f"Avatar loaded in {time.time() - t0:.2f}s")
174
+
175
+ return avatar_data
176
+
177
+ def unload_avatar(self, avatar_name: str):
178
+ """Unload avatar from memory."""
179
+ if avatar_name in self.loaded_avatars:
180
+ del self.loaded_avatars[avatar_name]
181
+ torch.cuda.empty_cache()
182
+
183
+ @torch.no_grad()
184
+ def generate_with_avatar(
185
+ self,
186
+ avatar_name: str,
187
+ audio_path: str,
188
+ output_path: str,
189
+ fps: Optional[int] = None
190
+ ) -> dict:
191
+ """Generate video using pre-processed avatar. Much faster!"""
192
+ if not self.is_loaded:
193
+ raise RuntimeError("Models not loaded!")
194
+
195
+ fps = fps or self.fps
196
+ timings = {}
197
+ total_start = time.time()
198
+
199
+ # Load avatar (cached in memory)
200
+ t0 = time.time()
201
+ avatar = self.load_avatar(avatar_name)
202
+ timings["avatar_load"] = time.time() - t0
203
+
204
+ coord_list = avatar['coord_list']
205
+ frame_list = avatar['frame_list']
206
+ input_latent_list = avatar['latent_list']
207
+
208
+ temp_dir = tempfile.mkdtemp()
209
+
210
+ try:
211
+ # 1. Extract audio features (only audio-dependent step that's heavy)
212
+ t0 = time.time()
213
+ whisper_input_features, librosa_length = self.audio_processor.get_audio_feature(audio_path)
214
+ whisper_chunks = self.audio_processor.get_whisper_chunk(
215
+ whisper_input_features,
216
+ self.device,
217
+ self.weight_dtype,
218
+ self.whisper,
219
+ librosa_length,
220
+ fps=fps,
221
+ audio_padding_length_left=self.audio_padding_left,
222
+ audio_padding_length_right=self.audio_padding_right,
223
+ )
224
+ timings["whisper_features"] = time.time() - t0
225
+
226
+ # 2. Prepare cycled lists
227
+ frame_list_cycle = frame_list + frame_list[::-1]
228
+ coord_list_cycle = coord_list + coord_list[::-1]
229
+ input_latent_list_cycle = input_latent_list + input_latent_list[::-1]
230
+
231
+ # 3. UNet inference
232
+ t0 = time.time()
233
+ gen = datagen(
234
+ whisper_chunks=whisper_chunks,
235
+ vae_encode_latents=input_latent_list_cycle,
236
+ batch_size=self.batch_size,
237
+ delay_frame=0,
238
+ device=self.device,
239
+ )
240
+
241
+ res_frame_list = []
242
+ for whisper_batch, latent_batch in gen:
243
+ audio_feature_batch = self.pe(whisper_batch)
244
+ latent_batch = latent_batch.to(dtype=self.unet.model.dtype)
245
+ pred_latents = self.unet.model(
246
+ latent_batch, self.timesteps,
247
+ encoder_hidden_states=audio_feature_batch
248
+ ).sample
249
+ recon = self.vae.decode_latents(pred_latents)
250
+ for res_frame in recon:
251
+ res_frame_list.append(res_frame)
252
+
253
+ timings["unet_inference"] = time.time() - t0
254
+
255
+ # 4. Face blending
256
+ t0 = time.time()
257
+ result_img_path = os.path.join(temp_dir, "results")
258
+ os.makedirs(result_img_path, exist_ok=True)
259
+
260
+ for i, res_frame in enumerate(res_frame_list):
261
+ bbox = coord_list_cycle[i % len(coord_list_cycle)]
262
+ ori_frame = copy.deepcopy(frame_list_cycle[i % len(frame_list_cycle)])
263
+ x1, y1, x2, y2 = bbox
264
+
265
+ if self.version == "v15":
266
+ y2 = y2 + self.extra_margin
267
+ y2 = min(y2, ori_frame.shape[0])
268
+
269
+ try:
270
+ res_frame = cv2.resize(res_frame.astype(np.uint8), (x2-x1, y2-y1))
271
+ except:
272
+ continue
273
+
274
+ if self.version == "v15":
275
+ combine_frame = get_image(ori_frame, res_frame, [x1, y1, x2, y2],
276
+ mode=self.parsing_mode, fp=self.fp)
277
+ else:
278
+ combine_frame = get_image(ori_frame, res_frame, [x1, y1, x2, y2], fp=self.fp)
279
+
280
+ cv2.imwrite(f"{result_img_path}/{str(i).zfill(8)}.png", combine_frame)
281
+
282
+ timings["face_blending"] = time.time() - t0
283
+
284
+ # 5. Encode video
285
+ t0 = time.time()
286
+ temp_vid = os.path.join(temp_dir, "temp.mp4")
287
+ cmd_img2video = f"ffmpeg -y -v warning -r {fps} -f image2 -i {result_img_path}/%08d.png -vcodec libx264 -vf format=yuv420p -crf 18 {temp_vid}"
288
+ os.system(cmd_img2video)
289
+
290
+ cmd_combine = f"ffmpeg -y -v warning -i {audio_path} -i {temp_vid} {output_path}"
291
+ os.system(cmd_combine)
292
+
293
+ timings["video_encoding"] = time.time() - t0
294
+
295
+ finally:
296
+ shutil.rmtree(temp_dir, ignore_errors=True)
297
+
298
+ timings["total"] = time.time() - total_start
299
+ timings["frames_generated"] = len(res_frame_list)
300
+
301
+ return timings
302
+
303
+
304
+ # Global server instance
305
+ server = MuseTalkServerV2()
306
+
307
+ # FastAPI app
308
+ app = FastAPI(
309
+ title="MuseTalk API v2",
310
+ description="Optimized API for repeated avatar usage",
311
+ version="2.0.0"
312
+ )
313
+
314
+ app.add_middleware(
315
+ CORSMiddleware,
316
+ allow_origins=["*"],
317
+ allow_credentials=True,
318
+ allow_methods=["*"],
319
+ allow_headers=["*"],
320
+ )
321
+
322
+
323
+ @app.on_event("startup")
324
+ async def startup_event():
325
+ server.load_models()
326
+
327
+
328
+ @app.get("/health")
329
+ async def health_check():
330
+ return {
331
+ "status": "ok" if server.is_loaded else "loading",
332
+ "models_loaded": server.is_loaded,
333
+ "device": str(server.device) if server.device else None,
334
+ "loaded_avatars": list(server.loaded_avatars.keys())
335
+ }
336
+
337
+
338
+ @app.get("/avatars")
339
+ async def list_avatars():
340
+ """List all available preprocessed avatars."""
341
+ avatars = []
342
+ for p in server.avatar_dir.iterdir():
343
+ if p.is_dir() and (p / "metadata.pkl").exists():
344
+ with open(p / "metadata.pkl", 'rb') as f:
345
+ metadata = pickle.load(f)
346
+ metadata['loaded'] = p.name in server.loaded_avatars
347
+ avatars.append(metadata)
348
+ return {"avatars": avatars}
349
+
350
+
351
+ @app.post("/avatars/{avatar_name}/load")
352
+ async def load_avatar(avatar_name: str):
353
+ """Pre-load an avatar into GPU memory."""
354
+ try:
355
+ server.load_avatar(avatar_name)
356
+ return {"status": "loaded", "avatar_name": avatar_name}
357
+ except FileNotFoundError as e:
358
+ raise HTTPException(status_code=404, detail=str(e))
359
+
360
+
361
+ @app.post("/avatars/{avatar_name}/unload")
362
+ async def unload_avatar(avatar_name: str):
363
+ """Unload an avatar from memory."""
364
+ server.unload_avatar(avatar_name)
365
+ return {"status": "unloaded", "avatar_name": avatar_name}
366
+
367
+
368
+ class GenerateWithAvatarRequest(BaseModel):
369
+ avatar_name: str
370
+ audio_path: str
371
+ output_path: str
372
+ fps: Optional[int] = 25
373
+
374
+
375
+ @app.post("/generate/avatar")
376
+ async def generate_with_avatar(request: GenerateWithAvatarRequest):
377
+ """Generate video using pre-processed avatar. FAST!"""
378
+ if not server.is_loaded:
379
+ raise HTTPException(status_code=503, detail="Models not loaded")
380
+
381
+ if not os.path.exists(request.audio_path):
382
+ raise HTTPException(status_code=404, detail=f"Audio not found: {request.audio_path}")
383
+
384
+ try:
385
+ timings = server.generate_with_avatar(
386
+ avatar_name=request.avatar_name,
387
+ audio_path=request.audio_path,
388
+ output_path=request.output_path,
389
+ fps=request.fps
390
+ )
391
+ return {
392
+ "status": "success",
393
+ "output_path": request.output_path,
394
+ "timings": timings
395
+ }
396
+ except FileNotFoundError as e:
397
+ raise HTTPException(status_code=404, detail=str(e))
398
+ except Exception as e:
399
+ raise HTTPException(status_code=500, detail=str(e))
400
+
401
+
402
+ @app.post("/generate/avatar/upload")
403
+ async def generate_with_avatar_upload(
404
+ avatar_name: str = Form(...),
405
+ audio: UploadFile = File(...),
406
+ fps: int = Form(25)
407
+ ):
408
+ """Generate video from uploaded audio using pre-processed avatar."""
409
+ if not server.is_loaded:
410
+ raise HTTPException(status_code=503, detail="Models not loaded")
411
+
412
+ temp_dir = tempfile.mkdtemp()
413
+ try:
414
+ audio_path = os.path.join(temp_dir, audio.filename)
415
+ output_path = os.path.join(temp_dir, "output.mp4")
416
+
417
+ with open(audio_path, "wb") as f:
418
+ f.write(await audio.read())
419
+
420
+ timings = server.generate_with_avatar(
421
+ avatar_name=avatar_name,
422
+ audio_path=audio_path,
423
+ output_path=output_path,
424
+ fps=fps
425
+ )
426
+
427
+ return FileResponse(
428
+ output_path,
429
+ media_type="video/mp4",
430
+ filename="result.mp4",
431
+ headers={"X-Timings": str(timings)}
432
+ )
433
+ except Exception as e:
434
+ shutil.rmtree(temp_dir, ignore_errors=True)
435
+ raise HTTPException(status_code=500, detail=str(e))
436
+
437
+
438
+ if __name__ == "__main__":
439
+ import argparse
440
+ parser = argparse.ArgumentParser()
441
+ parser.add_argument("--host", type=str, default="0.0.0.0")
442
+ parser.add_argument("--port", type=int, default=8000)
443
+ args = parser.parse_args()
444
+
445
+ uvicorn.run(app, host=args.host, port=args.port)
musetalk_api_server_v3.py ADDED
@@ -0,0 +1,651 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MuseTalk HTTP API Server v3
3
+ Ultra-optimized with:
4
+ 1. GPU-accelerated face blending (parallel processing)
5
+ 2. NVENC hardware video encoding
6
+ 3. Batch audio processing
7
+ """
8
+ import os
9
+ import cv2
10
+ import copy
11
+ import torch
12
+ import glob
13
+ import shutil
14
+ import pickle
15
+ import numpy as np
16
+ import subprocess
17
+ import tempfile
18
+ import hashlib
19
+ import time
20
+ import asyncio
21
+ from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor
22
+ from pathlib import Path
23
+ from typing import Optional, List
24
+ from fastapi import FastAPI, File, UploadFile, Form, HTTPException, BackgroundTasks
25
+ from fastapi.responses import FileResponse, JSONResponse
26
+ from fastapi.middleware.cors import CORSMiddleware
27
+ from pydantic import BaseModel
28
+ from tqdm import tqdm
29
+ from omegaconf import OmegaConf
30
+ from transformers import WhisperModel
31
+ import uvicorn
32
+ import multiprocessing as mp
33
+
34
+ # MuseTalk imports
35
+ from musetalk.utils.blending import get_image
36
+ from musetalk.utils.face_parsing import FaceParsing
37
+ from musetalk.utils.audio_processor import AudioProcessor
38
+ from musetalk.utils.utils import get_file_type, datagen, load_all_model
39
+ from musetalk.utils.preprocessing import get_landmark_and_bbox, read_imgs, coord_placeholder
40
+
41
+
42
+ def blend_single_frame(args):
43
+ """Worker function for parallel face blending."""
44
+ i, res_frame, bbox, ori_frame, extra_margin, version, parsing_mode, fp_config = args
45
+
46
+ x1, y1, x2, y2 = bbox
47
+ if version == "v15":
48
+ y2 = y2 + extra_margin
49
+ y2 = min(y2, ori_frame.shape[0])
50
+
51
+ try:
52
+ res_frame = cv2.resize(res_frame.astype(np.uint8), (x2-x1, y2-y1))
53
+ except:
54
+ return i, None
55
+
56
+ # Create FaceParsing instance for this worker
57
+ fp = FaceParsing(
58
+ left_cheek_width=fp_config['left_cheek_width'],
59
+ right_cheek_width=fp_config['right_cheek_width']
60
+ )
61
+
62
+ if version == "v15":
63
+ combine_frame = get_image(ori_frame, res_frame, [x1, y1, x2, y2],
64
+ mode=parsing_mode, fp=fp)
65
+ else:
66
+ combine_frame = get_image(ori_frame, res_frame, [x1, y1, x2, y2], fp=fp)
67
+
68
+ return i, combine_frame
69
+
70
+
71
+ class MuseTalkServerV3:
72
+ """Ultra-optimized server."""
73
+
74
+ def __init__(self):
75
+ self.device = None
76
+ self.vae = None
77
+ self.unet = None
78
+ self.pe = None
79
+ self.whisper = None
80
+ self.audio_processor = None
81
+ self.fp = None
82
+ self.timesteps = None
83
+ self.weight_dtype = None
84
+ self.is_loaded = False
85
+
86
+ # Avatar cache
87
+ self.loaded_avatars = {}
88
+ self.avatar_dir = Path("./avatars")
89
+
90
+ # Config
91
+ self.fps = 25
92
+ self.batch_size = 8
93
+ self.use_float16 = True
94
+ self.version = "v15"
95
+ self.extra_margin = 10
96
+ self.parsing_mode = "jaw"
97
+ self.left_cheek_width = 90
98
+ self.right_cheek_width = 90
99
+ self.audio_padding_left = 2
100
+ self.audio_padding_right = 2
101
+
102
+ # Thread pool for parallel blending
103
+ self.num_workers = min(8, mp.cpu_count())
104
+ self.thread_pool = ThreadPoolExecutor(max_workers=self.num_workers)
105
+
106
+ # NVENC settings
107
+ self.use_nvenc = True
108
+ self.nvenc_preset = "p4" # p1(fastest) to p7(best quality)
109
+ self.crf = 23
110
+
111
+ def load_models(
112
+ self,
113
+ gpu_id: int = 0,
114
+ unet_model_path: str = "./models/musetalkV15/unet.pth",
115
+ unet_config: str = "./models/musetalk/config.json",
116
+ vae_type: str = "sd-vae",
117
+ whisper_dir: str = "./models/whisper",
118
+ use_float16: bool = True,
119
+ version: str = "v15"
120
+ ):
121
+ if self.is_loaded:
122
+ print("Models already loaded!")
123
+ return
124
+
125
+ print("=" * 50)
126
+ print("Loading MuseTalk models (v3 Ultra-Optimized)...")
127
+ print("=" * 50)
128
+
129
+ start_time = time.time()
130
+ self.device = torch.device(f"cuda:{gpu_id}" if torch.cuda.is_available() else "cpu")
131
+ print(f"Using device: {self.device}")
132
+ print(f"Parallel workers: {self.num_workers}")
133
+ print(f"NVENC encoding: {self.use_nvenc}")
134
+
135
+ print("Loading VAE, UNet, PE...")
136
+ self.vae, self.unet, self.pe = load_all_model(
137
+ unet_model_path=unet_model_path,
138
+ vae_type=vae_type,
139
+ unet_config=unet_config,
140
+ device=self.device
141
+ )
142
+ self.timesteps = torch.tensor([0], device=self.device)
143
+
144
+ self.use_float16 = use_float16
145
+ if use_float16:
146
+ print("Converting to float16...")
147
+ self.pe = self.pe.half()
148
+ self.vae.vae = self.vae.vae.half()
149
+ self.unet.model = self.unet.model.half()
150
+
151
+ self.pe = self.pe.to(self.device)
152
+ self.vae.vae = self.vae.vae.to(self.device)
153
+ self.unet.model = self.unet.model.to(self.device)
154
+
155
+ print("Loading Whisper model...")
156
+ self.audio_processor = AudioProcessor(feature_extractor_path=whisper_dir)
157
+ self.weight_dtype = self.unet.model.dtype
158
+ self.whisper = WhisperModel.from_pretrained(whisper_dir)
159
+ self.whisper = self.whisper.to(device=self.device, dtype=self.weight_dtype).eval()
160
+ self.whisper.requires_grad_(False)
161
+
162
+ self.version = version
163
+ if version == "v15":
164
+ self.fp = FaceParsing(
165
+ left_cheek_width=self.left_cheek_width,
166
+ right_cheek_width=self.right_cheek_width
167
+ )
168
+ else:
169
+ self.fp = FaceParsing()
170
+
171
+ self.is_loaded = True
172
+ print(f"Models loaded in {time.time() - start_time:.2f}s")
173
+ print("=" * 50)
174
+
175
+ def load_avatar(self, avatar_name: str) -> dict:
176
+ if avatar_name in self.loaded_avatars:
177
+ return self.loaded_avatars[avatar_name]
178
+
179
+ avatar_path = self.avatar_dir / avatar_name
180
+ if not avatar_path.exists():
181
+ raise FileNotFoundError(f"Avatar not found: {avatar_name}")
182
+
183
+ print(f"Loading avatar '{avatar_name}' into memory...")
184
+ t0 = time.time()
185
+
186
+ avatar_data = {}
187
+
188
+ with open(avatar_path / "metadata.pkl", 'rb') as f:
189
+ avatar_data['metadata'] = pickle.load(f)
190
+
191
+ with open(avatar_path / "coords.pkl", 'rb') as f:
192
+ avatar_data['coord_list'] = pickle.load(f)
193
+
194
+ with open(avatar_path / "frames.pkl", 'rb') as f:
195
+ avatar_data['frame_list'] = pickle.load(f)
196
+
197
+ with open(avatar_path / "latents.pkl", 'rb') as f:
198
+ latents_np = pickle.load(f)
199
+ avatar_data['latent_list'] = [
200
+ torch.from_numpy(l).to(self.device) for l in latents_np
201
+ ]
202
+
203
+ with open(avatar_path / "crop_info.pkl", 'rb') as f:
204
+ avatar_data['crop_info'] = pickle.load(f)
205
+
206
+ self.loaded_avatars[avatar_name] = avatar_data
207
+ print(f"Avatar loaded in {time.time() - t0:.2f}s")
208
+
209
+ return avatar_data
210
+
211
+ def unload_avatar(self, avatar_name: str):
212
+ if avatar_name in self.loaded_avatars:
213
+ del self.loaded_avatars[avatar_name]
214
+ torch.cuda.empty_cache()
215
+
216
+ def _encode_video_nvenc(self, frames_dir: str, audio_path: str, output_path: str, fps: int) -> float:
217
+ """Encode video using NVENC hardware acceleration."""
218
+ t0 = time.time()
219
+ temp_vid = frames_dir.replace('/results', '/temp.mp4')
220
+
221
+ if self.use_nvenc:
222
+ # NVENC H.264 encoding (much faster)
223
+ cmd_img2video = (
224
+ f"ffmpeg -y -v warning -r {fps} -f image2 -i {frames_dir}/%08d.png "
225
+ f"-c:v h264_nvenc -preset {self.nvenc_preset} -cq {self.crf} "
226
+ f"-pix_fmt yuv420p {temp_vid}"
227
+ )
228
+ else:
229
+ # Fallback to CPU encoding
230
+ cmd_img2video = (
231
+ f"ffmpeg -y -v warning -r {fps} -f image2 -i {frames_dir}/%08d.png "
232
+ f"-vcodec libx264 -vf format=yuv420p -crf 18 {temp_vid}"
233
+ )
234
+
235
+ os.system(cmd_img2video)
236
+
237
+ # Add audio
238
+ cmd_combine = f"ffmpeg -y -v warning -i {audio_path} -i {temp_vid} -c:v copy -c:a aac {output_path}"
239
+ os.system(cmd_combine)
240
+
241
+ # Cleanup temp video
242
+ if os.path.exists(temp_vid):
243
+ os.remove(temp_vid)
244
+
245
+ return time.time() - t0
246
+
247
+ def _parallel_face_blending(self, res_frame_list, coord_list_cycle, frame_list_cycle, result_img_path) -> float:
248
+ """Parallel face blending using thread pool."""
249
+ t0 = time.time()
250
+
251
+ fp_config = {
252
+ 'left_cheek_width': self.left_cheek_width,
253
+ 'right_cheek_width': self.right_cheek_width
254
+ }
255
+
256
+ # Prepare all tasks
257
+ tasks = []
258
+ for i, res_frame in enumerate(res_frame_list):
259
+ bbox = coord_list_cycle[i % len(coord_list_cycle)]
260
+ ori_frame = copy.deepcopy(frame_list_cycle[i % len(frame_list_cycle)])
261
+ tasks.append((
262
+ i, res_frame, bbox, ori_frame,
263
+ self.extra_margin, self.version, self.parsing_mode, fp_config
264
+ ))
265
+
266
+ # Process in parallel
267
+ results = list(self.thread_pool.map(blend_single_frame, tasks))
268
+
269
+ # Sort and save results
270
+ results.sort(key=lambda x: x[0])
271
+ for i, combine_frame in results:
272
+ if combine_frame is not None:
273
+ cv2.imwrite(f"{result_img_path}/{str(i).zfill(8)}.png", combine_frame)
274
+
275
+ return time.time() - t0
276
+
277
+ @torch.no_grad()
278
+ def generate_with_avatar(
279
+ self,
280
+ avatar_name: str,
281
+ audio_path: str,
282
+ output_path: str,
283
+ fps: Optional[int] = None,
284
+ use_parallel_blending: bool = True
285
+ ) -> dict:
286
+ """Generate video using pre-processed avatar with all optimizations."""
287
+ if not self.is_loaded:
288
+ raise RuntimeError("Models not loaded!")
289
+
290
+ fps = fps or self.fps
291
+ timings = {}
292
+ total_start = time.time()
293
+
294
+ # Load avatar
295
+ t0 = time.time()
296
+ avatar = self.load_avatar(avatar_name)
297
+ timings["avatar_load"] = time.time() - t0
298
+
299
+ coord_list = avatar['coord_list']
300
+ frame_list = avatar['frame_list']
301
+ input_latent_list = avatar['latent_list']
302
+
303
+ temp_dir = tempfile.mkdtemp()
304
+
305
+ try:
306
+ # 1. Extract audio features
307
+ t0 = time.time()
308
+ whisper_input_features, librosa_length = self.audio_processor.get_audio_feature(audio_path)
309
+ whisper_chunks = self.audio_processor.get_whisper_chunk(
310
+ whisper_input_features,
311
+ self.device,
312
+ self.weight_dtype,
313
+ self.whisper,
314
+ librosa_length,
315
+ fps=fps,
316
+ audio_padding_length_left=self.audio_padding_left,
317
+ audio_padding_length_right=self.audio_padding_right,
318
+ )
319
+ timings["whisper_features"] = time.time() - t0
320
+
321
+ # 2. Prepare cycled lists
322
+ frame_list_cycle = frame_list + frame_list[::-1]
323
+ coord_list_cycle = coord_list + coord_list[::-1]
324
+ input_latent_list_cycle = input_latent_list + input_latent_list[::-1]
325
+
326
+ # 3. UNet inference
327
+ t0 = time.time()
328
+ gen = datagen(
329
+ whisper_chunks=whisper_chunks,
330
+ vae_encode_latents=input_latent_list_cycle,
331
+ batch_size=self.batch_size,
332
+ delay_frame=0,
333
+ device=self.device,
334
+ )
335
+
336
+ res_frame_list = []
337
+ for whisper_batch, latent_batch in gen:
338
+ audio_feature_batch = self.pe(whisper_batch)
339
+ latent_batch = latent_batch.to(dtype=self.unet.model.dtype)
340
+ pred_latents = self.unet.model(
341
+ latent_batch, self.timesteps,
342
+ encoder_hidden_states=audio_feature_batch
343
+ ).sample
344
+ recon = self.vae.decode_latents(pred_latents)
345
+ for res_frame in recon:
346
+ res_frame_list.append(res_frame)
347
+
348
+ timings["unet_inference"] = time.time() - t0
349
+
350
+ # 4. Face blending (parallel or sequential)
351
+ result_img_path = os.path.join(temp_dir, "results")
352
+ os.makedirs(result_img_path, exist_ok=True)
353
+
354
+ if use_parallel_blending:
355
+ timings["face_blending"] = self._parallel_face_blending(
356
+ res_frame_list, coord_list_cycle, frame_list_cycle, result_img_path
357
+ )
358
+ timings["blending_mode"] = "parallel"
359
+ else:
360
+ t0 = time.time()
361
+ for i, res_frame in enumerate(res_frame_list):
362
+ bbox = coord_list_cycle[i % len(coord_list_cycle)]
363
+ ori_frame = copy.deepcopy(frame_list_cycle[i % len(frame_list_cycle)])
364
+ x1, y1, x2, y2 = bbox
365
+
366
+ if self.version == "v15":
367
+ y2 = y2 + self.extra_margin
368
+ y2 = min(y2, ori_frame.shape[0])
369
+
370
+ try:
371
+ res_frame = cv2.resize(res_frame.astype(np.uint8), (x2-x1, y2-y1))
372
+ except:
373
+ continue
374
+
375
+ if self.version == "v15":
376
+ combine_frame = get_image(ori_frame, res_frame, [x1, y1, x2, y2],
377
+ mode=self.parsing_mode, fp=self.fp)
378
+ else:
379
+ combine_frame = get_image(ori_frame, res_frame, [x1, y1, x2, y2], fp=self.fp)
380
+
381
+ cv2.imwrite(f"{result_img_path}/{str(i).zfill(8)}.png", combine_frame)
382
+ timings["face_blending"] = time.time() - t0
383
+ timings["blending_mode"] = "sequential"
384
+
385
+ # 5. Video encoding (NVENC)
386
+ timings["video_encoding"] = self._encode_video_nvenc(
387
+ result_img_path, audio_path, output_path, fps
388
+ )
389
+ timings["encoding_mode"] = "nvenc" if self.use_nvenc else "cpu"
390
+
391
+ finally:
392
+ shutil.rmtree(temp_dir, ignore_errors=True)
393
+
394
+ timings["total"] = time.time() - total_start
395
+ timings["frames_generated"] = len(res_frame_list)
396
+
397
+ return timings
398
+
399
+ @torch.no_grad()
400
+ def generate_batch(
401
+ self,
402
+ avatar_name: str,
403
+ audio_paths: List[str],
404
+ output_dir: str,
405
+ fps: Optional[int] = None
406
+ ) -> dict:
407
+ """Generate multiple videos from multiple audios efficiently."""
408
+ if not self.is_loaded:
409
+ raise RuntimeError("Models not loaded!")
410
+
411
+ fps = fps or self.fps
412
+ batch_timings = {"videos": [], "total": 0}
413
+ total_start = time.time()
414
+
415
+ # Load avatar once
416
+ t0 = time.time()
417
+ avatar = self.load_avatar(avatar_name)
418
+ batch_timings["avatar_load"] = time.time() - t0
419
+
420
+ coord_list = avatar['coord_list']
421
+ frame_list = avatar['frame_list']
422
+ input_latent_list = avatar['latent_list']
423
+
424
+ # Prepare cycled lists once
425
+ frame_list_cycle = frame_list + frame_list[::-1]
426
+ coord_list_cycle = coord_list + coord_list[::-1]
427
+ input_latent_list_cycle = input_latent_list + input_latent_list[::-1]
428
+
429
+ os.makedirs(output_dir, exist_ok=True)
430
+
431
+ for idx, audio_path in enumerate(audio_paths):
432
+ video_start = time.time()
433
+ timings = {}
434
+
435
+ audio_name = Path(audio_path).stem
436
+ output_path = os.path.join(output_dir, f"{audio_name}.mp4")
437
+
438
+ temp_dir = tempfile.mkdtemp()
439
+
440
+ try:
441
+ # 1. Extract audio features
442
+ t0 = time.time()
443
+ whisper_input_features, librosa_length = self.audio_processor.get_audio_feature(audio_path)
444
+ whisper_chunks = self.audio_processor.get_whisper_chunk(
445
+ whisper_input_features,
446
+ self.device,
447
+ self.weight_dtype,
448
+ self.whisper,
449
+ librosa_length,
450
+ fps=fps,
451
+ audio_padding_length_left=self.audio_padding_left,
452
+ audio_padding_length_right=self.audio_padding_right,
453
+ )
454
+ timings["whisper_features"] = time.time() - t0
455
+
456
+ # 2. UNet inference
457
+ t0 = time.time()
458
+ gen = datagen(
459
+ whisper_chunks=whisper_chunks,
460
+ vae_encode_latents=input_latent_list_cycle,
461
+ batch_size=self.batch_size,
462
+ delay_frame=0,
463
+ device=self.device,
464
+ )
465
+
466
+ res_frame_list = []
467
+ for whisper_batch, latent_batch in gen:
468
+ audio_feature_batch = self.pe(whisper_batch)
469
+ latent_batch = latent_batch.to(dtype=self.unet.model.dtype)
470
+ pred_latents = self.unet.model(
471
+ latent_batch, self.timesteps,
472
+ encoder_hidden_states=audio_feature_batch
473
+ ).sample
474
+ recon = self.vae.decode_latents(pred_latents)
475
+ for res_frame in recon:
476
+ res_frame_list.append(res_frame)
477
+
478
+ timings["unet_inference"] = time.time() - t0
479
+
480
+ # 3. Face blending (parallel)
481
+ result_img_path = os.path.join(temp_dir, "results")
482
+ os.makedirs(result_img_path, exist_ok=True)
483
+ timings["face_blending"] = self._parallel_face_blending(
484
+ res_frame_list, coord_list_cycle, frame_list_cycle, result_img_path
485
+ )
486
+
487
+ # 4. Video encoding (NVENC)
488
+ timings["video_encoding"] = self._encode_video_nvenc(
489
+ result_img_path, audio_path, output_path, fps
490
+ )
491
+
492
+ finally:
493
+ shutil.rmtree(temp_dir, ignore_errors=True)
494
+
495
+ timings["total"] = time.time() - video_start
496
+ timings["frames_generated"] = len(res_frame_list)
497
+ timings["output_path"] = output_path
498
+ timings["audio_path"] = audio_path
499
+
500
+ batch_timings["videos"].append(timings)
501
+ print(f" [{idx+1}/{len(audio_paths)}] {audio_name}: {timings['total']:.2f}s")
502
+
503
+ batch_timings["total"] = time.time() - total_start
504
+ batch_timings["num_videos"] = len(audio_paths)
505
+ batch_timings["avg_per_video"] = batch_timings["total"] / len(audio_paths) if audio_paths else 0
506
+
507
+ return batch_timings
508
+
509
+
510
+ # Global server
511
+ server = MuseTalkServerV3()
512
+
513
+ # FastAPI app
514
+ app = FastAPI(
515
+ title="MuseTalk API v3",
516
+ description="Ultra-optimized API with parallel blending, NVENC, and batch processing",
517
+ version="3.0.0"
518
+ )
519
+
520
+ app.add_middleware(
521
+ CORSMiddleware,
522
+ allow_origins=["*"],
523
+ allow_credentials=True,
524
+ allow_methods=["*"],
525
+ allow_headers=["*"],
526
+ )
527
+
528
+
529
+ @app.on_event("startup")
530
+ async def startup_event():
531
+ server.load_models()
532
+
533
+
534
+ @app.get("/health")
535
+ async def health_check():
536
+ return {
537
+ "status": "ok" if server.is_loaded else "loading",
538
+ "models_loaded": server.is_loaded,
539
+ "device": str(server.device) if server.device else None,
540
+ "loaded_avatars": list(server.loaded_avatars.keys()),
541
+ "optimizations": {
542
+ "parallel_workers": server.num_workers,
543
+ "nvenc_enabled": server.use_nvenc,
544
+ "nvenc_preset": server.nvenc_preset
545
+ }
546
+ }
547
+
548
+
549
+ @app.get("/avatars")
550
+ async def list_avatars():
551
+ avatars = []
552
+ for p in server.avatar_dir.iterdir():
553
+ if p.is_dir() and (p / "metadata.pkl").exists():
554
+ with open(p / "metadata.pkl", 'rb') as f:
555
+ metadata = pickle.load(f)
556
+ metadata['loaded'] = p.name in server.loaded_avatars
557
+ avatars.append(metadata)
558
+ return {"avatars": avatars}
559
+
560
+
561
+ @app.post("/avatars/{avatar_name}/load")
562
+ async def load_avatar(avatar_name: str):
563
+ try:
564
+ server.load_avatar(avatar_name)
565
+ return {"status": "loaded", "avatar_name": avatar_name}
566
+ except FileNotFoundError as e:
567
+ raise HTTPException(status_code=404, detail=str(e))
568
+
569
+
570
+ @app.post("/avatars/{avatar_name}/unload")
571
+ async def unload_avatar(avatar_name: str):
572
+ server.unload_avatar(avatar_name)
573
+ return {"status": "unloaded", "avatar_name": avatar_name}
574
+
575
+
576
+ class GenerateRequest(BaseModel):
577
+ avatar_name: str
578
+ audio_path: str
579
+ output_path: str
580
+ fps: Optional[int] = 25
581
+ use_parallel_blending: bool = True
582
+
583
+
584
+ @app.post("/generate/avatar")
585
+ async def generate_with_avatar(request: GenerateRequest):
586
+ if not server.is_loaded:
587
+ raise HTTPException(status_code=503, detail="Models not loaded")
588
+
589
+ if not os.path.exists(request.audio_path):
590
+ raise HTTPException(status_code=404, detail=f"Audio not found: {request.audio_path}")
591
+
592
+ try:
593
+ timings = server.generate_with_avatar(
594
+ avatar_name=request.avatar_name,
595
+ audio_path=request.audio_path,
596
+ output_path=request.output_path,
597
+ fps=request.fps,
598
+ use_parallel_blending=request.use_parallel_blending
599
+ )
600
+ return {
601
+ "status": "success",
602
+ "output_path": request.output_path,
603
+ "timings": timings
604
+ }
605
+ except FileNotFoundError as e:
606
+ raise HTTPException(status_code=404, detail=str(e))
607
+ except Exception as e:
608
+ raise HTTPException(status_code=500, detail=str(e))
609
+
610
+
611
+ class BatchGenerateRequest(BaseModel):
612
+ avatar_name: str
613
+ audio_paths: List[str]
614
+ output_dir: str
615
+ fps: Optional[int] = 25
616
+
617
+
618
+ @app.post("/generate/batch")
619
+ async def generate_batch(request: BatchGenerateRequest):
620
+ """Generate multiple videos from multiple audios."""
621
+ if not server.is_loaded:
622
+ raise HTTPException(status_code=503, detail="Models not loaded")
623
+
624
+ for audio_path in request.audio_paths:
625
+ if not os.path.exists(audio_path):
626
+ raise HTTPException(status_code=404, detail=f"Audio not found: {audio_path}")
627
+
628
+ try:
629
+ timings = server.generate_batch(
630
+ avatar_name=request.avatar_name,
631
+ audio_paths=request.audio_paths,
632
+ output_dir=request.output_dir,
633
+ fps=request.fps
634
+ )
635
+ return {
636
+ "status": "success",
637
+ "output_dir": request.output_dir,
638
+ "timings": timings
639
+ }
640
+ except Exception as e:
641
+ raise HTTPException(status_code=500, detail=str(e))
642
+
643
+
644
+ if __name__ == "__main__":
645
+ import argparse
646
+ parser = argparse.ArgumentParser()
647
+ parser.add_argument("--host", type=str, default="0.0.0.0")
648
+ parser.add_argument("--port", type=int, default=8000)
649
+ args = parser.parse_args()
650
+
651
+ uvicorn.run(app, host=args.host, port=args.port)
musetalk_api_server_v3_fixed.py ADDED
@@ -0,0 +1,371 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MuseTalk HTTP API Server v3 (Fixed)
3
+ Optimized with:
4
+ 1. Sequential face blending (parallel had overhead)
5
+ 2. NVENC hardware video encoding
6
+ 3. Batch audio processing
7
+ """
8
+ import os
9
+ import cv2
10
+ import copy
11
+ import torch
12
+ import glob
13
+ import shutil
14
+ import pickle
15
+ import numpy as np
16
+ import subprocess
17
+ import tempfile
18
+ import hashlib
19
+ import time
20
+ from pathlib import Path
21
+ from typing import Optional, List
22
+ from fastapi import FastAPI, File, UploadFile, Form, HTTPException
23
+ from fastapi.responses import FileResponse, JSONResponse
24
+ from fastapi.middleware.cors import CORSMiddleware
25
+ from pydantic import BaseModel
26
+ from tqdm import tqdm
27
+ from transformers import WhisperModel
28
+ import uvicorn
29
+
30
+ # MuseTalk imports
31
+ from musetalk.utils.blending import get_image
32
+ from musetalk.utils.face_parsing import FaceParsing
33
+ from musetalk.utils.audio_processor import AudioProcessor
34
+ from musetalk.utils.utils import get_file_type, datagen, load_all_model
35
+ from musetalk.utils.preprocessing import coord_placeholder
36
+
37
+
38
+ class MuseTalkServerV3:
39
+ def __init__(self):
40
+ self.device = None
41
+ self.vae = None
42
+ self.unet = None
43
+ self.pe = None
44
+ self.whisper = None
45
+ self.audio_processor = None
46
+ self.fp = None
47
+ self.timesteps = None
48
+ self.weight_dtype = None
49
+ self.is_loaded = False
50
+
51
+ self.loaded_avatars = {}
52
+ self.avatar_dir = Path("./avatars")
53
+
54
+ self.fps = 25
55
+ self.batch_size = 8
56
+ self.use_float16 = True
57
+ self.version = "v15"
58
+ self.extra_margin = 10
59
+ self.parsing_mode = "jaw"
60
+ self.left_cheek_width = 90
61
+ self.right_cheek_width = 90
62
+ self.audio_padding_left = 2
63
+ self.audio_padding_right = 2
64
+
65
+ # NVENC
66
+ self.use_nvenc = True
67
+ self.nvenc_preset = "p4"
68
+ self.crf = 23
69
+
70
+ def load_models(self, gpu_id: int = 0):
71
+ if self.is_loaded:
72
+ print("Models already loaded!")
73
+ return
74
+
75
+ print("=" * 50)
76
+ print("Loading MuseTalk models (v3 Optimized)...")
77
+ print("=" * 50)
78
+
79
+ start_time = time.time()
80
+ self.device = torch.device(f"cuda:{gpu_id}" if torch.cuda.is_available() else "cpu")
81
+
82
+ self.vae, self.unet, self.pe = load_all_model(
83
+ unet_model_path="./models/musetalkV15/unet.pth",
84
+ vae_type="sd-vae",
85
+ unet_config="./models/musetalk/config.json",
86
+ device=self.device
87
+ )
88
+ self.timesteps = torch.tensor([0], device=self.device)
89
+
90
+ self.pe = self.pe.half().to(self.device)
91
+ self.vae.vae = self.vae.vae.half().to(self.device)
92
+ self.unet.model = self.unet.model.half().to(self.device)
93
+
94
+ self.audio_processor = AudioProcessor(feature_extractor_path="./models/whisper")
95
+ self.weight_dtype = self.unet.model.dtype
96
+ self.whisper = WhisperModel.from_pretrained("./models/whisper")
97
+ self.whisper = self.whisper.to(device=self.device, dtype=self.weight_dtype).eval()
98
+ self.whisper.requires_grad_(False)
99
+
100
+ self.fp = FaceParsing(
101
+ left_cheek_width=self.left_cheek_width,
102
+ right_cheek_width=self.right_cheek_width
103
+ )
104
+
105
+ self.is_loaded = True
106
+ print(f"Models loaded in {time.time() - start_time:.2f}s")
107
+
108
+ def load_avatar(self, avatar_name: str) -> dict:
109
+ if avatar_name in self.loaded_avatars:
110
+ return self.loaded_avatars[avatar_name]
111
+
112
+ avatar_path = self.avatar_dir / avatar_name
113
+ if not avatar_path.exists():
114
+ raise FileNotFoundError(f"Avatar not found: {avatar_name}")
115
+
116
+ avatar_data = {}
117
+ with open(avatar_path / "metadata.pkl", 'rb') as f:
118
+ avatar_data['metadata'] = pickle.load(f)
119
+ with open(avatar_path / "coords.pkl", 'rb') as f:
120
+ avatar_data['coord_list'] = pickle.load(f)
121
+ with open(avatar_path / "frames.pkl", 'rb') as f:
122
+ avatar_data['frame_list'] = pickle.load(f)
123
+ with open(avatar_path / "latents.pkl", 'rb') as f:
124
+ latents_np = pickle.load(f)
125
+ avatar_data['latent_list'] = [torch.from_numpy(l).to(self.device) for l in latents_np]
126
+
127
+ self.loaded_avatars[avatar_name] = avatar_data
128
+ return avatar_data
129
+
130
+ def _encode_video_nvenc(self, frames_dir: str, audio_path: str, output_path: str, fps: int) -> float:
131
+ t0 = time.time()
132
+ temp_vid = output_path.replace('.mp4', '_temp.mp4')
133
+
134
+ if self.use_nvenc:
135
+ cmd = (
136
+ f"ffmpeg -y -v warning -r {fps} -f image2 -i {frames_dir}/%08d.png "
137
+ f"-c:v h264_nvenc -preset {self.nvenc_preset} -cq {self.crf} -pix_fmt yuv420p {temp_vid}"
138
+ )
139
+ else:
140
+ cmd = (
141
+ f"ffmpeg -y -v warning -r {fps} -f image2 -i {frames_dir}/%08d.png "
142
+ f"-vcodec libx264 -crf 18 -pix_fmt yuv420p {temp_vid}"
143
+ )
144
+ os.system(cmd)
145
+
146
+ os.system(f"ffmpeg -y -v warning -i {audio_path} -i {temp_vid} -c:v copy -c:a aac {output_path}")
147
+ os.remove(temp_vid) if os.path.exists(temp_vid) else None
148
+
149
+ return time.time() - t0
150
+
151
+ @torch.no_grad()
152
+ def generate_with_avatar(self, avatar_name: str, audio_path: str, output_path: str, fps: int = 25) -> dict:
153
+ if not self.is_loaded:
154
+ raise RuntimeError("Models not loaded!")
155
+
156
+ timings = {}
157
+ total_start = time.time()
158
+
159
+ t0 = time.time()
160
+ avatar = self.load_avatar(avatar_name)
161
+ timings["avatar_load"] = time.time() - t0
162
+
163
+ coord_list = avatar['coord_list']
164
+ frame_list = avatar['frame_list']
165
+ input_latent_list = avatar['latent_list']
166
+
167
+ temp_dir = tempfile.mkdtemp()
168
+
169
+ try:
170
+ # Whisper
171
+ t0 = time.time()
172
+ whisper_input_features, librosa_length = self.audio_processor.get_audio_feature(audio_path)
173
+ whisper_chunks = self.audio_processor.get_whisper_chunk(
174
+ whisper_input_features, self.device, self.weight_dtype, self.whisper,
175
+ librosa_length, fps=fps,
176
+ audio_padding_length_left=self.audio_padding_left,
177
+ audio_padding_length_right=self.audio_padding_right,
178
+ )
179
+ timings["whisper_features"] = time.time() - t0
180
+
181
+ # Cycle lists
182
+ frame_list_cycle = frame_list + frame_list[::-1]
183
+ coord_list_cycle = coord_list + coord_list[::-1]
184
+ input_latent_list_cycle = input_latent_list + input_latent_list[::-1]
185
+
186
+ # UNet
187
+ t0 = time.time()
188
+ gen = datagen(whisper_chunks=whisper_chunks, vae_encode_latents=input_latent_list_cycle,
189
+ batch_size=self.batch_size, delay_frame=0, device=self.device)
190
+
191
+ res_frame_list = []
192
+ for whisper_batch, latent_batch in gen:
193
+ audio_feature_batch = self.pe(whisper_batch)
194
+ latent_batch = latent_batch.to(dtype=self.unet.model.dtype)
195
+ pred_latents = self.unet.model(latent_batch, self.timesteps,
196
+ encoder_hidden_states=audio_feature_batch).sample
197
+ recon = self.vae.decode_latents(pred_latents)
198
+ res_frame_list.extend(recon)
199
+ timings["unet_inference"] = time.time() - t0
200
+
201
+ # Face blending (sequential - faster than parallel due to FP overhead)
202
+ t0 = time.time()
203
+ result_img_path = os.path.join(temp_dir, "results")
204
+ os.makedirs(result_img_path, exist_ok=True)
205
+
206
+ for i, res_frame in enumerate(res_frame_list):
207
+ bbox = coord_list_cycle[i % len(coord_list_cycle)]
208
+ ori_frame = copy.deepcopy(frame_list_cycle[i % len(frame_list_cycle)])
209
+ x1, y1, x2, y2 = bbox
210
+ y2 = min(y2 + self.extra_margin, ori_frame.shape[0])
211
+
212
+ try:
213
+ res_frame = cv2.resize(res_frame.astype(np.uint8), (x2-x1, y2-y1))
214
+ combine_frame = get_image(ori_frame, res_frame, [x1, y1, x2, y2],
215
+ mode=self.parsing_mode, fp=self.fp)
216
+ cv2.imwrite(f"{result_img_path}/{str(i).zfill(8)}.png", combine_frame)
217
+ except:
218
+ continue
219
+ timings["face_blending"] = time.time() - t0
220
+
221
+ # NVENC encoding
222
+ timings["video_encoding"] = self._encode_video_nvenc(result_img_path, audio_path, output_path, fps)
223
+
224
+ finally:
225
+ shutil.rmtree(temp_dir, ignore_errors=True)
226
+
227
+ timings["total"] = time.time() - total_start
228
+ timings["frames_generated"] = len(res_frame_list)
229
+ return timings
230
+
231
+ @torch.no_grad()
232
+ def generate_batch(self, avatar_name: str, audio_paths: List[str], output_dir: str, fps: int = 25) -> dict:
233
+ if not self.is_loaded:
234
+ raise RuntimeError("Models not loaded!")
235
+
236
+ batch_timings = {"videos": [], "total": 0}
237
+ total_start = time.time()
238
+
239
+ t0 = time.time()
240
+ avatar = self.load_avatar(avatar_name)
241
+ batch_timings["avatar_load"] = time.time() - t0
242
+
243
+ coord_list = avatar['coord_list']
244
+ frame_list = avatar['frame_list']
245
+ input_latent_list = avatar['latent_list']
246
+ frame_list_cycle = frame_list + frame_list[::-1]
247
+ coord_list_cycle = coord_list + coord_list[::-1]
248
+ input_latent_list_cycle = input_latent_list + input_latent_list[::-1]
249
+
250
+ os.makedirs(output_dir, exist_ok=True)
251
+
252
+ for idx, audio_path in enumerate(audio_paths):
253
+ video_start = time.time()
254
+ timings = {}
255
+ output_path = os.path.join(output_dir, f"{Path(audio_path).stem}.mp4")
256
+ temp_dir = tempfile.mkdtemp()
257
+
258
+ try:
259
+ t0 = time.time()
260
+ whisper_input_features, librosa_length = self.audio_processor.get_audio_feature(audio_path)
261
+ whisper_chunks = self.audio_processor.get_whisper_chunk(
262
+ whisper_input_features, self.device, self.weight_dtype, self.whisper,
263
+ librosa_length, fps=fps,
264
+ audio_padding_length_left=self.audio_padding_left,
265
+ audio_padding_length_right=self.audio_padding_right,
266
+ )
267
+ timings["whisper"] = time.time() - t0
268
+
269
+ t0 = time.time()
270
+ gen = datagen(whisper_chunks=whisper_chunks, vae_encode_latents=input_latent_list_cycle,
271
+ batch_size=self.batch_size, delay_frame=0, device=self.device)
272
+ res_frame_list = []
273
+ for whisper_batch, latent_batch in gen:
274
+ audio_feature_batch = self.pe(whisper_batch)
275
+ latent_batch = latent_batch.to(dtype=self.unet.model.dtype)
276
+ pred_latents = self.unet.model(latent_batch, self.timesteps,
277
+ encoder_hidden_states=audio_feature_batch).sample
278
+ res_frame_list.extend(self.vae.decode_latents(pred_latents))
279
+ timings["unet"] = time.time() - t0
280
+
281
+ t0 = time.time()
282
+ result_img_path = os.path.join(temp_dir, "results")
283
+ os.makedirs(result_img_path, exist_ok=True)
284
+ for i, res_frame in enumerate(res_frame_list):
285
+ bbox = coord_list_cycle[i % len(coord_list_cycle)]
286
+ ori_frame = copy.deepcopy(frame_list_cycle[i % len(frame_list_cycle)])
287
+ x1, y1, x2, y2 = bbox
288
+ y2 = min(y2 + self.extra_margin, ori_frame.shape[0])
289
+ try:
290
+ res_frame = cv2.resize(res_frame.astype(np.uint8), (x2-x1, y2-y1))
291
+ combine_frame = get_image(ori_frame, res_frame, [x1, y1, x2, y2],
292
+ mode=self.parsing_mode, fp=self.fp)
293
+ cv2.imwrite(f"{result_img_path}/{str(i).zfill(8)}.png", combine_frame)
294
+ except:
295
+ continue
296
+ timings["blending"] = time.time() - t0
297
+
298
+ timings["encoding"] = self._encode_video_nvenc(result_img_path, audio_path, output_path, fps)
299
+
300
+ finally:
301
+ shutil.rmtree(temp_dir, ignore_errors=True)
302
+
303
+ timings["total"] = time.time() - video_start
304
+ timings["frames"] = len(res_frame_list)
305
+ timings["output"] = output_path
306
+ batch_timings["videos"].append(timings)
307
+ print(f" [{idx+1}/{len(audio_paths)}] {Path(audio_path).stem}: {timings['total']:.2f}s")
308
+
309
+ batch_timings["total"] = time.time() - total_start
310
+ batch_timings["num_videos"] = len(audio_paths)
311
+ batch_timings["avg_per_video"] = batch_timings["total"] / len(audio_paths) if audio_paths else 0
312
+ return batch_timings
313
+
314
+
315
+ server = MuseTalkServerV3()
316
+ app = FastAPI(title="MuseTalk API v3", version="3.0.0")
317
+ app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"])
318
+
319
+ @app.on_event("startup")
320
+ async def startup():
321
+ server.load_models()
322
+
323
+ @app.get("/health")
324
+ async def health():
325
+ return {"status": "ok" if server.is_loaded else "loading", "device": str(server.device),
326
+ "avatars": list(server.loaded_avatars.keys()), "nvenc": server.use_nvenc}
327
+
328
+ @app.get("/avatars")
329
+ async def list_avatars():
330
+ avatars = []
331
+ for p in server.avatar_dir.iterdir():
332
+ if p.is_dir() and (p / "metadata.pkl").exists():
333
+ with open(p / "metadata.pkl", 'rb') as f:
334
+ avatars.append(pickle.load(f))
335
+ return {"avatars": avatars}
336
+
337
+ class GenReq(BaseModel):
338
+ avatar_name: str
339
+ audio_path: str
340
+ output_path: str
341
+ fps: int = 25
342
+
343
+ @app.post("/generate/avatar")
344
+ async def generate(req: GenReq):
345
+ if not os.path.exists(req.audio_path):
346
+ raise HTTPException(404, f"Audio not found: {req.audio_path}")
347
+ try:
348
+ timings = server.generate_with_avatar(req.avatar_name, req.audio_path, req.output_path, req.fps)
349
+ return {"status": "success", "output_path": req.output_path, "timings": timings}
350
+ except Exception as e:
351
+ raise HTTPException(500, str(e))
352
+
353
+ class BatchReq(BaseModel):
354
+ avatar_name: str
355
+ audio_paths: List[str]
356
+ output_dir: str
357
+ fps: int = 25
358
+
359
+ @app.post("/generate/batch")
360
+ async def batch(req: BatchReq):
361
+ for p in req.audio_paths:
362
+ if not os.path.exists(p):
363
+ raise HTTPException(404, f"Audio not found: {p}")
364
+ try:
365
+ timings = server.generate_batch(req.avatar_name, req.audio_paths, req.output_dir, req.fps)
366
+ return {"status": "success", "output_dir": req.output_dir, "timings": timings}
367
+ except Exception as e:
368
+ raise HTTPException(500, str(e))
369
+
370
+ if __name__ == "__main__":
371
+ uvicorn.run(app, host="0.0.0.0", port=8000)
run_inference.py ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import math
4
+ import copy
5
+ import torch
6
+ import glob
7
+ import shutil
8
+ import pickle
9
+ import argparse
10
+ import numpy as np
11
+ import subprocess
12
+ from tqdm import tqdm
13
+ from omegaconf import OmegaConf
14
+ from transformers import WhisperModel
15
+ import sys
16
+
17
+ from musetalk.utils.blending import get_image
18
+ from musetalk.utils.face_parsing import FaceParsing
19
+ from musetalk.utils.audio_processor import AudioProcessor
20
+ from musetalk.utils.utils import get_file_type, get_video_fps, datagen, load_all_model
21
+ from musetalk.utils.preprocessing import get_landmark_and_bbox, read_imgs, coord_placeholder
22
+
23
+ def fast_check_ffmpeg():
24
+ try:
25
+ subprocess.run(["ffmpeg", "-version"], capture_output=True, check=True)
26
+ return True
27
+ except:
28
+ return False
29
+
30
+ @torch.no_grad()
31
+ def main(args):
32
+ # Configure ffmpeg path
33
+ if not fast_check_ffmpeg():
34
+ print("Adding ffmpeg to PATH")
35
+ # Choose path separator based on operating system
36
+ path_separator = ';' if sys.platform == 'win32' else ':'
37
+ os.environ["PATH"] = f"{args.ffmpeg_path}{path_separator}{os.environ['PATH']}"
38
+ if not fast_check_ffmpeg():
39
+ print("Warning: Unable to find ffmpeg, please ensure ffmpeg is properly installed")
40
+
41
+ # Set computing device
42
+ device = torch.device(f"cuda:{args.gpu_id}" if torch.cuda.is_available() else "cpu")
43
+ # Load model weights
44
+ vae, unet, pe = load_all_model(
45
+ unet_model_path=args.unet_model_path,
46
+ vae_type=args.vae_type,
47
+ unet_config=args.unet_config,
48
+ device=device
49
+ )
50
+ timesteps = torch.tensor([0], device=device)
51
+
52
+ # Convert models to half precision if float16 is enabled
53
+ if args.use_float16:
54
+ pe = pe.half()
55
+ vae.vae = vae.vae.half()
56
+ unet.model = unet.model.half()
57
+
58
+ # Move models to specified device
59
+ pe = pe.to(device)
60
+ vae.vae = vae.vae.to(device)
61
+ unet.model = unet.model.to(device)
62
+
63
+ # Initialize audio processor and Whisper model
64
+ audio_processor = AudioProcessor(feature_extractor_path=args.whisper_dir)
65
+ weight_dtype = unet.model.dtype
66
+ whisper = WhisperModel.from_pretrained(args.whisper_dir)
67
+ whisper = whisper.to(device=device, dtype=weight_dtype).eval()
68
+ whisper.requires_grad_(False)
69
+
70
+ # Initialize face parser with configurable parameters based on version
71
+ if args.version == "v15":
72
+ fp = FaceParsing(
73
+ left_cheek_width=args.left_cheek_width,
74
+ right_cheek_width=args.right_cheek_width
75
+ )
76
+ else: # v1
77
+ fp = FaceParsing()
78
+
79
+ # Load inference configuration
80
+ inference_config = OmegaConf.load(args.inference_config)
81
+ print("Loaded inference config:", inference_config)
82
+
83
+ # Process each task
84
+ for task_id in inference_config:
85
+ try:
86
+ # Get task configuration
87
+ video_path = inference_config[task_id]["video_path"]
88
+ audio_path = inference_config[task_id]["audio_path"]
89
+ if "result_name" in inference_config[task_id]:
90
+ args.output_vid_name = inference_config[task_id]["result_name"]
91
+
92
+ # Set bbox_shift based on version
93
+ if args.version == "v15":
94
+ bbox_shift = 0 # v15 uses fixed bbox_shift
95
+ else:
96
+ bbox_shift = inference_config[task_id].get("bbox_shift", args.bbox_shift) # v1 uses config or default
97
+
98
+ # Set output paths
99
+ input_basename = os.path.basename(video_path).split('.')[0]
100
+ audio_basename = os.path.basename(audio_path).split('.')[0]
101
+ output_basename = f"{input_basename}_{audio_basename}"
102
+
103
+ # Create temporary directories
104
+ temp_dir = os.path.join(args.result_dir, f"{args.version}")
105
+ os.makedirs(temp_dir, exist_ok=True)
106
+
107
+ # Set result save paths
108
+ result_img_save_path = os.path.join(temp_dir, output_basename)
109
+ crop_coord_save_path = os.path.join(args.result_dir, "../", input_basename+".pkl")
110
+ os.makedirs(result_img_save_path, exist_ok=True)
111
+
112
+ # Set output video paths
113
+ if args.output_vid_name is None:
114
+ output_vid_name = os.path.join(temp_dir, output_basename + ".mp4")
115
+ else:
116
+ output_vid_name = os.path.join(temp_dir, args.output_vid_name)
117
+ output_vid_name_concat = os.path.join(temp_dir, output_basename + "_concat.mp4")
118
+
119
+ # Extract frames from source video
120
+ if get_file_type(video_path) == "video":
121
+ save_dir_full = os.path.join(temp_dir, input_basename)
122
+ os.makedirs(save_dir_full, exist_ok=True)
123
+ cmd = f"ffmpeg -v fatal -i {video_path} -vf fps={args.fps} -start_number 0 {save_dir_full}/%08d.png" # PATCHED: extract at target fps
124
+ os.system(cmd)
125
+ input_img_list = sorted(glob.glob(os.path.join(save_dir_full, '*.[jpJP][pnPN]*[gG]')))
126
+ fps = args.fps # PATCHED: use target fps instead of video fps
127
+ elif get_file_type(video_path) == "image":
128
+ input_img_list = [video_path]
129
+ fps = args.fps
130
+ elif os.path.isdir(video_path):
131
+ input_img_list = glob.glob(os.path.join(video_path, '*.[jpJP][pnPN]*[gG]'))
132
+ input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
133
+ fps = args.fps
134
+ else:
135
+ raise ValueError(f"{video_path} should be a video file, an image file or a directory of images")
136
+
137
+ # Extract audio features
138
+ whisper_input_features, librosa_length = audio_processor.get_audio_feature(audio_path)
139
+ whisper_chunks = audio_processor.get_whisper_chunk(
140
+ whisper_input_features,
141
+ device,
142
+ weight_dtype,
143
+ whisper,
144
+ librosa_length,
145
+ fps=fps,
146
+ audio_padding_length_left=args.audio_padding_length_left,
147
+ audio_padding_length_right=args.audio_padding_length_right,
148
+ )
149
+
150
+ # Preprocess input images
151
+ if os.path.exists(crop_coord_save_path) and args.use_saved_coord:
152
+ print("Using saved coordinates")
153
+ with open(crop_coord_save_path, 'rb') as f:
154
+ coord_list = pickle.load(f)
155
+ frame_list = read_imgs(input_img_list)
156
+ else:
157
+ print("Extracting landmarks... time-consuming operation")
158
+ coord_list, frame_list = get_landmark_and_bbox(input_img_list, bbox_shift)
159
+ with open(crop_coord_save_path, 'wb') as f:
160
+ pickle.dump(coord_list, f)
161
+
162
+ print(f"Number of frames: {len(frame_list)}")
163
+
164
+ # Process each frame
165
+ input_latent_list = []
166
+ for bbox, frame in zip(coord_list, frame_list):
167
+ if bbox == coord_placeholder:
168
+ continue
169
+ x1, y1, x2, y2 = bbox
170
+ if args.version == "v15":
171
+ y2 = y2 + args.extra_margin
172
+ y2 = min(y2, frame.shape[0])
173
+ crop_frame = frame[y1:y2, x1:x2]
174
+ crop_frame = cv2.resize(crop_frame, (256,256), interpolation=cv2.INTER_LANCZOS4)
175
+ latents = vae.get_latents_for_unet(crop_frame)
176
+ input_latent_list.append(latents)
177
+
178
+ # Smooth first and last frames
179
+ frame_list_cycle = frame_list + frame_list[::-1]
180
+ coord_list_cycle = coord_list + coord_list[::-1]
181
+ input_latent_list_cycle = input_latent_list + input_latent_list[::-1]
182
+
183
+ # Batch inference
184
+ print("Starting inference")
185
+ video_num = len(whisper_chunks)
186
+ batch_size = args.batch_size
187
+ gen = datagen(
188
+ whisper_chunks=whisper_chunks,
189
+ vae_encode_latents=input_latent_list_cycle,
190
+ batch_size=batch_size,
191
+ delay_frame=0,
192
+ device=device,
193
+ )
194
+
195
+ res_frame_list = []
196
+ total = int(np.ceil(float(video_num) / batch_size))
197
+
198
+ # Execute inference
199
+ for i, (whisper_batch, latent_batch) in enumerate(tqdm(gen, total=total)):
200
+ audio_feature_batch = pe(whisper_batch)
201
+ latent_batch = latent_batch.to(dtype=unet.model.dtype)
202
+
203
+ pred_latents = unet.model(latent_batch, timesteps, encoder_hidden_states=audio_feature_batch).sample
204
+ recon = vae.decode_latents(pred_latents)
205
+ for res_frame in recon:
206
+ res_frame_list.append(res_frame)
207
+
208
+ # Pad generated images to original video size
209
+ print("Padding generated images to original video size")
210
+ for i, res_frame in enumerate(tqdm(res_frame_list)):
211
+ bbox = coord_list_cycle[i%(len(coord_list_cycle))]
212
+ ori_frame = copy.deepcopy(frame_list_cycle[i%(len(frame_list_cycle))])
213
+ x1, y1, x2, y2 = bbox
214
+ if args.version == "v15":
215
+ y2 = y2 + args.extra_margin
216
+ y2 = min(y2, frame.shape[0])
217
+ try:
218
+ res_frame = cv2.resize(res_frame.astype(np.uint8), (x2-x1, y2-y1))
219
+ except:
220
+ continue
221
+
222
+ # Merge results with version-specific parameters
223
+ if args.version == "v15":
224
+ combine_frame = get_image(ori_frame, res_frame, [x1, y1, x2, y2], mode=args.parsing_mode, fp=fp)
225
+ else:
226
+ combine_frame = get_image(ori_frame, res_frame, [x1, y1, x2, y2], fp=fp)
227
+ cv2.imwrite(f"{result_img_save_path}/{str(i).zfill(8)}.png", combine_frame)
228
+
229
+ # Save prediction results
230
+ temp_vid_path = f"{temp_dir}/temp_{input_basename}_{audio_basename}.mp4"
231
+ cmd_img2video = f"ffmpeg -y -v warning -r {fps} -f image2 -i {result_img_save_path}/%08d.png -vcodec libx264 -vf format=yuv420p -crf 18 {temp_vid_path}"
232
+ print("Video generation command:", cmd_img2video)
233
+ os.system(cmd_img2video)
234
+
235
+ cmd_combine_audio = f"ffmpeg -y -v warning -i {audio_path} -i {temp_vid_path} {output_vid_name}"
236
+ print("Audio combination command:", cmd_combine_audio)
237
+ os.system(cmd_combine_audio)
238
+
239
+ # Clean up temporary files
240
+ shutil.rmtree(result_img_save_path)
241
+ os.remove(temp_vid_path)
242
+
243
+ shutil.rmtree(save_dir_full)
244
+ if not args.saved_coord:
245
+ os.remove(crop_coord_save_path)
246
+
247
+ print(f"Results saved to {output_vid_name}")
248
+ except Exception as e:
249
+ print("Error occurred during processing:", e)
250
+
251
+ if __name__ == "__main__":
252
+ parser = argparse.ArgumentParser()
253
+ parser.add_argument("--ffmpeg_path", type=str, default="./ffmpeg-4.4-amd64-static/", help="Path to ffmpeg executable")
254
+ parser.add_argument("--gpu_id", type=int, default=0, help="GPU ID to use")
255
+ parser.add_argument("--vae_type", type=str, default="sd-vae", help="Type of VAE model")
256
+ parser.add_argument("--unet_config", type=str, default="./models/musetalk/config.json", help="Path to UNet configuration file")
257
+ parser.add_argument("--unet_model_path", type=str, default="./models/musetalkV15/unet.pth", help="Path to UNet model weights")
258
+ parser.add_argument("--whisper_dir", type=str, default="./models/whisper", help="Directory containing Whisper model")
259
+ parser.add_argument("--inference_config", type=str, default="configs/inference/test_img.yaml", help="Path to inference configuration file")
260
+ parser.add_argument("--bbox_shift", type=int, default=0, help="Bounding box shift value")
261
+ parser.add_argument("--result_dir", default='./results', help="Directory for output results")
262
+ parser.add_argument("--extra_margin", type=int, default=10, help="Extra margin for face cropping")
263
+ parser.add_argument("--fps", type=int, default=25, help="Video frames per second")
264
+ parser.add_argument("--audio_padding_length_left", type=int, default=2, help="Left padding length for audio")
265
+ parser.add_argument("--audio_padding_length_right", type=int, default=2, help="Right padding length for audio")
266
+ parser.add_argument("--batch_size", type=int, default=8, help="Batch size for inference")
267
+ parser.add_argument("--output_vid_name", type=str, default=None, help="Name of output video file")
268
+ parser.add_argument("--use_saved_coord", action="store_true", help='Use saved coordinates to save time')
269
+ parser.add_argument("--saved_coord", action="store_true", help='Save coordinates for future use')
270
+ parser.add_argument("--use_float16", action="store_true", help="Use float16 for faster inference")
271
+ parser.add_argument("--parsing_mode", default='jaw', help="Face blending parsing mode")
272
+ parser.add_argument("--left_cheek_width", type=int, default=90, help="Width of left cheek region")
273
+ parser.add_argument("--right_cheek_width", type=int, default=90, help="Width of right cheek region")
274
+ parser.add_argument("--version", type=str, default="v15", choices=["v1", "v15"], help="Model version to use")
275
+ args = parser.parse_args()
276
+ main(args)
scripts/inference.py CHANGED
@@ -213,7 +213,7 @@ def main(args):
213
  x1, y1, x2, y2 = bbox
214
  if args.version == "v15":
215
  y2 = y2 + args.extra_margin
216
- y2 = min(y2, frame.shape[0])
217
  try:
218
  res_frame = cv2.resize(res_frame.astype(np.uint8), (x2-x1, y2-y1))
219
  except:
 
213
  x1, y1, x2, y2 = bbox
214
  if args.version == "v15":
215
  y2 = y2 + args.extra_margin
216
+ y2 = min(y2, ori_frame.shape[0])
217
  try:
218
  res_frame = cv2.resize(res_frame.astype(np.uint8), (x2-x1, y2-y1))
219
  except:
server.py ADDED
@@ -0,0 +1,607 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MuseTalk Real-Time Server
3
+ Servidor FastAPI para lip-sync em tempo real
4
+ """
5
+ import os
6
+ import sys
7
+ import io
8
+ import time
9
+ import json
10
+ import uuid
11
+ import queue
12
+ import pickle
13
+ import shutil
14
+ import asyncio
15
+ import threading
16
+ from pathlib import Path
17
+ from typing import Optional
18
+ import tempfile
19
+
20
+ import cv2
21
+ import glob
22
+ import copy
23
+ import torch
24
+ import numpy as np
25
+ from tqdm import tqdm
26
+ from fastapi import FastAPI, UploadFile, File, Form, HTTPException, BackgroundTasks
27
+ from fastapi.responses import FileResponse, StreamingResponse, JSONResponse
28
+ from pydantic import BaseModel
29
+ import uvicorn
30
+
31
+ # Suppress warnings
32
+ import warnings
33
+ warnings.filterwarnings("ignore")
34
+
35
+ # MuseTalk imports
36
+ from musetalk.utils.utils import datagen, load_all_model
37
+ from musetalk.utils.blending import get_image_prepare_material, get_image_blending
38
+ from musetalk.utils.audio_processor import AudioProcessor
39
+ from musetalk.utils.preprocessing_simple import get_landmark_and_bbox, read_imgs
40
+ from transformers import WhisperModel
41
+
42
+ app = FastAPI(title="MuseTalk Real-Time Server", version="1.5")
43
+
44
+ # Global model instances
45
+ models = {}
46
+ avatars = {}
47
+
48
+ class AvatarConfig(BaseModel):
49
+ avatar_id: str
50
+ video_path: str
51
+ bbox_shift: int = 0
52
+
53
+ class InferenceRequest(BaseModel):
54
+ avatar_id: str
55
+ fps: int = 25
56
+
57
+ def video2imgs(vid_path, save_path):
58
+ """Extract frames from video"""
59
+ cap = cv2.VideoCapture(vid_path)
60
+ count = 0
61
+ while True:
62
+ ret, frame = cap.read()
63
+ if ret:
64
+ cv2.imwrite(f"{save_path}/{count:08d}.png", frame)
65
+ count += 1
66
+ else:
67
+ break
68
+ cap.release()
69
+ return count
70
+
71
+
72
+ @app.on_event("startup")
73
+ async def load_models():
74
+ """Load all models at startup"""
75
+ global models
76
+
77
+ print("Loading MuseTalk models...")
78
+ # Force CPU if FORCE_CPU env var is set or if CUDA kernels are incompatible
79
+ force_cpu = os.environ.get("FORCE_CPU", "0") == "1"
80
+ if force_cpu or not torch.cuda.is_available():
81
+ device = torch.device("cpu")
82
+ else:
83
+ try:
84
+ # Test if CUDA kernels work for this GPU
85
+ test_tensor = torch.zeros(1).cuda()
86
+ _ = test_tensor.half()
87
+ device = torch.device("cuda:0")
88
+ except RuntimeError as e:
89
+ print(f"CUDA kernel test failed: {e}")
90
+ print("Falling back to CPU...")
91
+ device = torch.device("cpu")
92
+ print(f"Using device: {device}")
93
+
94
+ # Model paths
95
+ unet_model_path = "./models/musetalkV15/unet.pth"
96
+ unet_config = "./models/musetalkV15/musetalk.json"
97
+ whisper_dir = "./models/whisper"
98
+ vae_type = "sd-vae"
99
+
100
+ # Load models
101
+ vae, unet, pe = load_all_model(
102
+ unet_model_path=unet_model_path,
103
+ vae_type=vae_type,
104
+ unet_config=unet_config,
105
+ device=device
106
+ )
107
+
108
+ # Move to device, use half precision only for GPU
109
+ if device.type == "cuda":
110
+ pe = pe.half().to(device)
111
+ vae.vae = vae.vae.half().to(device)
112
+ unet.model = unet.model.half().to(device)
113
+ else:
114
+ pe = pe.to(device)
115
+ vae.vae = vae.vae.to(device)
116
+ unet.model = unet.model.to(device)
117
+
118
+ # Load whisper
119
+ audio_processor = AudioProcessor(feature_extractor_path=whisper_dir)
120
+ whisper = WhisperModel.from_pretrained(whisper_dir)
121
+ weight_dtype = unet.model.dtype if device.type == "cuda" else torch.float32
122
+ whisper = whisper.to(device=device, dtype=weight_dtype).eval()
123
+ whisper.requires_grad_(False)
124
+
125
+ # Initialize face parser
126
+ from musetalk.utils.face_parsing import FaceParsing
127
+ fp = FaceParsing(left_cheek_width=90, right_cheek_width=90)
128
+
129
+ timesteps = torch.tensor([0], device=device)
130
+
131
+ models = {
132
+ "vae": vae,
133
+ "unet": unet,
134
+ "pe": pe,
135
+ "whisper": whisper,
136
+ "audio_processor": audio_processor,
137
+ "fp": fp,
138
+ "device": device,
139
+ "timesteps": timesteps,
140
+ "weight_dtype": weight_dtype
141
+ }
142
+
143
+ print("Models loaded successfully!")
144
+
145
+ @app.get("/")
146
+ async def root():
147
+ return {"status": "ok", "message": "MuseTalk Real-Time Server"}
148
+
149
+ @app.get("/health")
150
+ async def health():
151
+ return {
152
+ "status": "healthy",
153
+ "models_loaded": len(models) > 0,
154
+ "avatars_count": len(avatars),
155
+ "gpu_available": torch.cuda.is_available()
156
+ }
157
+
158
+ @app.post("/avatar/prepare")
159
+ async def prepare_avatar(
160
+ avatar_id: str = Form(...),
161
+ video: UploadFile = File(...),
162
+ bbox_shift: int = Form(0, description="Ajusta abertura da boca: positivo=mais aberto, negativo=menos aberto (-9 a 9)"),
163
+ extra_margin: int = Form(10, description="Margem extra para movimento do queixo"),
164
+ parsing_mode: str = Form("jaw", description="Modo de parsing: 'jaw' (v1.5) ou 'raw' (v1.0)"),
165
+ left_cheek_width: int = Form(90, description="Largura da bochecha esquerda"),
166
+ right_cheek_width: int = Form(90, description="Largura da bochecha direita")
167
+ ):
168
+ """Prepare an avatar from video for real-time inference"""
169
+ global avatars
170
+
171
+ if not models:
172
+ raise HTTPException(status_code=503, detail="Models not loaded")
173
+
174
+ # Save uploaded video
175
+ avatar_path = f"./results/v15/avatars/{avatar_id}"
176
+ full_imgs_path = f"{avatar_path}/full_imgs"
177
+ mask_out_path = f"{avatar_path}/mask"
178
+
179
+ os.makedirs(avatar_path, exist_ok=True)
180
+ os.makedirs(full_imgs_path, exist_ok=True)
181
+ os.makedirs(mask_out_path, exist_ok=True)
182
+
183
+ # Save video
184
+ video_path = f"{avatar_path}/source_video{Path(video.filename).suffix}"
185
+ with open(video_path, "wb") as f:
186
+ content = await video.read()
187
+ f.write(content)
188
+
189
+ # Extract frames
190
+ print(f"Extracting frames from video...")
191
+ frame_count = video2imgs(video_path, full_imgs_path)
192
+ print(f"Extracted {frame_count} frames")
193
+
194
+ input_img_list = sorted(glob.glob(os.path.join(full_imgs_path, '*.[jpJP][pnPN]*[gG]')))
195
+
196
+ print("Extracting landmarks...")
197
+ # bbox_shift controls mouth openness: positive=more open, negative=less open
198
+ coord_list_raw, frame_list_raw = get_landmark_and_bbox(input_img_list, upperbondrange=bbox_shift)
199
+
200
+ # Generate latents - filter out frames without detected faces
201
+ input_latent_list = []
202
+ valid_coord_list = []
203
+ valid_frame_list = []
204
+ coord_placeholder = (0.0, 0.0, 0.0, 0.0)
205
+
206
+ vae = models["vae"]
207
+
208
+ # Create FaceParsing with custom cheek widths for this avatar
209
+ from musetalk.utils.face_parsing import FaceParsing
210
+ fp_avatar = FaceParsing(left_cheek_width=left_cheek_width, right_cheek_width=right_cheek_width)
211
+
212
+ for bbox, frame in zip(coord_list_raw, frame_list_raw):
213
+ if bbox == coord_placeholder:
214
+ continue
215
+ x1, y1, x2, y2 = bbox
216
+ # Validate bbox dimensions
217
+ if x2 <= x1 or y2 <= y1:
218
+ continue
219
+ # Add extra margin for jaw movement (v1.5 feature)
220
+ y2 = min(y2 + extra_margin, frame.shape[0])
221
+
222
+ # Store valid frame and coordinates
223
+ valid_coord_list.append([x1, y1, x2, y2])
224
+ valid_frame_list.append(frame)
225
+
226
+ crop_frame = frame[y1:y2, x1:x2]
227
+ if crop_frame.size == 0:
228
+ valid_coord_list.pop()
229
+ valid_frame_list.pop()
230
+ continue
231
+ resized_crop_frame = cv2.resize(crop_frame, (256, 256), interpolation=cv2.INTER_LANCZOS4)
232
+ latents = vae.get_latents_for_unet(resized_crop_frame)
233
+ input_latent_list.append(latents)
234
+
235
+ print(f"Valid frames with detected faces: {len(valid_frame_list)}/{len(frame_list_raw)}")
236
+
237
+ if len(valid_frame_list) == 0:
238
+ raise HTTPException(status_code=400, detail="No faces detected in video. Please use a video with a clear frontal face.")
239
+
240
+ # Create cycles from valid frames only
241
+ frame_list_cycle = valid_frame_list + valid_frame_list[::-1]
242
+ coord_list_cycle = valid_coord_list + valid_coord_list[::-1]
243
+ input_latent_list_cycle = input_latent_list + input_latent_list[::-1]
244
+
245
+ # Generate masks
246
+ mask_list_cycle = []
247
+ mask_coords_list_cycle = []
248
+
249
+ print(f"Generating masks with mode={parsing_mode}...")
250
+ for i, frame in enumerate(tqdm(frame_list_cycle)):
251
+ x1, y1, x2, y2 = coord_list_cycle[i]
252
+ mask, crop_box = get_image_prepare_material(frame, [x1, y1, x2, y2], fp=fp_avatar, mode=parsing_mode)
253
+ cv2.imwrite(f"{mask_out_path}/{str(i).zfill(8)}.png", mask)
254
+ mask_coords_list_cycle.append(crop_box)
255
+ mask_list_cycle.append(mask)
256
+
257
+ # Save preprocessed data
258
+ with open(f"{avatar_path}/coords.pkl", 'wb') as f:
259
+ pickle.dump(coord_list_cycle, f)
260
+
261
+ with open(f"{avatar_path}/mask_coords.pkl", 'wb') as f:
262
+ pickle.dump(mask_coords_list_cycle, f)
263
+
264
+ # Save quality settings
265
+ quality_settings = {
266
+ "bbox_shift": bbox_shift,
267
+ "extra_margin": extra_margin,
268
+ "parsing_mode": parsing_mode,
269
+ "left_cheek_width": left_cheek_width,
270
+ "right_cheek_width": right_cheek_width
271
+ }
272
+ with open(f"{avatar_path}/quality_settings.json", 'w') as f:
273
+ json.dump(quality_settings, f)
274
+
275
+ torch.save(input_latent_list_cycle, f"{avatar_path}/latents.pt")
276
+
277
+ # Store in memory - keep latents on CPU to save GPU memory
278
+ input_latent_list_cpu = [lat.cpu() for lat in input_latent_list_cycle]
279
+
280
+ avatars[avatar_id] = {
281
+ "path": avatar_path,
282
+ "frame_list_cycle": frame_list_cycle,
283
+ "coord_list_cycle": coord_list_cycle,
284
+ "input_latent_list_cycle": input_latent_list_cpu,
285
+ "mask_list_cycle": mask_list_cycle,
286
+ "mask_coords_list_cycle": mask_coords_list_cycle,
287
+ "quality_settings": quality_settings
288
+ }
289
+
290
+ # Clear GPU cache after preparation
291
+ import gc
292
+ gc.collect()
293
+ torch.cuda.empty_cache()
294
+
295
+ return {
296
+ "status": "success",
297
+ "avatar_id": avatar_id,
298
+ "frame_count": len(frame_list_cycle),
299
+ "quality_settings": quality_settings
300
+ }
301
+
302
+ @app.post("/avatar/load/{avatar_id}")
303
+ async def load_avatar(avatar_id: str):
304
+ """Load a previously prepared avatar"""
305
+ global avatars
306
+
307
+ avatar_path = f"./results/v15/avatars/{avatar_id}"
308
+
309
+ if not os.path.exists(avatar_path):
310
+ raise HTTPException(status_code=404, detail=f"Avatar {avatar_id} not found")
311
+
312
+ full_imgs_path = f"{avatar_path}/full_imgs"
313
+ mask_out_path = f"{avatar_path}/mask"
314
+
315
+ # Load preprocessed data
316
+ input_latent_list_cycle = torch.load(f"{avatar_path}/latents.pt")
317
+
318
+ with open(f"{avatar_path}/coords.pkl", 'rb') as f:
319
+ coord_list_cycle = pickle.load(f)
320
+
321
+ with open(f"{avatar_path}/mask_coords.pkl", 'rb') as f:
322
+ mask_coords_list_cycle = pickle.load(f)
323
+
324
+ # Load quality settings (with defaults for backwards compatibility)
325
+ quality_settings_path = f"{avatar_path}/quality_settings.json"
326
+ if os.path.exists(quality_settings_path):
327
+ with open(quality_settings_path, 'r') as f:
328
+ quality_settings = json.load(f)
329
+ else:
330
+ quality_settings = {
331
+ "bbox_shift": 0,
332
+ "extra_margin": 10,
333
+ "parsing_mode": "jaw",
334
+ "left_cheek_width": 90,
335
+ "right_cheek_width": 90
336
+ }
337
+
338
+ # Load frames
339
+ input_img_list = sorted(glob.glob(os.path.join(full_imgs_path, '*.[jpJP][pnPN]*[gG]')))
340
+ frame_list_cycle = read_imgs(input_img_list)
341
+
342
+ # Load masks
343
+ input_mask_list = sorted(glob.glob(os.path.join(mask_out_path, '*.[jpJP][pnPN]*[gG]')))
344
+ mask_list_cycle = read_imgs(input_mask_list)
345
+
346
+ # Keep latents on CPU to save GPU memory
347
+ input_latent_list_cpu = [lat.cpu() if hasattr(lat, 'cpu') else lat for lat in input_latent_list_cycle]
348
+
349
+ avatars[avatar_id] = {
350
+ "path": avatar_path,
351
+ "frame_list_cycle": frame_list_cycle,
352
+ "coord_list_cycle": coord_list_cycle,
353
+ "input_latent_list_cycle": input_latent_list_cpu,
354
+ "mask_list_cycle": mask_list_cycle,
355
+ "mask_coords_list_cycle": mask_coords_list_cycle,
356
+ "quality_settings": quality_settings
357
+ }
358
+
359
+ # Clear GPU cache
360
+ import gc
361
+ gc.collect()
362
+ torch.cuda.empty_cache()
363
+
364
+ return {
365
+ "status": "success",
366
+ "avatar_id": avatar_id,
367
+ "frame_count": len(frame_list_cycle),
368
+ "quality_settings": quality_settings
369
+ }
370
+
371
+ @app.get("/avatars")
372
+ async def list_avatars():
373
+ """List all available avatars"""
374
+ avatar_dir = "./results/v15/avatars"
375
+ if not os.path.exists(avatar_dir):
376
+ return {"avatars": [], "loaded": list(avatars.keys())}
377
+
378
+ available = [d for d in os.listdir(avatar_dir) if os.path.isdir(os.path.join(avatar_dir, d))]
379
+ return {"avatars": available, "loaded": list(avatars.keys())}
380
+
381
+ @app.post("/inference")
382
+ async def inference(
383
+ avatar_id: str = Form(...),
384
+ audio: UploadFile = File(...),
385
+ fps: int = Form(25)
386
+ ):
387
+ """Run inference with uploaded audio and return video"""
388
+
389
+ if avatar_id not in avatars:
390
+ raise HTTPException(status_code=404, detail=f"Avatar {avatar_id} not loaded. Use /avatar/load first")
391
+
392
+ if not models:
393
+ raise HTTPException(status_code=503, detail="Models not loaded")
394
+
395
+ avatar = avatars[avatar_id]
396
+ device = models["device"]
397
+
398
+ # Save audio temporarily
399
+ with tempfile.NamedTemporaryFile(suffix=Path(audio.filename).suffix, delete=False) as tmp:
400
+ content = await audio.read()
401
+ tmp.write(content)
402
+ audio_path = tmp.name
403
+
404
+ try:
405
+ start_time = time.time()
406
+
407
+ # Extract audio features
408
+ audio_processor = models["audio_processor"]
409
+ whisper = models["whisper"]
410
+ weight_dtype = models["weight_dtype"]
411
+
412
+ whisper_input_features, librosa_length = audio_processor.get_audio_feature(
413
+ audio_path, weight_dtype=weight_dtype
414
+ )
415
+ whisper_chunks = audio_processor.get_whisper_chunk(
416
+ whisper_input_features,
417
+ device,
418
+ weight_dtype,
419
+ whisper,
420
+ librosa_length,
421
+ fps=fps,
422
+ audio_padding_length_left=2,
423
+ audio_padding_length_right=2,
424
+ )
425
+
426
+ print(f"Audio processing: {(time.time() - start_time)*1000:.0f}ms")
427
+
428
+ # Inference
429
+ vae = models["vae"]
430
+ unet = models["unet"]
431
+ pe = models["pe"]
432
+ timesteps = models["timesteps"]
433
+
434
+ video_num = len(whisper_chunks)
435
+ batch_size = 4 # Reduced batch size to save GPU memory
436
+
437
+ gen = datagen(whisper_chunks, avatar["input_latent_list_cycle"], batch_size)
438
+
439
+ result_frames = []
440
+ inference_start = time.time()
441
+
442
+ for i, (whisper_batch, latent_batch) in enumerate(gen):
443
+ audio_feature_batch = pe(whisper_batch.to(device))
444
+ latent_batch = latent_batch.to(device=device, dtype=unet.model.dtype)
445
+
446
+ pred_latents = unet.model(
447
+ latent_batch,
448
+ timesteps,
449
+ encoder_hidden_states=audio_feature_batch
450
+ ).sample
451
+
452
+ pred_latents = pred_latents.to(device=device, dtype=vae.vae.dtype)
453
+ recon = vae.decode_latents(pred_latents)
454
+
455
+ for idx_in_batch, res_frame in enumerate(recon):
456
+ frame_idx = i * batch_size + idx_in_batch
457
+ if frame_idx >= video_num:
458
+ break
459
+
460
+ bbox = avatar["coord_list_cycle"][frame_idx % len(avatar["coord_list_cycle"])]
461
+ ori_frame = copy.deepcopy(avatar["frame_list_cycle"][frame_idx % len(avatar["frame_list_cycle"])])
462
+ x1, y1, x2, y2 = bbox
463
+
464
+ res_frame = cv2.resize(res_frame.astype(np.uint8), (x2 - x1, y2 - y1))
465
+ mask = avatar["mask_list_cycle"][frame_idx % len(avatar["mask_list_cycle"])]
466
+ mask_crop_box = avatar["mask_coords_list_cycle"][frame_idx % len(avatar["mask_coords_list_cycle"])]
467
+
468
+ combine_frame = get_image_blending(ori_frame, res_frame, bbox, mask, mask_crop_box)
469
+ result_frames.append(combine_frame)
470
+
471
+ print(f"Inference: {(time.time() - inference_start)*1000:.0f}ms for {video_num} frames")
472
+ print(f"FPS: {video_num / (time.time() - inference_start):.1f}")
473
+
474
+ # Create video
475
+ output_path = tempfile.mktemp(suffix=".mp4")
476
+ h, w = result_frames[0].shape[:2]
477
+
478
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
479
+ out = cv2.VideoWriter(output_path, fourcc, fps, (w, h))
480
+
481
+ for frame in result_frames:
482
+ out.write(frame)
483
+ out.release()
484
+
485
+ # Combine with audio using ffmpeg
486
+ final_output = tempfile.mktemp(suffix=".mp4")
487
+ os.system(f"ffmpeg -y -v warning -i {audio_path} -i {output_path} -c:v libx264 -c:a aac {final_output}")
488
+
489
+ os.unlink(output_path)
490
+ os.unlink(audio_path)
491
+
492
+ total_time = time.time() - start_time
493
+ print(f"Total time: {total_time*1000:.0f}ms")
494
+
495
+ return FileResponse(
496
+ final_output,
497
+ media_type="video/mp4",
498
+ filename=f"output_{avatar_id}.mp4",
499
+ headers={"X-Processing-Time": f"{total_time:.2f}s"}
500
+ )
501
+
502
+ except Exception as e:
503
+ if os.path.exists(audio_path):
504
+ os.unlink(audio_path)
505
+ raise HTTPException(status_code=500, detail=str(e))
506
+
507
+ @app.post("/inference/frames")
508
+ async def inference_frames(
509
+ avatar_id: str = Form(...),
510
+ audio: UploadFile = File(...),
511
+ fps: int = Form(25)
512
+ ):
513
+ """Run inference and return frames as JSON (for streaming)"""
514
+
515
+ if avatar_id not in avatars:
516
+ raise HTTPException(status_code=404, detail=f"Avatar {avatar_id} not loaded")
517
+
518
+ avatar = avatars[avatar_id]
519
+ device = models["device"]
520
+
521
+ # Save audio temporarily
522
+ with tempfile.NamedTemporaryFile(suffix=Path(audio.filename).suffix, delete=False) as tmp:
523
+ content = await audio.read()
524
+ tmp.write(content)
525
+ audio_path = tmp.name
526
+
527
+ try:
528
+ # Extract audio features
529
+ audio_processor = models["audio_processor"]
530
+ whisper = models["whisper"]
531
+ weight_dtype = models["weight_dtype"]
532
+
533
+ whisper_input_features, librosa_length = audio_processor.get_audio_feature(
534
+ audio_path, weight_dtype=weight_dtype
535
+ )
536
+ whisper_chunks = audio_processor.get_whisper_chunk(
537
+ whisper_input_features,
538
+ device,
539
+ weight_dtype,
540
+ whisper,
541
+ librosa_length,
542
+ fps=fps,
543
+ )
544
+
545
+ # Inference
546
+ vae = models["vae"]
547
+ unet = models["unet"]
548
+ pe = models["pe"]
549
+ timesteps = models["timesteps"]
550
+
551
+ video_num = len(whisper_chunks)
552
+ batch_size = 4 # Reduced batch size to save GPU memory
553
+
554
+ gen = datagen(whisper_chunks, avatar["input_latent_list_cycle"], batch_size)
555
+
556
+ frames_data = []
557
+
558
+ for i, (whisper_batch, latent_batch) in enumerate(gen):
559
+ audio_feature_batch = pe(whisper_batch.to(device))
560
+ latent_batch = latent_batch.to(device=device, dtype=unet.model.dtype)
561
+
562
+ pred_latents = unet.model(
563
+ latent_batch,
564
+ timesteps,
565
+ encoder_hidden_states=audio_feature_batch
566
+ ).sample
567
+
568
+ pred_latents = pred_latents.to(device=device, dtype=vae.vae.dtype)
569
+ recon = vae.decode_latents(pred_latents)
570
+
571
+ for idx_in_batch, res_frame in enumerate(recon):
572
+ frame_idx = i * batch_size + idx_in_batch
573
+ if frame_idx >= video_num:
574
+ break
575
+
576
+ bbox = avatar["coord_list_cycle"][frame_idx % len(avatar["coord_list_cycle"])]
577
+ ori_frame = copy.deepcopy(avatar["frame_list_cycle"][frame_idx % len(avatar["frame_list_cycle"])])
578
+ x1, y1, x2, y2 = bbox
579
+
580
+ res_frame = cv2.resize(res_frame.astype(np.uint8), (x2 - x1, y2 - y1))
581
+ mask = avatar["mask_list_cycle"][frame_idx % len(avatar["mask_list_cycle"])]
582
+ mask_crop_box = avatar["mask_coords_list_cycle"][frame_idx % len(avatar["mask_coords_list_cycle"])]
583
+
584
+ combine_frame = get_image_blending(ori_frame, res_frame, bbox, mask, mask_crop_box)
585
+
586
+ # Encode frame as JPEG
587
+ _, buffer = cv2.imencode('.jpg', combine_frame, [cv2.IMWRITE_JPEG_QUALITY, 85])
588
+ import base64
589
+ frame_b64 = base64.b64encode(buffer).decode('utf-8')
590
+ frames_data.append(frame_b64)
591
+
592
+ os.unlink(audio_path)
593
+
594
+ return {
595
+ "frames": frames_data,
596
+ "fps": fps,
597
+ "total_frames": len(frames_data)
598
+ }
599
+
600
+ except Exception as e:
601
+ if os.path.exists(audio_path):
602
+ os.unlink(audio_path)
603
+ raise HTTPException(status_code=500, detail=str(e))
604
+
605
+
606
+ if __name__ == "__main__":
607
+ uvicorn.run(app, host="0.0.0.0", port=8000)