Commit
·
af11910
1
Parent(s):
bc451c3
fix: correct facial alignment issues and add API server
Browse files- Fix frame shape bug in inference.py line 216 (use ori_frame instead of frame)
- Adjust upper_boundary_ratio from 0.5 to 0.4 for better facial blending
- Add MuseTalk API server with multiple versions
- Add inference configs and helper scripts
- Update .gitignore for conda environments
🤖 Generated with [Claude Code](https://claude.com/claude-code)
Co-Authored-By: Claude Opus 4.5 <[email protected]>
- .gitignore +9 -1
- activate.sh +19 -0
- avatar_pipeline.py +204 -0
- configs/inference/hello_world.yaml +3 -0
- configs/inference/professor_test.yaml +4 -0
- musetalk/utils/blending.py +2 -2
- musetalk_api_server.py +551 -0
- musetalk_api_server_v2.py +445 -0
- musetalk_api_server_v3.py +651 -0
- musetalk_api_server_v3_fixed.py +371 -0
- run_inference.py +276 -0
- scripts/inference.py +1 -1
- server.py +607 -0
.gitignore
CHANGED
|
@@ -15,4 +15,12 @@ ffmprobe*
|
|
| 15 |
ffplay*
|
| 16 |
debug
|
| 17 |
exp_out
|
| 18 |
-
.gradio
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
ffplay*
|
| 16 |
debug
|
| 17 |
exp_out
|
| 18 |
+
.gradio
|
| 19 |
+
|
| 20 |
+
# Conda environment (Lightning AI persistent)
|
| 21 |
+
.conda_env/
|
| 22 |
+
miniconda/
|
| 23 |
+
venv/
|
| 24 |
+
|
| 25 |
+
# Arquivos temporários de instalação
|
| 26 |
+
=*
|
activate.sh
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# Script para ativar o ambiente MuseTalk
|
| 3 |
+
|
| 4 |
+
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
| 5 |
+
|
| 6 |
+
# Ativar conda local
|
| 7 |
+
source "$SCRIPT_DIR/miniconda/bin/activate" musetalk
|
| 8 |
+
|
| 9 |
+
# Configurar token HuggingFace (defina sua variável HF_TOKEN ou edite aqui)
|
| 10 |
+
# export HF_TOKEN="seu_token_aqui"
|
| 11 |
+
|
| 12 |
+
# Configurar FFMPEG se necessário
|
| 13 |
+
# export FFMPEG_PATH="$SCRIPT_DIR/ffmpeg"
|
| 14 |
+
|
| 15 |
+
cd "${SCRIPT_DIR}"
|
| 16 |
+
echo "✅ Ambiente MuseTalk ativado!"
|
| 17 |
+
echo "Diretório: ${SCRIPT_DIR}"
|
| 18 |
+
echo "Python: $(python --version)"
|
| 19 |
+
echo "PyTorch: $(python -c 'import torch; print(torch.__version__)')"
|
avatar_pipeline.py
ADDED
|
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Multimodal Avatar Pipeline
|
| 3 |
+
Audio Input -> Whisper -> LLM -> XTTS -> MuseTalk
|
| 4 |
+
|
| 5 |
+
This creates a complete avatar that can understand spoken Portuguese
|
| 6 |
+
and respond with lip-synced video.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import os
|
| 10 |
+
import requests
|
| 11 |
+
import tempfile
|
| 12 |
+
import soundfile as sf
|
| 13 |
+
|
| 14 |
+
os.environ["HF_HOME"] = "/workspace/MuseTalk/.cache/huggingface"
|
| 15 |
+
os.environ["COQUI_TOS_AGREED"] = "1"
|
| 16 |
+
|
| 17 |
+
from faster_whisper import WhisperModel
|
| 18 |
+
from llama_cpp import Llama
|
| 19 |
+
from TTS.api import TTS
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class MultimodalAvatar:
|
| 23 |
+
def __init__(
|
| 24 |
+
self,
|
| 25 |
+
whisper_model: str = "tiny",
|
| 26 |
+
llm_model_path: str = "models/llm/qwen2.5-0.5b-instruct-q4_k_m.gguf",
|
| 27 |
+
reference_audio: str = "data/audio/mariana_ref.wav",
|
| 28 |
+
avatar_id: str = "mariana_hd",
|
| 29 |
+
musetalk_url: str = "http://localhost:8000",
|
| 30 |
+
system_prompt: str = None
|
| 31 |
+
):
|
| 32 |
+
print("Initializing Multimodal Avatar Pipeline...")
|
| 33 |
+
|
| 34 |
+
# Whisper for speech-to-text
|
| 35 |
+
print(" Loading Whisper...")
|
| 36 |
+
self.whisper = WhisperModel(whisper_model, device="cpu", compute_type="int8")
|
| 37 |
+
|
| 38 |
+
# LLM for understanding and response
|
| 39 |
+
print(" Loading LLM...")
|
| 40 |
+
self.llm = Llama(
|
| 41 |
+
model_path=llm_model_path,
|
| 42 |
+
n_ctx=2048,
|
| 43 |
+
n_threads=4,
|
| 44 |
+
verbose=False
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
# XTTS for text-to-speech
|
| 48 |
+
print(" Loading XTTS...")
|
| 49 |
+
self.tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2", gpu=False)
|
| 50 |
+
|
| 51 |
+
self.reference_audio = reference_audio
|
| 52 |
+
self.avatar_id = avatar_id
|
| 53 |
+
self.musetalk_url = musetalk_url
|
| 54 |
+
|
| 55 |
+
self.system_prompt = system_prompt or """Você é Mariana, uma assistente virtual brasileira.
|
| 56 |
+
Você é simpática, prestativa e sempre responde em português brasileiro.
|
| 57 |
+
Suas respostas são claras, concisas e naturais, como se estivesse conversando.
|
| 58 |
+
Evite respostas muito longas - prefira 2-3 frases no máximo."""
|
| 59 |
+
|
| 60 |
+
self.conversation_history = []
|
| 61 |
+
print("Avatar ready!")
|
| 62 |
+
|
| 63 |
+
def transcribe(self, audio_path: str) -> str:
|
| 64 |
+
"""Transcribe audio to text using Whisper"""
|
| 65 |
+
segments, info = self.whisper.transcribe(audio_path, language="pt")
|
| 66 |
+
text = " ".join([segment.text for segment in segments]).strip()
|
| 67 |
+
return text
|
| 68 |
+
|
| 69 |
+
def think(self, user_message: str) -> str:
|
| 70 |
+
"""Generate response using LLM"""
|
| 71 |
+
self.conversation_history.append({"role": "user", "content": user_message})
|
| 72 |
+
|
| 73 |
+
messages = [{"role": "system", "content": self.system_prompt}]
|
| 74 |
+
messages.extend(self.conversation_history[-10:]) # Keep last 10 messages
|
| 75 |
+
|
| 76 |
+
response = self.llm.create_chat_completion(
|
| 77 |
+
messages=messages,
|
| 78 |
+
max_tokens=200,
|
| 79 |
+
temperature=0.7,
|
| 80 |
+
stop=["<|im_end|>", "<|endoftext|>"]
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
assistant_message = response['choices'][0]['message']['content'].strip()
|
| 84 |
+
self.conversation_history.append({"role": "assistant", "content": assistant_message})
|
| 85 |
+
|
| 86 |
+
return assistant_message
|
| 87 |
+
|
| 88 |
+
def speak(self, text: str, output_path: str) -> str:
|
| 89 |
+
"""Convert text to speech using XTTS"""
|
| 90 |
+
self.tts.tts_to_file(
|
| 91 |
+
text=text,
|
| 92 |
+
speaker_wav=self.reference_audio,
|
| 93 |
+
language="pt",
|
| 94 |
+
file_path=output_path
|
| 95 |
+
)
|
| 96 |
+
return output_path
|
| 97 |
+
|
| 98 |
+
def animate(self, audio_path: str, output_path: str) -> str:
|
| 99 |
+
"""Generate lip-sync video using MuseTalk"""
|
| 100 |
+
with open(audio_path, 'rb') as f:
|
| 101 |
+
response = requests.post(
|
| 102 |
+
f"{self.musetalk_url}/inference",
|
| 103 |
+
files={"audio": f},
|
| 104 |
+
data={"avatar_id": self.avatar_id},
|
| 105 |
+
timeout=300
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
if response.status_code == 200:
|
| 109 |
+
with open(output_path, 'wb') as f:
|
| 110 |
+
f.write(response.content)
|
| 111 |
+
return output_path
|
| 112 |
+
else:
|
| 113 |
+
raise Exception(f"MuseTalk error: {response.text}")
|
| 114 |
+
|
| 115 |
+
def respond(self, audio_input: str, output_video: str) -> dict:
|
| 116 |
+
"""
|
| 117 |
+
Complete pipeline: audio input -> transcribe -> think -> speak -> animate
|
| 118 |
+
|
| 119 |
+
Returns dict with all intermediate results
|
| 120 |
+
"""
|
| 121 |
+
print("\n=== Processing Request ===")
|
| 122 |
+
|
| 123 |
+
# Step 1: Transcribe
|
| 124 |
+
print("1. Transcribing audio...")
|
| 125 |
+
user_text = self.transcribe(audio_input)
|
| 126 |
+
print(f" User said: {user_text}")
|
| 127 |
+
|
| 128 |
+
# Step 2: Think
|
| 129 |
+
print("2. Generating response...")
|
| 130 |
+
response_text = self.think(user_text)
|
| 131 |
+
print(f" Response: {response_text}")
|
| 132 |
+
|
| 133 |
+
# Step 3: Speak
|
| 134 |
+
print("3. Synthesizing speech...")
|
| 135 |
+
audio_output = output_video.replace('.mp4', '.wav')
|
| 136 |
+
self.speak(response_text, audio_output)
|
| 137 |
+
|
| 138 |
+
# Get audio duration
|
| 139 |
+
data, sr = sf.read(audio_output)
|
| 140 |
+
audio_duration = len(data) / sr
|
| 141 |
+
print(f" Audio duration: {audio_duration:.2f}s")
|
| 142 |
+
|
| 143 |
+
# Step 4: Animate
|
| 144 |
+
print("4. Generating lip-sync video...")
|
| 145 |
+
self.animate(audio_output, output_video)
|
| 146 |
+
print(f" Video saved: {output_video}")
|
| 147 |
+
|
| 148 |
+
return {
|
| 149 |
+
"user_text": user_text,
|
| 150 |
+
"response_text": response_text,
|
| 151 |
+
"audio_path": audio_output,
|
| 152 |
+
"video_path": output_video,
|
| 153 |
+
"audio_duration": audio_duration
|
| 154 |
+
}
|
| 155 |
+
|
| 156 |
+
def respond_to_text(self, user_text: str, output_video: str) -> dict:
|
| 157 |
+
"""
|
| 158 |
+
Pipeline for text input: think -> speak -> animate
|
| 159 |
+
"""
|
| 160 |
+
print("\n=== Processing Text Request ===")
|
| 161 |
+
print(f" User: {user_text}")
|
| 162 |
+
|
| 163 |
+
# Step 1: Think
|
| 164 |
+
print("1. Generating response...")
|
| 165 |
+
response_text = self.think(user_text)
|
| 166 |
+
print(f" Response: {response_text}")
|
| 167 |
+
|
| 168 |
+
# Step 2: Speak
|
| 169 |
+
print("2. Synthesizing speech...")
|
| 170 |
+
audio_output = output_video.replace('.mp4', '.wav')
|
| 171 |
+
self.speak(response_text, audio_output)
|
| 172 |
+
|
| 173 |
+
data, sr = sf.read(audio_output)
|
| 174 |
+
audio_duration = len(data) / sr
|
| 175 |
+
print(f" Audio duration: {audio_duration:.2f}s")
|
| 176 |
+
|
| 177 |
+
# Step 3: Animate
|
| 178 |
+
print("3. Generating lip-sync video...")
|
| 179 |
+
self.animate(audio_output, output_video)
|
| 180 |
+
print(f" Video saved: {output_video}")
|
| 181 |
+
|
| 182 |
+
return {
|
| 183 |
+
"user_text": user_text,
|
| 184 |
+
"response_text": response_text,
|
| 185 |
+
"audio_path": audio_output,
|
| 186 |
+
"video_path": output_video,
|
| 187 |
+
"audio_duration": audio_duration
|
| 188 |
+
}
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
if __name__ == "__main__":
|
| 192 |
+
# Initialize the avatar
|
| 193 |
+
avatar = MultimodalAvatar()
|
| 194 |
+
|
| 195 |
+
# Test with text input
|
| 196 |
+
result = avatar.respond_to_text(
|
| 197 |
+
user_text="Olá Mariana! Me conte sobre você.",
|
| 198 |
+
output_video="results/avatar_test.mp4"
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
print("\n=== Result ===")
|
| 202 |
+
print(f"User: {result['user_text']}")
|
| 203 |
+
print(f"Mariana: {result['response_text']}")
|
| 204 |
+
print(f"Video: {result['video_path']} ({result['audio_duration']:.1f}s)")
|
configs/inference/hello_world.yaml
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
task_hello_world:
|
| 2 |
+
video_path: "data/video/video_hd_1min_25fps.mp4"
|
| 3 |
+
audio_path: "data/audio/hello_world.wav"
|
configs/inference/professor_test.yaml
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
task_0:
|
| 2 |
+
video_path: "data/video/yongen.mp4"
|
| 3 |
+
audio_path: "data/audio/professor_pt.wav"
|
| 4 |
+
bbox_shift: 0
|
musetalk/utils/blending.py
CHANGED
|
@@ -32,7 +32,7 @@ def face_seg(image, mode="raw", fp=None):
|
|
| 32 |
return seg_image
|
| 33 |
|
| 34 |
|
| 35 |
-
def get_image(image, face, face_box, upper_boundary_ratio=0.
|
| 36 |
"""
|
| 37 |
将裁剪的面部图像粘贴回原始图像,并进行一些处理。
|
| 38 |
|
|
@@ -109,7 +109,7 @@ def get_image_blending(image, face, face_box, mask_array, crop_box):
|
|
| 109 |
return body[:,:,::-1]
|
| 110 |
|
| 111 |
|
| 112 |
-
def get_image_prepare_material(image, face_box, upper_boundary_ratio=0.
|
| 113 |
body = Image.fromarray(image[:,:,::-1])
|
| 114 |
|
| 115 |
x, y, x1, y1 = face_box
|
|
|
|
| 32 |
return seg_image
|
| 33 |
|
| 34 |
|
| 35 |
+
def get_image(image, face, face_box, upper_boundary_ratio=0.4, expand=1.5, mode="raw", fp=None):
|
| 36 |
"""
|
| 37 |
将裁剪的面部图像粘贴回原始图像,并进行一些处理。
|
| 38 |
|
|
|
|
| 109 |
return body[:,:,::-1]
|
| 110 |
|
| 111 |
|
| 112 |
+
def get_image_prepare_material(image, face_box, upper_boundary_ratio=0.4, expand=1.5, fp=None, mode="raw"):
|
| 113 |
body = Image.fromarray(image[:,:,::-1])
|
| 114 |
|
| 115 |
x, y, x1, y1 = face_box
|
musetalk_api_server.py
ADDED
|
@@ -0,0 +1,551 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
MuseTalk HTTP API Server
|
| 3 |
+
Keeps models loaded in GPU memory for fast inference.
|
| 4 |
+
"""
|
| 5 |
+
import os
|
| 6 |
+
import cv2
|
| 7 |
+
import copy
|
| 8 |
+
import torch
|
| 9 |
+
import glob
|
| 10 |
+
import shutil
|
| 11 |
+
import pickle
|
| 12 |
+
import numpy as np
|
| 13 |
+
import subprocess
|
| 14 |
+
import tempfile
|
| 15 |
+
import hashlib
|
| 16 |
+
import time
|
| 17 |
+
from pathlib import Path
|
| 18 |
+
from typing import Optional
|
| 19 |
+
from fastapi import FastAPI, File, UploadFile, Form, HTTPException, BackgroundTasks
|
| 20 |
+
from fastapi.responses import FileResponse, JSONResponse
|
| 21 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 22 |
+
from pydantic import BaseModel
|
| 23 |
+
from tqdm import tqdm
|
| 24 |
+
from omegaconf import OmegaConf
|
| 25 |
+
from transformers import WhisperModel
|
| 26 |
+
import uvicorn
|
| 27 |
+
|
| 28 |
+
# MuseTalk imports
|
| 29 |
+
from musetalk.utils.blending import get_image
|
| 30 |
+
from musetalk.utils.face_parsing import FaceParsing
|
| 31 |
+
from musetalk.utils.audio_processor import AudioProcessor
|
| 32 |
+
from musetalk.utils.utils import get_file_type, datagen, load_all_model
|
| 33 |
+
from musetalk.utils.preprocessing import get_landmark_and_bbox, read_imgs, coord_placeholder
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class MuseTalkServer:
|
| 37 |
+
"""Singleton server that keeps models loaded in GPU memory."""
|
| 38 |
+
|
| 39 |
+
def __init__(self):
|
| 40 |
+
self.device = None
|
| 41 |
+
self.vae = None
|
| 42 |
+
self.unet = None
|
| 43 |
+
self.pe = None
|
| 44 |
+
self.whisper = None
|
| 45 |
+
self.audio_processor = None
|
| 46 |
+
self.fp = None
|
| 47 |
+
self.timesteps = None
|
| 48 |
+
self.weight_dtype = None
|
| 49 |
+
self.is_loaded = False
|
| 50 |
+
|
| 51 |
+
# Cache directories
|
| 52 |
+
self.cache_dir = Path("./cache")
|
| 53 |
+
self.cache_dir.mkdir(exist_ok=True)
|
| 54 |
+
self.landmarks_cache = self.cache_dir / "landmarks"
|
| 55 |
+
self.latents_cache = self.cache_dir / "latents"
|
| 56 |
+
self.whisper_cache = self.cache_dir / "whisper_features"
|
| 57 |
+
self.landmarks_cache.mkdir(exist_ok=True)
|
| 58 |
+
self.latents_cache.mkdir(exist_ok=True)
|
| 59 |
+
self.whisper_cache.mkdir(exist_ok=True)
|
| 60 |
+
|
| 61 |
+
# Config
|
| 62 |
+
self.fps = 25
|
| 63 |
+
self.batch_size = 8
|
| 64 |
+
self.use_float16 = True
|
| 65 |
+
self.version = "v15"
|
| 66 |
+
self.extra_margin = 10
|
| 67 |
+
self.parsing_mode = "jaw"
|
| 68 |
+
self.left_cheek_width = 90
|
| 69 |
+
self.right_cheek_width = 90
|
| 70 |
+
self.audio_padding_left = 2
|
| 71 |
+
self.audio_padding_right = 2
|
| 72 |
+
|
| 73 |
+
def load_models(
|
| 74 |
+
self,
|
| 75 |
+
gpu_id: int = 0,
|
| 76 |
+
unet_model_path: str = "./models/musetalkV15/unet.pth",
|
| 77 |
+
unet_config: str = "./models/musetalk/config.json",
|
| 78 |
+
vae_type: str = "sd-vae",
|
| 79 |
+
whisper_dir: str = "./models/whisper",
|
| 80 |
+
use_float16: bool = True,
|
| 81 |
+
version: str = "v15"
|
| 82 |
+
):
|
| 83 |
+
"""Load all models into GPU memory."""
|
| 84 |
+
if self.is_loaded:
|
| 85 |
+
print("Models already loaded!")
|
| 86 |
+
return
|
| 87 |
+
|
| 88 |
+
print("=" * 50)
|
| 89 |
+
print("Loading MuseTalk models into GPU memory...")
|
| 90 |
+
print("=" * 50)
|
| 91 |
+
|
| 92 |
+
start_time = time.time()
|
| 93 |
+
|
| 94 |
+
# Set device
|
| 95 |
+
self.device = torch.device(f"cuda:{gpu_id}" if torch.cuda.is_available() else "cpu")
|
| 96 |
+
print(f"Using device: {self.device}")
|
| 97 |
+
|
| 98 |
+
# Load model weights
|
| 99 |
+
print("Loading VAE, UNet, PE...")
|
| 100 |
+
self.vae, self.unet, self.pe = load_all_model(
|
| 101 |
+
unet_model_path=unet_model_path,
|
| 102 |
+
vae_type=vae_type,
|
| 103 |
+
unet_config=unet_config,
|
| 104 |
+
device=self.device
|
| 105 |
+
)
|
| 106 |
+
self.timesteps = torch.tensor([0], device=self.device)
|
| 107 |
+
|
| 108 |
+
# Convert to float16 if enabled
|
| 109 |
+
self.use_float16 = use_float16
|
| 110 |
+
if use_float16:
|
| 111 |
+
print("Converting to float16...")
|
| 112 |
+
self.pe = self.pe.half()
|
| 113 |
+
self.vae.vae = self.vae.vae.half()
|
| 114 |
+
self.unet.model = self.unet.model.half()
|
| 115 |
+
|
| 116 |
+
# Move to device
|
| 117 |
+
self.pe = self.pe.to(self.device)
|
| 118 |
+
self.vae.vae = self.vae.vae.to(self.device)
|
| 119 |
+
self.unet.model = self.unet.model.to(self.device)
|
| 120 |
+
|
| 121 |
+
# Initialize audio processor and Whisper
|
| 122 |
+
print("Loading Whisper model...")
|
| 123 |
+
self.audio_processor = AudioProcessor(feature_extractor_path=whisper_dir)
|
| 124 |
+
self.weight_dtype = self.unet.model.dtype
|
| 125 |
+
self.whisper = WhisperModel.from_pretrained(whisper_dir)
|
| 126 |
+
self.whisper = self.whisper.to(device=self.device, dtype=self.weight_dtype).eval()
|
| 127 |
+
self.whisper.requires_grad_(False)
|
| 128 |
+
|
| 129 |
+
# Initialize face parser
|
| 130 |
+
self.version = version
|
| 131 |
+
if version == "v15":
|
| 132 |
+
self.fp = FaceParsing(
|
| 133 |
+
left_cheek_width=self.left_cheek_width,
|
| 134 |
+
right_cheek_width=self.right_cheek_width
|
| 135 |
+
)
|
| 136 |
+
else:
|
| 137 |
+
self.fp = FaceParsing()
|
| 138 |
+
|
| 139 |
+
self.is_loaded = True
|
| 140 |
+
load_time = time.time() - start_time
|
| 141 |
+
print(f"Models loaded in {load_time:.2f}s")
|
| 142 |
+
print("=" * 50)
|
| 143 |
+
print("Server ready for inference!")
|
| 144 |
+
print("=" * 50)
|
| 145 |
+
|
| 146 |
+
def _get_file_hash(self, file_path: str) -> str:
|
| 147 |
+
"""Get MD5 hash of a file for caching."""
|
| 148 |
+
hash_md5 = hashlib.md5()
|
| 149 |
+
with open(file_path, "rb") as f:
|
| 150 |
+
for chunk in iter(lambda: f.read(4096), b""):
|
| 151 |
+
hash_md5.update(chunk)
|
| 152 |
+
return hash_md5.hexdigest()[:16]
|
| 153 |
+
|
| 154 |
+
def _get_cached_landmarks(self, video_hash: str, bbox_shift: int):
|
| 155 |
+
"""Get cached landmarks if available."""
|
| 156 |
+
# Disabled due to tensor comparison issues
|
| 157 |
+
return None
|
| 158 |
+
|
| 159 |
+
def _save_landmarks_cache(self, video_hash: str, bbox_shift: int, coord_list, frame_list):
|
| 160 |
+
"""Save landmarks to cache."""
|
| 161 |
+
cache_file = self.landmarks_cache / f"{video_hash}_shift{bbox_shift}.pkl"
|
| 162 |
+
with open(cache_file, 'wb') as f:
|
| 163 |
+
pickle.dump((coord_list, frame_list), f)
|
| 164 |
+
|
| 165 |
+
def _get_cached_latents(self, video_hash: str):
|
| 166 |
+
"""Get cached VAE latents if available."""
|
| 167 |
+
# Disabled due to tensor comparison issues
|
| 168 |
+
return None
|
| 169 |
+
|
| 170 |
+
def _save_latents_cache(self, video_hash: str, latent_list):
|
| 171 |
+
"""Save VAE latents to cache."""
|
| 172 |
+
cache_file = self.latents_cache / f"{video_hash}.pkl"
|
| 173 |
+
with open(cache_file, 'wb') as f:
|
| 174 |
+
pickle.dump(latent_list, f)
|
| 175 |
+
|
| 176 |
+
def _get_cached_whisper(self, audio_hash: str):
|
| 177 |
+
"""Get cached Whisper features if available."""
|
| 178 |
+
# Disabled due to tensor comparison issues
|
| 179 |
+
return None
|
| 180 |
+
|
| 181 |
+
def _save_whisper_cache(self, audio_hash: str, whisper_data):
|
| 182 |
+
"""Save Whisper features to cache."""
|
| 183 |
+
cache_file = self.whisper_cache / f"{audio_hash}.pkl"
|
| 184 |
+
with open(cache_file, 'wb') as f:
|
| 185 |
+
pickle.dump(whisper_data, f)
|
| 186 |
+
|
| 187 |
+
@torch.no_grad()
|
| 188 |
+
def generate(
|
| 189 |
+
self,
|
| 190 |
+
video_path: str,
|
| 191 |
+
audio_path: str,
|
| 192 |
+
output_path: str,
|
| 193 |
+
fps: Optional[int] = None,
|
| 194 |
+
use_cache: bool = True
|
| 195 |
+
) -> dict:
|
| 196 |
+
"""
|
| 197 |
+
Generate lip-synced video.
|
| 198 |
+
|
| 199 |
+
Returns dict with timing info.
|
| 200 |
+
"""
|
| 201 |
+
if not self.is_loaded:
|
| 202 |
+
raise RuntimeError("Models not loaded! Call load_models() first.")
|
| 203 |
+
|
| 204 |
+
fps = fps or self.fps
|
| 205 |
+
timings = {"total": 0}
|
| 206 |
+
total_start = time.time()
|
| 207 |
+
|
| 208 |
+
# Get file hashes for caching
|
| 209 |
+
video_hash = self._get_file_hash(video_path)
|
| 210 |
+
audio_hash = self._get_file_hash(audio_path)
|
| 211 |
+
|
| 212 |
+
# Create temp directory
|
| 213 |
+
temp_dir = tempfile.mkdtemp()
|
| 214 |
+
|
| 215 |
+
try:
|
| 216 |
+
# 1. Extract frames
|
| 217 |
+
t0 = time.time()
|
| 218 |
+
input_basename = Path(video_path).stem
|
| 219 |
+
save_dir_full = os.path.join(temp_dir, "frames")
|
| 220 |
+
os.makedirs(save_dir_full, exist_ok=True)
|
| 221 |
+
|
| 222 |
+
if get_file_type(video_path) == "video":
|
| 223 |
+
cmd = f"ffmpeg -v fatal -i {video_path} -vf fps={fps} -start_number 0 {save_dir_full}/%08d.png"
|
| 224 |
+
os.system(cmd)
|
| 225 |
+
input_img_list = sorted(glob.glob(os.path.join(save_dir_full, '*.[jpJP][pnPN]*[gG]')))
|
| 226 |
+
elif get_file_type(video_path) == "image":
|
| 227 |
+
input_img_list = [video_path]
|
| 228 |
+
else:
|
| 229 |
+
raise ValueError(f"Unsupported video type: {video_path}")
|
| 230 |
+
|
| 231 |
+
timings["frame_extraction"] = time.time() - t0
|
| 232 |
+
|
| 233 |
+
# 2. Extract audio features (with caching)
|
| 234 |
+
t0 = time.time()
|
| 235 |
+
cached_whisper = self._get_cached_whisper(audio_hash) if use_cache else None
|
| 236 |
+
|
| 237 |
+
if cached_whisper:
|
| 238 |
+
whisper_chunks = cached_whisper
|
| 239 |
+
timings["whisper_source"] = "cache"
|
| 240 |
+
else:
|
| 241 |
+
whisper_input_features, librosa_length = self.audio_processor.get_audio_feature(audio_path)
|
| 242 |
+
whisper_chunks = self.audio_processor.get_whisper_chunk(
|
| 243 |
+
whisper_input_features,
|
| 244 |
+
self.device,
|
| 245 |
+
self.weight_dtype,
|
| 246 |
+
self.whisper,
|
| 247 |
+
librosa_length,
|
| 248 |
+
fps=fps,
|
| 249 |
+
audio_padding_length_left=self.audio_padding_left,
|
| 250 |
+
audio_padding_length_right=self.audio_padding_right,
|
| 251 |
+
)
|
| 252 |
+
if use_cache:
|
| 253 |
+
self._save_whisper_cache(audio_hash, whisper_chunks)
|
| 254 |
+
timings["whisper_source"] = "computed"
|
| 255 |
+
|
| 256 |
+
timings["whisper_features"] = time.time() - t0
|
| 257 |
+
|
| 258 |
+
# 3. Get landmarks (with caching)
|
| 259 |
+
t0 = time.time()
|
| 260 |
+
bbox_shift = 0 if self.version == "v15" else 0
|
| 261 |
+
cache_key = f"{video_hash}_{fps}"
|
| 262 |
+
|
| 263 |
+
cached_landmarks = self._get_cached_landmarks(cache_key, bbox_shift) if use_cache else None
|
| 264 |
+
|
| 265 |
+
if cached_landmarks:
|
| 266 |
+
coord_list, frame_list = cached_landmarks
|
| 267 |
+
timings["landmarks_source"] = "cache"
|
| 268 |
+
else:
|
| 269 |
+
coord_list, frame_list = get_landmark_and_bbox(input_img_list, bbox_shift)
|
| 270 |
+
if use_cache:
|
| 271 |
+
self._save_landmarks_cache(cache_key, bbox_shift, coord_list, frame_list)
|
| 272 |
+
timings["landmarks_source"] = "computed"
|
| 273 |
+
|
| 274 |
+
timings["landmarks"] = time.time() - t0
|
| 275 |
+
|
| 276 |
+
# 4. Compute VAE latents (with caching)
|
| 277 |
+
t0 = time.time()
|
| 278 |
+
latent_cache_key = f"{video_hash}_{fps}_{self.version}"
|
| 279 |
+
cached_latents = self._get_cached_latents(latent_cache_key) if use_cache else None
|
| 280 |
+
|
| 281 |
+
if cached_latents:
|
| 282 |
+
input_latent_list = cached_latents
|
| 283 |
+
timings["latents_source"] = "cache"
|
| 284 |
+
else:
|
| 285 |
+
input_latent_list = []
|
| 286 |
+
for bbox, frame in zip(coord_list, frame_list):
|
| 287 |
+
if isinstance(bbox, (list, tuple)) and list(bbox) == list(coord_placeholder):
|
| 288 |
+
continue
|
| 289 |
+
x1, y1, x2, y2 = bbox
|
| 290 |
+
if self.version == "v15":
|
| 291 |
+
y2 = y2 + self.extra_margin
|
| 292 |
+
y2 = min(y2, frame.shape[0])
|
| 293 |
+
crop_frame = frame[y1:y2, x1:x2]
|
| 294 |
+
crop_frame = cv2.resize(crop_frame, (256, 256), interpolation=cv2.INTER_LANCZOS4)
|
| 295 |
+
latents = self.vae.get_latents_for_unet(crop_frame)
|
| 296 |
+
input_latent_list.append(latents)
|
| 297 |
+
|
| 298 |
+
if use_cache:
|
| 299 |
+
self._save_latents_cache(latent_cache_key, input_latent_list)
|
| 300 |
+
timings["latents_source"] = "computed"
|
| 301 |
+
|
| 302 |
+
timings["vae_encoding"] = time.time() - t0
|
| 303 |
+
|
| 304 |
+
# 5. Prepare cycled lists
|
| 305 |
+
frame_list_cycle = frame_list + frame_list[::-1]
|
| 306 |
+
coord_list_cycle = coord_list + coord_list[::-1]
|
| 307 |
+
input_latent_list_cycle = input_latent_list + input_latent_list[::-1]
|
| 308 |
+
|
| 309 |
+
# 6. UNet inference
|
| 310 |
+
t0 = time.time()
|
| 311 |
+
video_num = len(whisper_chunks)
|
| 312 |
+
gen = datagen(
|
| 313 |
+
whisper_chunks=whisper_chunks,
|
| 314 |
+
vae_encode_latents=input_latent_list_cycle,
|
| 315 |
+
batch_size=self.batch_size,
|
| 316 |
+
delay_frame=0,
|
| 317 |
+
device=self.device,
|
| 318 |
+
)
|
| 319 |
+
|
| 320 |
+
res_frame_list = []
|
| 321 |
+
for whisper_batch, latent_batch in gen:
|
| 322 |
+
audio_feature_batch = self.pe(whisper_batch)
|
| 323 |
+
latent_batch = latent_batch.to(dtype=self.unet.model.dtype)
|
| 324 |
+
pred_latents = self.unet.model(
|
| 325 |
+
latent_batch, self.timesteps,
|
| 326 |
+
encoder_hidden_states=audio_feature_batch
|
| 327 |
+
).sample
|
| 328 |
+
recon = self.vae.decode_latents(pred_latents)
|
| 329 |
+
for res_frame in recon:
|
| 330 |
+
res_frame_list.append(res_frame)
|
| 331 |
+
|
| 332 |
+
timings["unet_inference"] = time.time() - t0
|
| 333 |
+
|
| 334 |
+
# 7. Face blending
|
| 335 |
+
t0 = time.time()
|
| 336 |
+
result_img_path = os.path.join(temp_dir, "results")
|
| 337 |
+
os.makedirs(result_img_path, exist_ok=True)
|
| 338 |
+
|
| 339 |
+
for i, res_frame in enumerate(res_frame_list):
|
| 340 |
+
bbox = coord_list_cycle[i % len(coord_list_cycle)]
|
| 341 |
+
ori_frame = copy.deepcopy(frame_list_cycle[i % len(frame_list_cycle)])
|
| 342 |
+
x1, y1, x2, y2 = bbox
|
| 343 |
+
if self.version == "v15":
|
| 344 |
+
y2 = y2 + self.extra_margin
|
| 345 |
+
y2 = min(y2, ori_frame.shape[0])
|
| 346 |
+
try:
|
| 347 |
+
res_frame = cv2.resize(res_frame.astype(np.uint8), (x2-x1, y2-y1))
|
| 348 |
+
except:
|
| 349 |
+
continue
|
| 350 |
+
|
| 351 |
+
if self.version == "v15":
|
| 352 |
+
combine_frame = get_image(ori_frame, res_frame, [x1, y1, x2, y2],
|
| 353 |
+
mode=self.parsing_mode, fp=self.fp)
|
| 354 |
+
else:
|
| 355 |
+
combine_frame = get_image(ori_frame, res_frame, [x1, y1, x2, y2], fp=self.fp)
|
| 356 |
+
|
| 357 |
+
cv2.imwrite(f"{result_img_path}/{str(i).zfill(8)}.png", combine_frame)
|
| 358 |
+
|
| 359 |
+
timings["face_blending"] = time.time() - t0
|
| 360 |
+
|
| 361 |
+
# 8. Encode video
|
| 362 |
+
t0 = time.time()
|
| 363 |
+
temp_vid = os.path.join(temp_dir, "temp.mp4")
|
| 364 |
+
cmd_img2video = f"ffmpeg -y -v warning -r {fps} -f image2 -i {result_img_path}/%08d.png -vcodec libx264 -vf format=yuv420p -crf 18 {temp_vid}"
|
| 365 |
+
os.system(cmd_img2video)
|
| 366 |
+
|
| 367 |
+
cmd_combine = f"ffmpeg -y -v warning -i {audio_path} -i {temp_vid} {output_path}"
|
| 368 |
+
os.system(cmd_combine)
|
| 369 |
+
|
| 370 |
+
timings["video_encoding"] = time.time() - t0
|
| 371 |
+
|
| 372 |
+
finally:
|
| 373 |
+
# Cleanup
|
| 374 |
+
shutil.rmtree(temp_dir, ignore_errors=True)
|
| 375 |
+
|
| 376 |
+
timings["total"] = time.time() - total_start
|
| 377 |
+
timings["frames_generated"] = len(res_frame_list)
|
| 378 |
+
|
| 379 |
+
return timings
|
| 380 |
+
|
| 381 |
+
|
| 382 |
+
# Global server instance
|
| 383 |
+
server = MuseTalkServer()
|
| 384 |
+
|
| 385 |
+
# FastAPI app
|
| 386 |
+
app = FastAPI(
|
| 387 |
+
title="MuseTalk API",
|
| 388 |
+
description="HTTP API for MuseTalk lip-sync generation",
|
| 389 |
+
version="1.0.0"
|
| 390 |
+
)
|
| 391 |
+
|
| 392 |
+
# CORS middleware
|
| 393 |
+
app.add_middleware(
|
| 394 |
+
CORSMiddleware,
|
| 395 |
+
allow_origins=["*"],
|
| 396 |
+
allow_credentials=True,
|
| 397 |
+
allow_methods=["*"],
|
| 398 |
+
allow_headers=["*"],
|
| 399 |
+
)
|
| 400 |
+
|
| 401 |
+
|
| 402 |
+
@app.on_event("startup")
|
| 403 |
+
async def startup_event():
|
| 404 |
+
"""Load models on server startup."""
|
| 405 |
+
server.load_models()
|
| 406 |
+
|
| 407 |
+
|
| 408 |
+
@app.get("/health")
|
| 409 |
+
async def health_check():
|
| 410 |
+
"""Check if server is ready."""
|
| 411 |
+
return {
|
| 412 |
+
"status": "ok" if server.is_loaded else "loading",
|
| 413 |
+
"models_loaded": server.is_loaded,
|
| 414 |
+
"device": str(server.device) if server.device else None
|
| 415 |
+
}
|
| 416 |
+
|
| 417 |
+
|
| 418 |
+
@app.get("/cache/stats")
|
| 419 |
+
async def cache_stats():
|
| 420 |
+
"""Get cache statistics."""
|
| 421 |
+
landmarks_count = len(list(server.landmarks_cache.glob("*.pkl")))
|
| 422 |
+
latents_count = len(list(server.latents_cache.glob("*.pkl")))
|
| 423 |
+
whisper_count = len(list(server.whisper_cache.glob("*.pkl")))
|
| 424 |
+
|
| 425 |
+
return {
|
| 426 |
+
"landmarks_cached": landmarks_count,
|
| 427 |
+
"latents_cached": latents_count,
|
| 428 |
+
"whisper_features_cached": whisper_count
|
| 429 |
+
}
|
| 430 |
+
|
| 431 |
+
|
| 432 |
+
@app.post("/cache/clear")
|
| 433 |
+
async def clear_cache():
|
| 434 |
+
"""Clear all caches."""
|
| 435 |
+
for cache_dir in [server.landmarks_cache, server.latents_cache, server.whisper_cache]:
|
| 436 |
+
for f in cache_dir.glob("*.pkl"):
|
| 437 |
+
f.unlink()
|
| 438 |
+
return {"status": "cleared"}
|
| 439 |
+
|
| 440 |
+
|
| 441 |
+
class GenerateRequest(BaseModel):
|
| 442 |
+
video_path: str
|
| 443 |
+
audio_path: str
|
| 444 |
+
output_path: str
|
| 445 |
+
fps: Optional[int] = 25
|
| 446 |
+
use_cache: bool = True
|
| 447 |
+
|
| 448 |
+
|
| 449 |
+
@app.post("/generate")
|
| 450 |
+
async def generate_from_paths(request: GenerateRequest):
|
| 451 |
+
"""
|
| 452 |
+
Generate lip-synced video from file paths.
|
| 453 |
+
|
| 454 |
+
Use this when files are already on the server.
|
| 455 |
+
"""
|
| 456 |
+
if not server.is_loaded:
|
| 457 |
+
raise HTTPException(status_code=503, detail="Models not loaded yet")
|
| 458 |
+
|
| 459 |
+
if not os.path.exists(request.video_path):
|
| 460 |
+
raise HTTPException(status_code=404, detail=f"Video not found: {request.video_path}")
|
| 461 |
+
if not os.path.exists(request.audio_path):
|
| 462 |
+
raise HTTPException(status_code=404, detail=f"Audio not found: {request.audio_path}")
|
| 463 |
+
|
| 464 |
+
try:
|
| 465 |
+
timings = server.generate(
|
| 466 |
+
video_path=request.video_path,
|
| 467 |
+
audio_path=request.audio_path,
|
| 468 |
+
output_path=request.output_path,
|
| 469 |
+
fps=request.fps,
|
| 470 |
+
use_cache=request.use_cache
|
| 471 |
+
)
|
| 472 |
+
return {
|
| 473 |
+
"status": "success",
|
| 474 |
+
"output_path": request.output_path,
|
| 475 |
+
"timings": timings
|
| 476 |
+
}
|
| 477 |
+
except Exception as e:
|
| 478 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 479 |
+
|
| 480 |
+
|
| 481 |
+
@app.post("/generate/upload")
|
| 482 |
+
async def generate_from_upload(
|
| 483 |
+
video: UploadFile = File(...),
|
| 484 |
+
audio: UploadFile = File(...),
|
| 485 |
+
fps: int = Form(25),
|
| 486 |
+
use_cache: bool = Form(True)
|
| 487 |
+
):
|
| 488 |
+
"""
|
| 489 |
+
Generate lip-synced video from uploaded files.
|
| 490 |
+
|
| 491 |
+
Returns the generated video file.
|
| 492 |
+
"""
|
| 493 |
+
if not server.is_loaded:
|
| 494 |
+
raise HTTPException(status_code=503, detail="Models not loaded yet")
|
| 495 |
+
|
| 496 |
+
# Save uploaded files
|
| 497 |
+
temp_dir = tempfile.mkdtemp()
|
| 498 |
+
try:
|
| 499 |
+
video_path = os.path.join(temp_dir, video.filename)
|
| 500 |
+
audio_path = os.path.join(temp_dir, audio.filename)
|
| 501 |
+
output_path = os.path.join(temp_dir, "output.mp4")
|
| 502 |
+
|
| 503 |
+
with open(video_path, "wb") as f:
|
| 504 |
+
f.write(await video.read())
|
| 505 |
+
with open(audio_path, "wb") as f:
|
| 506 |
+
f.write(await audio.read())
|
| 507 |
+
|
| 508 |
+
timings = server.generate(
|
| 509 |
+
video_path=video_path,
|
| 510 |
+
audio_path=audio_path,
|
| 511 |
+
output_path=output_path,
|
| 512 |
+
fps=fps,
|
| 513 |
+
use_cache=use_cache
|
| 514 |
+
)
|
| 515 |
+
|
| 516 |
+
# Return the video file
|
| 517 |
+
return FileResponse(
|
| 518 |
+
output_path,
|
| 519 |
+
media_type="video/mp4",
|
| 520 |
+
filename="result.mp4",
|
| 521 |
+
headers={"X-Timings": str(timings)}
|
| 522 |
+
)
|
| 523 |
+
except Exception as e:
|
| 524 |
+
shutil.rmtree(temp_dir, ignore_errors=True)
|
| 525 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 526 |
+
|
| 527 |
+
|
| 528 |
+
if __name__ == "__main__":
|
| 529 |
+
import argparse
|
| 530 |
+
|
| 531 |
+
parser = argparse.ArgumentParser(description="MuseTalk API Server")
|
| 532 |
+
parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind")
|
| 533 |
+
parser.add_argument("--port", type=int, default=8000, help="Port to bind")
|
| 534 |
+
parser.add_argument("--gpu_id", type=int, default=0, help="GPU ID")
|
| 535 |
+
parser.add_argument("--unet_model_path", type=str, default="./models/musetalkV15/unet.pth")
|
| 536 |
+
parser.add_argument("--unet_config", type=str, default="./models/musetalk/config.json")
|
| 537 |
+
parser.add_argument("--whisper_dir", type=str, default="./models/whisper")
|
| 538 |
+
parser.add_argument("--no_float16", action="store_true", help="Disable float16")
|
| 539 |
+
args = parser.parse_args()
|
| 540 |
+
|
| 541 |
+
# Pre-configure server
|
| 542 |
+
server.load_models(
|
| 543 |
+
gpu_id=args.gpu_id,
|
| 544 |
+
unet_model_path=args.unet_model_path,
|
| 545 |
+
unet_config=args.unet_config,
|
| 546 |
+
whisper_dir=args.whisper_dir,
|
| 547 |
+
use_float16=not args.no_float16
|
| 548 |
+
)
|
| 549 |
+
|
| 550 |
+
# Start server
|
| 551 |
+
uvicorn.run(app, host=args.host, port=args.port)
|
musetalk_api_server_v2.py
ADDED
|
@@ -0,0 +1,445 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
MuseTalk HTTP API Server v2
|
| 3 |
+
Optimized for repeated use of the same avatar.
|
| 4 |
+
"""
|
| 5 |
+
import os
|
| 6 |
+
import cv2
|
| 7 |
+
import copy
|
| 8 |
+
import torch
|
| 9 |
+
import glob
|
| 10 |
+
import shutil
|
| 11 |
+
import pickle
|
| 12 |
+
import numpy as np
|
| 13 |
+
import subprocess
|
| 14 |
+
import tempfile
|
| 15 |
+
import hashlib
|
| 16 |
+
import time
|
| 17 |
+
from pathlib import Path
|
| 18 |
+
from typing import Optional
|
| 19 |
+
from fastapi import FastAPI, File, UploadFile, Form, HTTPException, BackgroundTasks
|
| 20 |
+
from fastapi.responses import FileResponse, JSONResponse
|
| 21 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 22 |
+
from pydantic import BaseModel
|
| 23 |
+
from tqdm import tqdm
|
| 24 |
+
from omegaconf import OmegaConf
|
| 25 |
+
from transformers import WhisperModel
|
| 26 |
+
import uvicorn
|
| 27 |
+
|
| 28 |
+
# MuseTalk imports
|
| 29 |
+
from musetalk.utils.blending import get_image
|
| 30 |
+
from musetalk.utils.face_parsing import FaceParsing
|
| 31 |
+
from musetalk.utils.audio_processor import AudioProcessor
|
| 32 |
+
from musetalk.utils.utils import get_file_type, datagen, load_all_model
|
| 33 |
+
from musetalk.utils.preprocessing import get_landmark_and_bbox, read_imgs, coord_placeholder
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class MuseTalkServerV2:
|
| 37 |
+
"""Server optimized for pre-processed avatars."""
|
| 38 |
+
|
| 39 |
+
def __init__(self):
|
| 40 |
+
self.device = None
|
| 41 |
+
self.vae = None
|
| 42 |
+
self.unet = None
|
| 43 |
+
self.pe = None
|
| 44 |
+
self.whisper = None
|
| 45 |
+
self.audio_processor = None
|
| 46 |
+
self.fp = None
|
| 47 |
+
self.timesteps = None
|
| 48 |
+
self.weight_dtype = None
|
| 49 |
+
self.is_loaded = False
|
| 50 |
+
|
| 51 |
+
# Avatar cache (in-memory)
|
| 52 |
+
self.loaded_avatars = {}
|
| 53 |
+
self.avatar_dir = Path("./avatars")
|
| 54 |
+
|
| 55 |
+
# Config
|
| 56 |
+
self.fps = 25
|
| 57 |
+
self.batch_size = 8
|
| 58 |
+
self.use_float16 = True
|
| 59 |
+
self.version = "v15"
|
| 60 |
+
self.extra_margin = 10
|
| 61 |
+
self.parsing_mode = "jaw"
|
| 62 |
+
self.left_cheek_width = 90
|
| 63 |
+
self.right_cheek_width = 90
|
| 64 |
+
self.audio_padding_left = 2
|
| 65 |
+
self.audio_padding_right = 2
|
| 66 |
+
|
| 67 |
+
def load_models(
|
| 68 |
+
self,
|
| 69 |
+
gpu_id: int = 0,
|
| 70 |
+
unet_model_path: str = "./models/musetalkV15/unet.pth",
|
| 71 |
+
unet_config: str = "./models/musetalk/config.json",
|
| 72 |
+
vae_type: str = "sd-vae",
|
| 73 |
+
whisper_dir: str = "./models/whisper",
|
| 74 |
+
use_float16: bool = True,
|
| 75 |
+
version: str = "v15"
|
| 76 |
+
):
|
| 77 |
+
if self.is_loaded:
|
| 78 |
+
print("Models already loaded!")
|
| 79 |
+
return
|
| 80 |
+
|
| 81 |
+
print("=" * 50)
|
| 82 |
+
print("Loading MuseTalk models into GPU memory...")
|
| 83 |
+
print("=" * 50)
|
| 84 |
+
|
| 85 |
+
start_time = time.time()
|
| 86 |
+
self.device = torch.device(f"cuda:{gpu_id}" if torch.cuda.is_available() else "cpu")
|
| 87 |
+
print(f"Using device: {self.device}")
|
| 88 |
+
|
| 89 |
+
print("Loading VAE, UNet, PE...")
|
| 90 |
+
self.vae, self.unet, self.pe = load_all_model(
|
| 91 |
+
unet_model_path=unet_model_path,
|
| 92 |
+
vae_type=vae_type,
|
| 93 |
+
unet_config=unet_config,
|
| 94 |
+
device=self.device
|
| 95 |
+
)
|
| 96 |
+
self.timesteps = torch.tensor([0], device=self.device)
|
| 97 |
+
|
| 98 |
+
self.use_float16 = use_float16
|
| 99 |
+
if use_float16:
|
| 100 |
+
print("Converting to float16...")
|
| 101 |
+
self.pe = self.pe.half()
|
| 102 |
+
self.vae.vae = self.vae.vae.half()
|
| 103 |
+
self.unet.model = self.unet.model.half()
|
| 104 |
+
|
| 105 |
+
self.pe = self.pe.to(self.device)
|
| 106 |
+
self.vae.vae = self.vae.vae.to(self.device)
|
| 107 |
+
self.unet.model = self.unet.model.to(self.device)
|
| 108 |
+
|
| 109 |
+
print("Loading Whisper model...")
|
| 110 |
+
self.audio_processor = AudioProcessor(feature_extractor_path=whisper_dir)
|
| 111 |
+
self.weight_dtype = self.unet.model.dtype
|
| 112 |
+
self.whisper = WhisperModel.from_pretrained(whisper_dir)
|
| 113 |
+
self.whisper = self.whisper.to(device=self.device, dtype=self.weight_dtype).eval()
|
| 114 |
+
self.whisper.requires_grad_(False)
|
| 115 |
+
|
| 116 |
+
self.version = version
|
| 117 |
+
if version == "v15":
|
| 118 |
+
self.fp = FaceParsing(
|
| 119 |
+
left_cheek_width=self.left_cheek_width,
|
| 120 |
+
right_cheek_width=self.right_cheek_width
|
| 121 |
+
)
|
| 122 |
+
else:
|
| 123 |
+
self.fp = FaceParsing()
|
| 124 |
+
|
| 125 |
+
self.is_loaded = True
|
| 126 |
+
print(f"Models loaded in {time.time() - start_time:.2f}s")
|
| 127 |
+
print("=" * 50)
|
| 128 |
+
|
| 129 |
+
def load_avatar(self, avatar_name: str) -> dict:
|
| 130 |
+
"""Load a preprocessed avatar into memory."""
|
| 131 |
+
if avatar_name in self.loaded_avatars:
|
| 132 |
+
return self.loaded_avatars[avatar_name]
|
| 133 |
+
|
| 134 |
+
avatar_path = self.avatar_dir / avatar_name
|
| 135 |
+
if not avatar_path.exists():
|
| 136 |
+
raise FileNotFoundError(f"Avatar not found: {avatar_name}")
|
| 137 |
+
|
| 138 |
+
print(f"Loading avatar '{avatar_name}' into memory...")
|
| 139 |
+
t0 = time.time()
|
| 140 |
+
|
| 141 |
+
avatar_data = {}
|
| 142 |
+
|
| 143 |
+
# Load metadata
|
| 144 |
+
with open(avatar_path / "metadata.pkl", 'rb') as f:
|
| 145 |
+
avatar_data['metadata'] = pickle.load(f)
|
| 146 |
+
|
| 147 |
+
# Load coords
|
| 148 |
+
with open(avatar_path / "coords.pkl", 'rb') as f:
|
| 149 |
+
avatar_data['coord_list'] = pickle.load(f)
|
| 150 |
+
|
| 151 |
+
# Load frames
|
| 152 |
+
with open(avatar_path / "frames.pkl", 'rb') as f:
|
| 153 |
+
avatar_data['frame_list'] = pickle.load(f)
|
| 154 |
+
|
| 155 |
+
# Load latents and convert to GPU tensors
|
| 156 |
+
with open(avatar_path / "latents.pkl", 'rb') as f:
|
| 157 |
+
latents_np = pickle.load(f)
|
| 158 |
+
avatar_data['latent_list'] = [
|
| 159 |
+
torch.from_numpy(l).to(self.device) for l in latents_np
|
| 160 |
+
]
|
| 161 |
+
|
| 162 |
+
# Load crop info
|
| 163 |
+
with open(avatar_path / "crop_info.pkl", 'rb') as f:
|
| 164 |
+
avatar_data['crop_info'] = pickle.load(f)
|
| 165 |
+
|
| 166 |
+
# Load parsing data (optional)
|
| 167 |
+
parsing_path = avatar_path / "parsing.pkl"
|
| 168 |
+
if parsing_path.exists():
|
| 169 |
+
with open(parsing_path, 'rb') as f:
|
| 170 |
+
avatar_data['parsing_data'] = pickle.load(f)
|
| 171 |
+
|
| 172 |
+
self.loaded_avatars[avatar_name] = avatar_data
|
| 173 |
+
print(f"Avatar loaded in {time.time() - t0:.2f}s")
|
| 174 |
+
|
| 175 |
+
return avatar_data
|
| 176 |
+
|
| 177 |
+
def unload_avatar(self, avatar_name: str):
|
| 178 |
+
"""Unload avatar from memory."""
|
| 179 |
+
if avatar_name in self.loaded_avatars:
|
| 180 |
+
del self.loaded_avatars[avatar_name]
|
| 181 |
+
torch.cuda.empty_cache()
|
| 182 |
+
|
| 183 |
+
@torch.no_grad()
|
| 184 |
+
def generate_with_avatar(
|
| 185 |
+
self,
|
| 186 |
+
avatar_name: str,
|
| 187 |
+
audio_path: str,
|
| 188 |
+
output_path: str,
|
| 189 |
+
fps: Optional[int] = None
|
| 190 |
+
) -> dict:
|
| 191 |
+
"""Generate video using pre-processed avatar. Much faster!"""
|
| 192 |
+
if not self.is_loaded:
|
| 193 |
+
raise RuntimeError("Models not loaded!")
|
| 194 |
+
|
| 195 |
+
fps = fps or self.fps
|
| 196 |
+
timings = {}
|
| 197 |
+
total_start = time.time()
|
| 198 |
+
|
| 199 |
+
# Load avatar (cached in memory)
|
| 200 |
+
t0 = time.time()
|
| 201 |
+
avatar = self.load_avatar(avatar_name)
|
| 202 |
+
timings["avatar_load"] = time.time() - t0
|
| 203 |
+
|
| 204 |
+
coord_list = avatar['coord_list']
|
| 205 |
+
frame_list = avatar['frame_list']
|
| 206 |
+
input_latent_list = avatar['latent_list']
|
| 207 |
+
|
| 208 |
+
temp_dir = tempfile.mkdtemp()
|
| 209 |
+
|
| 210 |
+
try:
|
| 211 |
+
# 1. Extract audio features (only audio-dependent step that's heavy)
|
| 212 |
+
t0 = time.time()
|
| 213 |
+
whisper_input_features, librosa_length = self.audio_processor.get_audio_feature(audio_path)
|
| 214 |
+
whisper_chunks = self.audio_processor.get_whisper_chunk(
|
| 215 |
+
whisper_input_features,
|
| 216 |
+
self.device,
|
| 217 |
+
self.weight_dtype,
|
| 218 |
+
self.whisper,
|
| 219 |
+
librosa_length,
|
| 220 |
+
fps=fps,
|
| 221 |
+
audio_padding_length_left=self.audio_padding_left,
|
| 222 |
+
audio_padding_length_right=self.audio_padding_right,
|
| 223 |
+
)
|
| 224 |
+
timings["whisper_features"] = time.time() - t0
|
| 225 |
+
|
| 226 |
+
# 2. Prepare cycled lists
|
| 227 |
+
frame_list_cycle = frame_list + frame_list[::-1]
|
| 228 |
+
coord_list_cycle = coord_list + coord_list[::-1]
|
| 229 |
+
input_latent_list_cycle = input_latent_list + input_latent_list[::-1]
|
| 230 |
+
|
| 231 |
+
# 3. UNet inference
|
| 232 |
+
t0 = time.time()
|
| 233 |
+
gen = datagen(
|
| 234 |
+
whisper_chunks=whisper_chunks,
|
| 235 |
+
vae_encode_latents=input_latent_list_cycle,
|
| 236 |
+
batch_size=self.batch_size,
|
| 237 |
+
delay_frame=0,
|
| 238 |
+
device=self.device,
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
res_frame_list = []
|
| 242 |
+
for whisper_batch, latent_batch in gen:
|
| 243 |
+
audio_feature_batch = self.pe(whisper_batch)
|
| 244 |
+
latent_batch = latent_batch.to(dtype=self.unet.model.dtype)
|
| 245 |
+
pred_latents = self.unet.model(
|
| 246 |
+
latent_batch, self.timesteps,
|
| 247 |
+
encoder_hidden_states=audio_feature_batch
|
| 248 |
+
).sample
|
| 249 |
+
recon = self.vae.decode_latents(pred_latents)
|
| 250 |
+
for res_frame in recon:
|
| 251 |
+
res_frame_list.append(res_frame)
|
| 252 |
+
|
| 253 |
+
timings["unet_inference"] = time.time() - t0
|
| 254 |
+
|
| 255 |
+
# 4. Face blending
|
| 256 |
+
t0 = time.time()
|
| 257 |
+
result_img_path = os.path.join(temp_dir, "results")
|
| 258 |
+
os.makedirs(result_img_path, exist_ok=True)
|
| 259 |
+
|
| 260 |
+
for i, res_frame in enumerate(res_frame_list):
|
| 261 |
+
bbox = coord_list_cycle[i % len(coord_list_cycle)]
|
| 262 |
+
ori_frame = copy.deepcopy(frame_list_cycle[i % len(frame_list_cycle)])
|
| 263 |
+
x1, y1, x2, y2 = bbox
|
| 264 |
+
|
| 265 |
+
if self.version == "v15":
|
| 266 |
+
y2 = y2 + self.extra_margin
|
| 267 |
+
y2 = min(y2, ori_frame.shape[0])
|
| 268 |
+
|
| 269 |
+
try:
|
| 270 |
+
res_frame = cv2.resize(res_frame.astype(np.uint8), (x2-x1, y2-y1))
|
| 271 |
+
except:
|
| 272 |
+
continue
|
| 273 |
+
|
| 274 |
+
if self.version == "v15":
|
| 275 |
+
combine_frame = get_image(ori_frame, res_frame, [x1, y1, x2, y2],
|
| 276 |
+
mode=self.parsing_mode, fp=self.fp)
|
| 277 |
+
else:
|
| 278 |
+
combine_frame = get_image(ori_frame, res_frame, [x1, y1, x2, y2], fp=self.fp)
|
| 279 |
+
|
| 280 |
+
cv2.imwrite(f"{result_img_path}/{str(i).zfill(8)}.png", combine_frame)
|
| 281 |
+
|
| 282 |
+
timings["face_blending"] = time.time() - t0
|
| 283 |
+
|
| 284 |
+
# 5. Encode video
|
| 285 |
+
t0 = time.time()
|
| 286 |
+
temp_vid = os.path.join(temp_dir, "temp.mp4")
|
| 287 |
+
cmd_img2video = f"ffmpeg -y -v warning -r {fps} -f image2 -i {result_img_path}/%08d.png -vcodec libx264 -vf format=yuv420p -crf 18 {temp_vid}"
|
| 288 |
+
os.system(cmd_img2video)
|
| 289 |
+
|
| 290 |
+
cmd_combine = f"ffmpeg -y -v warning -i {audio_path} -i {temp_vid} {output_path}"
|
| 291 |
+
os.system(cmd_combine)
|
| 292 |
+
|
| 293 |
+
timings["video_encoding"] = time.time() - t0
|
| 294 |
+
|
| 295 |
+
finally:
|
| 296 |
+
shutil.rmtree(temp_dir, ignore_errors=True)
|
| 297 |
+
|
| 298 |
+
timings["total"] = time.time() - total_start
|
| 299 |
+
timings["frames_generated"] = len(res_frame_list)
|
| 300 |
+
|
| 301 |
+
return timings
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
# Global server instance
|
| 305 |
+
server = MuseTalkServerV2()
|
| 306 |
+
|
| 307 |
+
# FastAPI app
|
| 308 |
+
app = FastAPI(
|
| 309 |
+
title="MuseTalk API v2",
|
| 310 |
+
description="Optimized API for repeated avatar usage",
|
| 311 |
+
version="2.0.0"
|
| 312 |
+
)
|
| 313 |
+
|
| 314 |
+
app.add_middleware(
|
| 315 |
+
CORSMiddleware,
|
| 316 |
+
allow_origins=["*"],
|
| 317 |
+
allow_credentials=True,
|
| 318 |
+
allow_methods=["*"],
|
| 319 |
+
allow_headers=["*"],
|
| 320 |
+
)
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
@app.on_event("startup")
|
| 324 |
+
async def startup_event():
|
| 325 |
+
server.load_models()
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
@app.get("/health")
|
| 329 |
+
async def health_check():
|
| 330 |
+
return {
|
| 331 |
+
"status": "ok" if server.is_loaded else "loading",
|
| 332 |
+
"models_loaded": server.is_loaded,
|
| 333 |
+
"device": str(server.device) if server.device else None,
|
| 334 |
+
"loaded_avatars": list(server.loaded_avatars.keys())
|
| 335 |
+
}
|
| 336 |
+
|
| 337 |
+
|
| 338 |
+
@app.get("/avatars")
|
| 339 |
+
async def list_avatars():
|
| 340 |
+
"""List all available preprocessed avatars."""
|
| 341 |
+
avatars = []
|
| 342 |
+
for p in server.avatar_dir.iterdir():
|
| 343 |
+
if p.is_dir() and (p / "metadata.pkl").exists():
|
| 344 |
+
with open(p / "metadata.pkl", 'rb') as f:
|
| 345 |
+
metadata = pickle.load(f)
|
| 346 |
+
metadata['loaded'] = p.name in server.loaded_avatars
|
| 347 |
+
avatars.append(metadata)
|
| 348 |
+
return {"avatars": avatars}
|
| 349 |
+
|
| 350 |
+
|
| 351 |
+
@app.post("/avatars/{avatar_name}/load")
|
| 352 |
+
async def load_avatar(avatar_name: str):
|
| 353 |
+
"""Pre-load an avatar into GPU memory."""
|
| 354 |
+
try:
|
| 355 |
+
server.load_avatar(avatar_name)
|
| 356 |
+
return {"status": "loaded", "avatar_name": avatar_name}
|
| 357 |
+
except FileNotFoundError as e:
|
| 358 |
+
raise HTTPException(status_code=404, detail=str(e))
|
| 359 |
+
|
| 360 |
+
|
| 361 |
+
@app.post("/avatars/{avatar_name}/unload")
|
| 362 |
+
async def unload_avatar(avatar_name: str):
|
| 363 |
+
"""Unload an avatar from memory."""
|
| 364 |
+
server.unload_avatar(avatar_name)
|
| 365 |
+
return {"status": "unloaded", "avatar_name": avatar_name}
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
class GenerateWithAvatarRequest(BaseModel):
|
| 369 |
+
avatar_name: str
|
| 370 |
+
audio_path: str
|
| 371 |
+
output_path: str
|
| 372 |
+
fps: Optional[int] = 25
|
| 373 |
+
|
| 374 |
+
|
| 375 |
+
@app.post("/generate/avatar")
|
| 376 |
+
async def generate_with_avatar(request: GenerateWithAvatarRequest):
|
| 377 |
+
"""Generate video using pre-processed avatar. FAST!"""
|
| 378 |
+
if not server.is_loaded:
|
| 379 |
+
raise HTTPException(status_code=503, detail="Models not loaded")
|
| 380 |
+
|
| 381 |
+
if not os.path.exists(request.audio_path):
|
| 382 |
+
raise HTTPException(status_code=404, detail=f"Audio not found: {request.audio_path}")
|
| 383 |
+
|
| 384 |
+
try:
|
| 385 |
+
timings = server.generate_with_avatar(
|
| 386 |
+
avatar_name=request.avatar_name,
|
| 387 |
+
audio_path=request.audio_path,
|
| 388 |
+
output_path=request.output_path,
|
| 389 |
+
fps=request.fps
|
| 390 |
+
)
|
| 391 |
+
return {
|
| 392 |
+
"status": "success",
|
| 393 |
+
"output_path": request.output_path,
|
| 394 |
+
"timings": timings
|
| 395 |
+
}
|
| 396 |
+
except FileNotFoundError as e:
|
| 397 |
+
raise HTTPException(status_code=404, detail=str(e))
|
| 398 |
+
except Exception as e:
|
| 399 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 400 |
+
|
| 401 |
+
|
| 402 |
+
@app.post("/generate/avatar/upload")
|
| 403 |
+
async def generate_with_avatar_upload(
|
| 404 |
+
avatar_name: str = Form(...),
|
| 405 |
+
audio: UploadFile = File(...),
|
| 406 |
+
fps: int = Form(25)
|
| 407 |
+
):
|
| 408 |
+
"""Generate video from uploaded audio using pre-processed avatar."""
|
| 409 |
+
if not server.is_loaded:
|
| 410 |
+
raise HTTPException(status_code=503, detail="Models not loaded")
|
| 411 |
+
|
| 412 |
+
temp_dir = tempfile.mkdtemp()
|
| 413 |
+
try:
|
| 414 |
+
audio_path = os.path.join(temp_dir, audio.filename)
|
| 415 |
+
output_path = os.path.join(temp_dir, "output.mp4")
|
| 416 |
+
|
| 417 |
+
with open(audio_path, "wb") as f:
|
| 418 |
+
f.write(await audio.read())
|
| 419 |
+
|
| 420 |
+
timings = server.generate_with_avatar(
|
| 421 |
+
avatar_name=avatar_name,
|
| 422 |
+
audio_path=audio_path,
|
| 423 |
+
output_path=output_path,
|
| 424 |
+
fps=fps
|
| 425 |
+
)
|
| 426 |
+
|
| 427 |
+
return FileResponse(
|
| 428 |
+
output_path,
|
| 429 |
+
media_type="video/mp4",
|
| 430 |
+
filename="result.mp4",
|
| 431 |
+
headers={"X-Timings": str(timings)}
|
| 432 |
+
)
|
| 433 |
+
except Exception as e:
|
| 434 |
+
shutil.rmtree(temp_dir, ignore_errors=True)
|
| 435 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 436 |
+
|
| 437 |
+
|
| 438 |
+
if __name__ == "__main__":
|
| 439 |
+
import argparse
|
| 440 |
+
parser = argparse.ArgumentParser()
|
| 441 |
+
parser.add_argument("--host", type=str, default="0.0.0.0")
|
| 442 |
+
parser.add_argument("--port", type=int, default=8000)
|
| 443 |
+
args = parser.parse_args()
|
| 444 |
+
|
| 445 |
+
uvicorn.run(app, host=args.host, port=args.port)
|
musetalk_api_server_v3.py
ADDED
|
@@ -0,0 +1,651 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
MuseTalk HTTP API Server v3
|
| 3 |
+
Ultra-optimized with:
|
| 4 |
+
1. GPU-accelerated face blending (parallel processing)
|
| 5 |
+
2. NVENC hardware video encoding
|
| 6 |
+
3. Batch audio processing
|
| 7 |
+
"""
|
| 8 |
+
import os
|
| 9 |
+
import cv2
|
| 10 |
+
import copy
|
| 11 |
+
import torch
|
| 12 |
+
import glob
|
| 13 |
+
import shutil
|
| 14 |
+
import pickle
|
| 15 |
+
import numpy as np
|
| 16 |
+
import subprocess
|
| 17 |
+
import tempfile
|
| 18 |
+
import hashlib
|
| 19 |
+
import time
|
| 20 |
+
import asyncio
|
| 21 |
+
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor
|
| 22 |
+
from pathlib import Path
|
| 23 |
+
from typing import Optional, List
|
| 24 |
+
from fastapi import FastAPI, File, UploadFile, Form, HTTPException, BackgroundTasks
|
| 25 |
+
from fastapi.responses import FileResponse, JSONResponse
|
| 26 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 27 |
+
from pydantic import BaseModel
|
| 28 |
+
from tqdm import tqdm
|
| 29 |
+
from omegaconf import OmegaConf
|
| 30 |
+
from transformers import WhisperModel
|
| 31 |
+
import uvicorn
|
| 32 |
+
import multiprocessing as mp
|
| 33 |
+
|
| 34 |
+
# MuseTalk imports
|
| 35 |
+
from musetalk.utils.blending import get_image
|
| 36 |
+
from musetalk.utils.face_parsing import FaceParsing
|
| 37 |
+
from musetalk.utils.audio_processor import AudioProcessor
|
| 38 |
+
from musetalk.utils.utils import get_file_type, datagen, load_all_model
|
| 39 |
+
from musetalk.utils.preprocessing import get_landmark_and_bbox, read_imgs, coord_placeholder
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def blend_single_frame(args):
|
| 43 |
+
"""Worker function for parallel face blending."""
|
| 44 |
+
i, res_frame, bbox, ori_frame, extra_margin, version, parsing_mode, fp_config = args
|
| 45 |
+
|
| 46 |
+
x1, y1, x2, y2 = bbox
|
| 47 |
+
if version == "v15":
|
| 48 |
+
y2 = y2 + extra_margin
|
| 49 |
+
y2 = min(y2, ori_frame.shape[0])
|
| 50 |
+
|
| 51 |
+
try:
|
| 52 |
+
res_frame = cv2.resize(res_frame.astype(np.uint8), (x2-x1, y2-y1))
|
| 53 |
+
except:
|
| 54 |
+
return i, None
|
| 55 |
+
|
| 56 |
+
# Create FaceParsing instance for this worker
|
| 57 |
+
fp = FaceParsing(
|
| 58 |
+
left_cheek_width=fp_config['left_cheek_width'],
|
| 59 |
+
right_cheek_width=fp_config['right_cheek_width']
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
if version == "v15":
|
| 63 |
+
combine_frame = get_image(ori_frame, res_frame, [x1, y1, x2, y2],
|
| 64 |
+
mode=parsing_mode, fp=fp)
|
| 65 |
+
else:
|
| 66 |
+
combine_frame = get_image(ori_frame, res_frame, [x1, y1, x2, y2], fp=fp)
|
| 67 |
+
|
| 68 |
+
return i, combine_frame
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class MuseTalkServerV3:
|
| 72 |
+
"""Ultra-optimized server."""
|
| 73 |
+
|
| 74 |
+
def __init__(self):
|
| 75 |
+
self.device = None
|
| 76 |
+
self.vae = None
|
| 77 |
+
self.unet = None
|
| 78 |
+
self.pe = None
|
| 79 |
+
self.whisper = None
|
| 80 |
+
self.audio_processor = None
|
| 81 |
+
self.fp = None
|
| 82 |
+
self.timesteps = None
|
| 83 |
+
self.weight_dtype = None
|
| 84 |
+
self.is_loaded = False
|
| 85 |
+
|
| 86 |
+
# Avatar cache
|
| 87 |
+
self.loaded_avatars = {}
|
| 88 |
+
self.avatar_dir = Path("./avatars")
|
| 89 |
+
|
| 90 |
+
# Config
|
| 91 |
+
self.fps = 25
|
| 92 |
+
self.batch_size = 8
|
| 93 |
+
self.use_float16 = True
|
| 94 |
+
self.version = "v15"
|
| 95 |
+
self.extra_margin = 10
|
| 96 |
+
self.parsing_mode = "jaw"
|
| 97 |
+
self.left_cheek_width = 90
|
| 98 |
+
self.right_cheek_width = 90
|
| 99 |
+
self.audio_padding_left = 2
|
| 100 |
+
self.audio_padding_right = 2
|
| 101 |
+
|
| 102 |
+
# Thread pool for parallel blending
|
| 103 |
+
self.num_workers = min(8, mp.cpu_count())
|
| 104 |
+
self.thread_pool = ThreadPoolExecutor(max_workers=self.num_workers)
|
| 105 |
+
|
| 106 |
+
# NVENC settings
|
| 107 |
+
self.use_nvenc = True
|
| 108 |
+
self.nvenc_preset = "p4" # p1(fastest) to p7(best quality)
|
| 109 |
+
self.crf = 23
|
| 110 |
+
|
| 111 |
+
def load_models(
|
| 112 |
+
self,
|
| 113 |
+
gpu_id: int = 0,
|
| 114 |
+
unet_model_path: str = "./models/musetalkV15/unet.pth",
|
| 115 |
+
unet_config: str = "./models/musetalk/config.json",
|
| 116 |
+
vae_type: str = "sd-vae",
|
| 117 |
+
whisper_dir: str = "./models/whisper",
|
| 118 |
+
use_float16: bool = True,
|
| 119 |
+
version: str = "v15"
|
| 120 |
+
):
|
| 121 |
+
if self.is_loaded:
|
| 122 |
+
print("Models already loaded!")
|
| 123 |
+
return
|
| 124 |
+
|
| 125 |
+
print("=" * 50)
|
| 126 |
+
print("Loading MuseTalk models (v3 Ultra-Optimized)...")
|
| 127 |
+
print("=" * 50)
|
| 128 |
+
|
| 129 |
+
start_time = time.time()
|
| 130 |
+
self.device = torch.device(f"cuda:{gpu_id}" if torch.cuda.is_available() else "cpu")
|
| 131 |
+
print(f"Using device: {self.device}")
|
| 132 |
+
print(f"Parallel workers: {self.num_workers}")
|
| 133 |
+
print(f"NVENC encoding: {self.use_nvenc}")
|
| 134 |
+
|
| 135 |
+
print("Loading VAE, UNet, PE...")
|
| 136 |
+
self.vae, self.unet, self.pe = load_all_model(
|
| 137 |
+
unet_model_path=unet_model_path,
|
| 138 |
+
vae_type=vae_type,
|
| 139 |
+
unet_config=unet_config,
|
| 140 |
+
device=self.device
|
| 141 |
+
)
|
| 142 |
+
self.timesteps = torch.tensor([0], device=self.device)
|
| 143 |
+
|
| 144 |
+
self.use_float16 = use_float16
|
| 145 |
+
if use_float16:
|
| 146 |
+
print("Converting to float16...")
|
| 147 |
+
self.pe = self.pe.half()
|
| 148 |
+
self.vae.vae = self.vae.vae.half()
|
| 149 |
+
self.unet.model = self.unet.model.half()
|
| 150 |
+
|
| 151 |
+
self.pe = self.pe.to(self.device)
|
| 152 |
+
self.vae.vae = self.vae.vae.to(self.device)
|
| 153 |
+
self.unet.model = self.unet.model.to(self.device)
|
| 154 |
+
|
| 155 |
+
print("Loading Whisper model...")
|
| 156 |
+
self.audio_processor = AudioProcessor(feature_extractor_path=whisper_dir)
|
| 157 |
+
self.weight_dtype = self.unet.model.dtype
|
| 158 |
+
self.whisper = WhisperModel.from_pretrained(whisper_dir)
|
| 159 |
+
self.whisper = self.whisper.to(device=self.device, dtype=self.weight_dtype).eval()
|
| 160 |
+
self.whisper.requires_grad_(False)
|
| 161 |
+
|
| 162 |
+
self.version = version
|
| 163 |
+
if version == "v15":
|
| 164 |
+
self.fp = FaceParsing(
|
| 165 |
+
left_cheek_width=self.left_cheek_width,
|
| 166 |
+
right_cheek_width=self.right_cheek_width
|
| 167 |
+
)
|
| 168 |
+
else:
|
| 169 |
+
self.fp = FaceParsing()
|
| 170 |
+
|
| 171 |
+
self.is_loaded = True
|
| 172 |
+
print(f"Models loaded in {time.time() - start_time:.2f}s")
|
| 173 |
+
print("=" * 50)
|
| 174 |
+
|
| 175 |
+
def load_avatar(self, avatar_name: str) -> dict:
|
| 176 |
+
if avatar_name in self.loaded_avatars:
|
| 177 |
+
return self.loaded_avatars[avatar_name]
|
| 178 |
+
|
| 179 |
+
avatar_path = self.avatar_dir / avatar_name
|
| 180 |
+
if not avatar_path.exists():
|
| 181 |
+
raise FileNotFoundError(f"Avatar not found: {avatar_name}")
|
| 182 |
+
|
| 183 |
+
print(f"Loading avatar '{avatar_name}' into memory...")
|
| 184 |
+
t0 = time.time()
|
| 185 |
+
|
| 186 |
+
avatar_data = {}
|
| 187 |
+
|
| 188 |
+
with open(avatar_path / "metadata.pkl", 'rb') as f:
|
| 189 |
+
avatar_data['metadata'] = pickle.load(f)
|
| 190 |
+
|
| 191 |
+
with open(avatar_path / "coords.pkl", 'rb') as f:
|
| 192 |
+
avatar_data['coord_list'] = pickle.load(f)
|
| 193 |
+
|
| 194 |
+
with open(avatar_path / "frames.pkl", 'rb') as f:
|
| 195 |
+
avatar_data['frame_list'] = pickle.load(f)
|
| 196 |
+
|
| 197 |
+
with open(avatar_path / "latents.pkl", 'rb') as f:
|
| 198 |
+
latents_np = pickle.load(f)
|
| 199 |
+
avatar_data['latent_list'] = [
|
| 200 |
+
torch.from_numpy(l).to(self.device) for l in latents_np
|
| 201 |
+
]
|
| 202 |
+
|
| 203 |
+
with open(avatar_path / "crop_info.pkl", 'rb') as f:
|
| 204 |
+
avatar_data['crop_info'] = pickle.load(f)
|
| 205 |
+
|
| 206 |
+
self.loaded_avatars[avatar_name] = avatar_data
|
| 207 |
+
print(f"Avatar loaded in {time.time() - t0:.2f}s")
|
| 208 |
+
|
| 209 |
+
return avatar_data
|
| 210 |
+
|
| 211 |
+
def unload_avatar(self, avatar_name: str):
|
| 212 |
+
if avatar_name in self.loaded_avatars:
|
| 213 |
+
del self.loaded_avatars[avatar_name]
|
| 214 |
+
torch.cuda.empty_cache()
|
| 215 |
+
|
| 216 |
+
def _encode_video_nvenc(self, frames_dir: str, audio_path: str, output_path: str, fps: int) -> float:
|
| 217 |
+
"""Encode video using NVENC hardware acceleration."""
|
| 218 |
+
t0 = time.time()
|
| 219 |
+
temp_vid = frames_dir.replace('/results', '/temp.mp4')
|
| 220 |
+
|
| 221 |
+
if self.use_nvenc:
|
| 222 |
+
# NVENC H.264 encoding (much faster)
|
| 223 |
+
cmd_img2video = (
|
| 224 |
+
f"ffmpeg -y -v warning -r {fps} -f image2 -i {frames_dir}/%08d.png "
|
| 225 |
+
f"-c:v h264_nvenc -preset {self.nvenc_preset} -cq {self.crf} "
|
| 226 |
+
f"-pix_fmt yuv420p {temp_vid}"
|
| 227 |
+
)
|
| 228 |
+
else:
|
| 229 |
+
# Fallback to CPU encoding
|
| 230 |
+
cmd_img2video = (
|
| 231 |
+
f"ffmpeg -y -v warning -r {fps} -f image2 -i {frames_dir}/%08d.png "
|
| 232 |
+
f"-vcodec libx264 -vf format=yuv420p -crf 18 {temp_vid}"
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
os.system(cmd_img2video)
|
| 236 |
+
|
| 237 |
+
# Add audio
|
| 238 |
+
cmd_combine = f"ffmpeg -y -v warning -i {audio_path} -i {temp_vid} -c:v copy -c:a aac {output_path}"
|
| 239 |
+
os.system(cmd_combine)
|
| 240 |
+
|
| 241 |
+
# Cleanup temp video
|
| 242 |
+
if os.path.exists(temp_vid):
|
| 243 |
+
os.remove(temp_vid)
|
| 244 |
+
|
| 245 |
+
return time.time() - t0
|
| 246 |
+
|
| 247 |
+
def _parallel_face_blending(self, res_frame_list, coord_list_cycle, frame_list_cycle, result_img_path) -> float:
|
| 248 |
+
"""Parallel face blending using thread pool."""
|
| 249 |
+
t0 = time.time()
|
| 250 |
+
|
| 251 |
+
fp_config = {
|
| 252 |
+
'left_cheek_width': self.left_cheek_width,
|
| 253 |
+
'right_cheek_width': self.right_cheek_width
|
| 254 |
+
}
|
| 255 |
+
|
| 256 |
+
# Prepare all tasks
|
| 257 |
+
tasks = []
|
| 258 |
+
for i, res_frame in enumerate(res_frame_list):
|
| 259 |
+
bbox = coord_list_cycle[i % len(coord_list_cycle)]
|
| 260 |
+
ori_frame = copy.deepcopy(frame_list_cycle[i % len(frame_list_cycle)])
|
| 261 |
+
tasks.append((
|
| 262 |
+
i, res_frame, bbox, ori_frame,
|
| 263 |
+
self.extra_margin, self.version, self.parsing_mode, fp_config
|
| 264 |
+
))
|
| 265 |
+
|
| 266 |
+
# Process in parallel
|
| 267 |
+
results = list(self.thread_pool.map(blend_single_frame, tasks))
|
| 268 |
+
|
| 269 |
+
# Sort and save results
|
| 270 |
+
results.sort(key=lambda x: x[0])
|
| 271 |
+
for i, combine_frame in results:
|
| 272 |
+
if combine_frame is not None:
|
| 273 |
+
cv2.imwrite(f"{result_img_path}/{str(i).zfill(8)}.png", combine_frame)
|
| 274 |
+
|
| 275 |
+
return time.time() - t0
|
| 276 |
+
|
| 277 |
+
@torch.no_grad()
|
| 278 |
+
def generate_with_avatar(
|
| 279 |
+
self,
|
| 280 |
+
avatar_name: str,
|
| 281 |
+
audio_path: str,
|
| 282 |
+
output_path: str,
|
| 283 |
+
fps: Optional[int] = None,
|
| 284 |
+
use_parallel_blending: bool = True
|
| 285 |
+
) -> dict:
|
| 286 |
+
"""Generate video using pre-processed avatar with all optimizations."""
|
| 287 |
+
if not self.is_loaded:
|
| 288 |
+
raise RuntimeError("Models not loaded!")
|
| 289 |
+
|
| 290 |
+
fps = fps or self.fps
|
| 291 |
+
timings = {}
|
| 292 |
+
total_start = time.time()
|
| 293 |
+
|
| 294 |
+
# Load avatar
|
| 295 |
+
t0 = time.time()
|
| 296 |
+
avatar = self.load_avatar(avatar_name)
|
| 297 |
+
timings["avatar_load"] = time.time() - t0
|
| 298 |
+
|
| 299 |
+
coord_list = avatar['coord_list']
|
| 300 |
+
frame_list = avatar['frame_list']
|
| 301 |
+
input_latent_list = avatar['latent_list']
|
| 302 |
+
|
| 303 |
+
temp_dir = tempfile.mkdtemp()
|
| 304 |
+
|
| 305 |
+
try:
|
| 306 |
+
# 1. Extract audio features
|
| 307 |
+
t0 = time.time()
|
| 308 |
+
whisper_input_features, librosa_length = self.audio_processor.get_audio_feature(audio_path)
|
| 309 |
+
whisper_chunks = self.audio_processor.get_whisper_chunk(
|
| 310 |
+
whisper_input_features,
|
| 311 |
+
self.device,
|
| 312 |
+
self.weight_dtype,
|
| 313 |
+
self.whisper,
|
| 314 |
+
librosa_length,
|
| 315 |
+
fps=fps,
|
| 316 |
+
audio_padding_length_left=self.audio_padding_left,
|
| 317 |
+
audio_padding_length_right=self.audio_padding_right,
|
| 318 |
+
)
|
| 319 |
+
timings["whisper_features"] = time.time() - t0
|
| 320 |
+
|
| 321 |
+
# 2. Prepare cycled lists
|
| 322 |
+
frame_list_cycle = frame_list + frame_list[::-1]
|
| 323 |
+
coord_list_cycle = coord_list + coord_list[::-1]
|
| 324 |
+
input_latent_list_cycle = input_latent_list + input_latent_list[::-1]
|
| 325 |
+
|
| 326 |
+
# 3. UNet inference
|
| 327 |
+
t0 = time.time()
|
| 328 |
+
gen = datagen(
|
| 329 |
+
whisper_chunks=whisper_chunks,
|
| 330 |
+
vae_encode_latents=input_latent_list_cycle,
|
| 331 |
+
batch_size=self.batch_size,
|
| 332 |
+
delay_frame=0,
|
| 333 |
+
device=self.device,
|
| 334 |
+
)
|
| 335 |
+
|
| 336 |
+
res_frame_list = []
|
| 337 |
+
for whisper_batch, latent_batch in gen:
|
| 338 |
+
audio_feature_batch = self.pe(whisper_batch)
|
| 339 |
+
latent_batch = latent_batch.to(dtype=self.unet.model.dtype)
|
| 340 |
+
pred_latents = self.unet.model(
|
| 341 |
+
latent_batch, self.timesteps,
|
| 342 |
+
encoder_hidden_states=audio_feature_batch
|
| 343 |
+
).sample
|
| 344 |
+
recon = self.vae.decode_latents(pred_latents)
|
| 345 |
+
for res_frame in recon:
|
| 346 |
+
res_frame_list.append(res_frame)
|
| 347 |
+
|
| 348 |
+
timings["unet_inference"] = time.time() - t0
|
| 349 |
+
|
| 350 |
+
# 4. Face blending (parallel or sequential)
|
| 351 |
+
result_img_path = os.path.join(temp_dir, "results")
|
| 352 |
+
os.makedirs(result_img_path, exist_ok=True)
|
| 353 |
+
|
| 354 |
+
if use_parallel_blending:
|
| 355 |
+
timings["face_blending"] = self._parallel_face_blending(
|
| 356 |
+
res_frame_list, coord_list_cycle, frame_list_cycle, result_img_path
|
| 357 |
+
)
|
| 358 |
+
timings["blending_mode"] = "parallel"
|
| 359 |
+
else:
|
| 360 |
+
t0 = time.time()
|
| 361 |
+
for i, res_frame in enumerate(res_frame_list):
|
| 362 |
+
bbox = coord_list_cycle[i % len(coord_list_cycle)]
|
| 363 |
+
ori_frame = copy.deepcopy(frame_list_cycle[i % len(frame_list_cycle)])
|
| 364 |
+
x1, y1, x2, y2 = bbox
|
| 365 |
+
|
| 366 |
+
if self.version == "v15":
|
| 367 |
+
y2 = y2 + self.extra_margin
|
| 368 |
+
y2 = min(y2, ori_frame.shape[0])
|
| 369 |
+
|
| 370 |
+
try:
|
| 371 |
+
res_frame = cv2.resize(res_frame.astype(np.uint8), (x2-x1, y2-y1))
|
| 372 |
+
except:
|
| 373 |
+
continue
|
| 374 |
+
|
| 375 |
+
if self.version == "v15":
|
| 376 |
+
combine_frame = get_image(ori_frame, res_frame, [x1, y1, x2, y2],
|
| 377 |
+
mode=self.parsing_mode, fp=self.fp)
|
| 378 |
+
else:
|
| 379 |
+
combine_frame = get_image(ori_frame, res_frame, [x1, y1, x2, y2], fp=self.fp)
|
| 380 |
+
|
| 381 |
+
cv2.imwrite(f"{result_img_path}/{str(i).zfill(8)}.png", combine_frame)
|
| 382 |
+
timings["face_blending"] = time.time() - t0
|
| 383 |
+
timings["blending_mode"] = "sequential"
|
| 384 |
+
|
| 385 |
+
# 5. Video encoding (NVENC)
|
| 386 |
+
timings["video_encoding"] = self._encode_video_nvenc(
|
| 387 |
+
result_img_path, audio_path, output_path, fps
|
| 388 |
+
)
|
| 389 |
+
timings["encoding_mode"] = "nvenc" if self.use_nvenc else "cpu"
|
| 390 |
+
|
| 391 |
+
finally:
|
| 392 |
+
shutil.rmtree(temp_dir, ignore_errors=True)
|
| 393 |
+
|
| 394 |
+
timings["total"] = time.time() - total_start
|
| 395 |
+
timings["frames_generated"] = len(res_frame_list)
|
| 396 |
+
|
| 397 |
+
return timings
|
| 398 |
+
|
| 399 |
+
@torch.no_grad()
|
| 400 |
+
def generate_batch(
|
| 401 |
+
self,
|
| 402 |
+
avatar_name: str,
|
| 403 |
+
audio_paths: List[str],
|
| 404 |
+
output_dir: str,
|
| 405 |
+
fps: Optional[int] = None
|
| 406 |
+
) -> dict:
|
| 407 |
+
"""Generate multiple videos from multiple audios efficiently."""
|
| 408 |
+
if not self.is_loaded:
|
| 409 |
+
raise RuntimeError("Models not loaded!")
|
| 410 |
+
|
| 411 |
+
fps = fps or self.fps
|
| 412 |
+
batch_timings = {"videos": [], "total": 0}
|
| 413 |
+
total_start = time.time()
|
| 414 |
+
|
| 415 |
+
# Load avatar once
|
| 416 |
+
t0 = time.time()
|
| 417 |
+
avatar = self.load_avatar(avatar_name)
|
| 418 |
+
batch_timings["avatar_load"] = time.time() - t0
|
| 419 |
+
|
| 420 |
+
coord_list = avatar['coord_list']
|
| 421 |
+
frame_list = avatar['frame_list']
|
| 422 |
+
input_latent_list = avatar['latent_list']
|
| 423 |
+
|
| 424 |
+
# Prepare cycled lists once
|
| 425 |
+
frame_list_cycle = frame_list + frame_list[::-1]
|
| 426 |
+
coord_list_cycle = coord_list + coord_list[::-1]
|
| 427 |
+
input_latent_list_cycle = input_latent_list + input_latent_list[::-1]
|
| 428 |
+
|
| 429 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 430 |
+
|
| 431 |
+
for idx, audio_path in enumerate(audio_paths):
|
| 432 |
+
video_start = time.time()
|
| 433 |
+
timings = {}
|
| 434 |
+
|
| 435 |
+
audio_name = Path(audio_path).stem
|
| 436 |
+
output_path = os.path.join(output_dir, f"{audio_name}.mp4")
|
| 437 |
+
|
| 438 |
+
temp_dir = tempfile.mkdtemp()
|
| 439 |
+
|
| 440 |
+
try:
|
| 441 |
+
# 1. Extract audio features
|
| 442 |
+
t0 = time.time()
|
| 443 |
+
whisper_input_features, librosa_length = self.audio_processor.get_audio_feature(audio_path)
|
| 444 |
+
whisper_chunks = self.audio_processor.get_whisper_chunk(
|
| 445 |
+
whisper_input_features,
|
| 446 |
+
self.device,
|
| 447 |
+
self.weight_dtype,
|
| 448 |
+
self.whisper,
|
| 449 |
+
librosa_length,
|
| 450 |
+
fps=fps,
|
| 451 |
+
audio_padding_length_left=self.audio_padding_left,
|
| 452 |
+
audio_padding_length_right=self.audio_padding_right,
|
| 453 |
+
)
|
| 454 |
+
timings["whisper_features"] = time.time() - t0
|
| 455 |
+
|
| 456 |
+
# 2. UNet inference
|
| 457 |
+
t0 = time.time()
|
| 458 |
+
gen = datagen(
|
| 459 |
+
whisper_chunks=whisper_chunks,
|
| 460 |
+
vae_encode_latents=input_latent_list_cycle,
|
| 461 |
+
batch_size=self.batch_size,
|
| 462 |
+
delay_frame=0,
|
| 463 |
+
device=self.device,
|
| 464 |
+
)
|
| 465 |
+
|
| 466 |
+
res_frame_list = []
|
| 467 |
+
for whisper_batch, latent_batch in gen:
|
| 468 |
+
audio_feature_batch = self.pe(whisper_batch)
|
| 469 |
+
latent_batch = latent_batch.to(dtype=self.unet.model.dtype)
|
| 470 |
+
pred_latents = self.unet.model(
|
| 471 |
+
latent_batch, self.timesteps,
|
| 472 |
+
encoder_hidden_states=audio_feature_batch
|
| 473 |
+
).sample
|
| 474 |
+
recon = self.vae.decode_latents(pred_latents)
|
| 475 |
+
for res_frame in recon:
|
| 476 |
+
res_frame_list.append(res_frame)
|
| 477 |
+
|
| 478 |
+
timings["unet_inference"] = time.time() - t0
|
| 479 |
+
|
| 480 |
+
# 3. Face blending (parallel)
|
| 481 |
+
result_img_path = os.path.join(temp_dir, "results")
|
| 482 |
+
os.makedirs(result_img_path, exist_ok=True)
|
| 483 |
+
timings["face_blending"] = self._parallel_face_blending(
|
| 484 |
+
res_frame_list, coord_list_cycle, frame_list_cycle, result_img_path
|
| 485 |
+
)
|
| 486 |
+
|
| 487 |
+
# 4. Video encoding (NVENC)
|
| 488 |
+
timings["video_encoding"] = self._encode_video_nvenc(
|
| 489 |
+
result_img_path, audio_path, output_path, fps
|
| 490 |
+
)
|
| 491 |
+
|
| 492 |
+
finally:
|
| 493 |
+
shutil.rmtree(temp_dir, ignore_errors=True)
|
| 494 |
+
|
| 495 |
+
timings["total"] = time.time() - video_start
|
| 496 |
+
timings["frames_generated"] = len(res_frame_list)
|
| 497 |
+
timings["output_path"] = output_path
|
| 498 |
+
timings["audio_path"] = audio_path
|
| 499 |
+
|
| 500 |
+
batch_timings["videos"].append(timings)
|
| 501 |
+
print(f" [{idx+1}/{len(audio_paths)}] {audio_name}: {timings['total']:.2f}s")
|
| 502 |
+
|
| 503 |
+
batch_timings["total"] = time.time() - total_start
|
| 504 |
+
batch_timings["num_videos"] = len(audio_paths)
|
| 505 |
+
batch_timings["avg_per_video"] = batch_timings["total"] / len(audio_paths) if audio_paths else 0
|
| 506 |
+
|
| 507 |
+
return batch_timings
|
| 508 |
+
|
| 509 |
+
|
| 510 |
+
# Global server
|
| 511 |
+
server = MuseTalkServerV3()
|
| 512 |
+
|
| 513 |
+
# FastAPI app
|
| 514 |
+
app = FastAPI(
|
| 515 |
+
title="MuseTalk API v3",
|
| 516 |
+
description="Ultra-optimized API with parallel blending, NVENC, and batch processing",
|
| 517 |
+
version="3.0.0"
|
| 518 |
+
)
|
| 519 |
+
|
| 520 |
+
app.add_middleware(
|
| 521 |
+
CORSMiddleware,
|
| 522 |
+
allow_origins=["*"],
|
| 523 |
+
allow_credentials=True,
|
| 524 |
+
allow_methods=["*"],
|
| 525 |
+
allow_headers=["*"],
|
| 526 |
+
)
|
| 527 |
+
|
| 528 |
+
|
| 529 |
+
@app.on_event("startup")
|
| 530 |
+
async def startup_event():
|
| 531 |
+
server.load_models()
|
| 532 |
+
|
| 533 |
+
|
| 534 |
+
@app.get("/health")
|
| 535 |
+
async def health_check():
|
| 536 |
+
return {
|
| 537 |
+
"status": "ok" if server.is_loaded else "loading",
|
| 538 |
+
"models_loaded": server.is_loaded,
|
| 539 |
+
"device": str(server.device) if server.device else None,
|
| 540 |
+
"loaded_avatars": list(server.loaded_avatars.keys()),
|
| 541 |
+
"optimizations": {
|
| 542 |
+
"parallel_workers": server.num_workers,
|
| 543 |
+
"nvenc_enabled": server.use_nvenc,
|
| 544 |
+
"nvenc_preset": server.nvenc_preset
|
| 545 |
+
}
|
| 546 |
+
}
|
| 547 |
+
|
| 548 |
+
|
| 549 |
+
@app.get("/avatars")
|
| 550 |
+
async def list_avatars():
|
| 551 |
+
avatars = []
|
| 552 |
+
for p in server.avatar_dir.iterdir():
|
| 553 |
+
if p.is_dir() and (p / "metadata.pkl").exists():
|
| 554 |
+
with open(p / "metadata.pkl", 'rb') as f:
|
| 555 |
+
metadata = pickle.load(f)
|
| 556 |
+
metadata['loaded'] = p.name in server.loaded_avatars
|
| 557 |
+
avatars.append(metadata)
|
| 558 |
+
return {"avatars": avatars}
|
| 559 |
+
|
| 560 |
+
|
| 561 |
+
@app.post("/avatars/{avatar_name}/load")
|
| 562 |
+
async def load_avatar(avatar_name: str):
|
| 563 |
+
try:
|
| 564 |
+
server.load_avatar(avatar_name)
|
| 565 |
+
return {"status": "loaded", "avatar_name": avatar_name}
|
| 566 |
+
except FileNotFoundError as e:
|
| 567 |
+
raise HTTPException(status_code=404, detail=str(e))
|
| 568 |
+
|
| 569 |
+
|
| 570 |
+
@app.post("/avatars/{avatar_name}/unload")
|
| 571 |
+
async def unload_avatar(avatar_name: str):
|
| 572 |
+
server.unload_avatar(avatar_name)
|
| 573 |
+
return {"status": "unloaded", "avatar_name": avatar_name}
|
| 574 |
+
|
| 575 |
+
|
| 576 |
+
class GenerateRequest(BaseModel):
|
| 577 |
+
avatar_name: str
|
| 578 |
+
audio_path: str
|
| 579 |
+
output_path: str
|
| 580 |
+
fps: Optional[int] = 25
|
| 581 |
+
use_parallel_blending: bool = True
|
| 582 |
+
|
| 583 |
+
|
| 584 |
+
@app.post("/generate/avatar")
|
| 585 |
+
async def generate_with_avatar(request: GenerateRequest):
|
| 586 |
+
if not server.is_loaded:
|
| 587 |
+
raise HTTPException(status_code=503, detail="Models not loaded")
|
| 588 |
+
|
| 589 |
+
if not os.path.exists(request.audio_path):
|
| 590 |
+
raise HTTPException(status_code=404, detail=f"Audio not found: {request.audio_path}")
|
| 591 |
+
|
| 592 |
+
try:
|
| 593 |
+
timings = server.generate_with_avatar(
|
| 594 |
+
avatar_name=request.avatar_name,
|
| 595 |
+
audio_path=request.audio_path,
|
| 596 |
+
output_path=request.output_path,
|
| 597 |
+
fps=request.fps,
|
| 598 |
+
use_parallel_blending=request.use_parallel_blending
|
| 599 |
+
)
|
| 600 |
+
return {
|
| 601 |
+
"status": "success",
|
| 602 |
+
"output_path": request.output_path,
|
| 603 |
+
"timings": timings
|
| 604 |
+
}
|
| 605 |
+
except FileNotFoundError as e:
|
| 606 |
+
raise HTTPException(status_code=404, detail=str(e))
|
| 607 |
+
except Exception as e:
|
| 608 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 609 |
+
|
| 610 |
+
|
| 611 |
+
class BatchGenerateRequest(BaseModel):
|
| 612 |
+
avatar_name: str
|
| 613 |
+
audio_paths: List[str]
|
| 614 |
+
output_dir: str
|
| 615 |
+
fps: Optional[int] = 25
|
| 616 |
+
|
| 617 |
+
|
| 618 |
+
@app.post("/generate/batch")
|
| 619 |
+
async def generate_batch(request: BatchGenerateRequest):
|
| 620 |
+
"""Generate multiple videos from multiple audios."""
|
| 621 |
+
if not server.is_loaded:
|
| 622 |
+
raise HTTPException(status_code=503, detail="Models not loaded")
|
| 623 |
+
|
| 624 |
+
for audio_path in request.audio_paths:
|
| 625 |
+
if not os.path.exists(audio_path):
|
| 626 |
+
raise HTTPException(status_code=404, detail=f"Audio not found: {audio_path}")
|
| 627 |
+
|
| 628 |
+
try:
|
| 629 |
+
timings = server.generate_batch(
|
| 630 |
+
avatar_name=request.avatar_name,
|
| 631 |
+
audio_paths=request.audio_paths,
|
| 632 |
+
output_dir=request.output_dir,
|
| 633 |
+
fps=request.fps
|
| 634 |
+
)
|
| 635 |
+
return {
|
| 636 |
+
"status": "success",
|
| 637 |
+
"output_dir": request.output_dir,
|
| 638 |
+
"timings": timings
|
| 639 |
+
}
|
| 640 |
+
except Exception as e:
|
| 641 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 642 |
+
|
| 643 |
+
|
| 644 |
+
if __name__ == "__main__":
|
| 645 |
+
import argparse
|
| 646 |
+
parser = argparse.ArgumentParser()
|
| 647 |
+
parser.add_argument("--host", type=str, default="0.0.0.0")
|
| 648 |
+
parser.add_argument("--port", type=int, default=8000)
|
| 649 |
+
args = parser.parse_args()
|
| 650 |
+
|
| 651 |
+
uvicorn.run(app, host=args.host, port=args.port)
|
musetalk_api_server_v3_fixed.py
ADDED
|
@@ -0,0 +1,371 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
MuseTalk HTTP API Server v3 (Fixed)
|
| 3 |
+
Optimized with:
|
| 4 |
+
1. Sequential face blending (parallel had overhead)
|
| 5 |
+
2. NVENC hardware video encoding
|
| 6 |
+
3. Batch audio processing
|
| 7 |
+
"""
|
| 8 |
+
import os
|
| 9 |
+
import cv2
|
| 10 |
+
import copy
|
| 11 |
+
import torch
|
| 12 |
+
import glob
|
| 13 |
+
import shutil
|
| 14 |
+
import pickle
|
| 15 |
+
import numpy as np
|
| 16 |
+
import subprocess
|
| 17 |
+
import tempfile
|
| 18 |
+
import hashlib
|
| 19 |
+
import time
|
| 20 |
+
from pathlib import Path
|
| 21 |
+
from typing import Optional, List
|
| 22 |
+
from fastapi import FastAPI, File, UploadFile, Form, HTTPException
|
| 23 |
+
from fastapi.responses import FileResponse, JSONResponse
|
| 24 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 25 |
+
from pydantic import BaseModel
|
| 26 |
+
from tqdm import tqdm
|
| 27 |
+
from transformers import WhisperModel
|
| 28 |
+
import uvicorn
|
| 29 |
+
|
| 30 |
+
# MuseTalk imports
|
| 31 |
+
from musetalk.utils.blending import get_image
|
| 32 |
+
from musetalk.utils.face_parsing import FaceParsing
|
| 33 |
+
from musetalk.utils.audio_processor import AudioProcessor
|
| 34 |
+
from musetalk.utils.utils import get_file_type, datagen, load_all_model
|
| 35 |
+
from musetalk.utils.preprocessing import coord_placeholder
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class MuseTalkServerV3:
|
| 39 |
+
def __init__(self):
|
| 40 |
+
self.device = None
|
| 41 |
+
self.vae = None
|
| 42 |
+
self.unet = None
|
| 43 |
+
self.pe = None
|
| 44 |
+
self.whisper = None
|
| 45 |
+
self.audio_processor = None
|
| 46 |
+
self.fp = None
|
| 47 |
+
self.timesteps = None
|
| 48 |
+
self.weight_dtype = None
|
| 49 |
+
self.is_loaded = False
|
| 50 |
+
|
| 51 |
+
self.loaded_avatars = {}
|
| 52 |
+
self.avatar_dir = Path("./avatars")
|
| 53 |
+
|
| 54 |
+
self.fps = 25
|
| 55 |
+
self.batch_size = 8
|
| 56 |
+
self.use_float16 = True
|
| 57 |
+
self.version = "v15"
|
| 58 |
+
self.extra_margin = 10
|
| 59 |
+
self.parsing_mode = "jaw"
|
| 60 |
+
self.left_cheek_width = 90
|
| 61 |
+
self.right_cheek_width = 90
|
| 62 |
+
self.audio_padding_left = 2
|
| 63 |
+
self.audio_padding_right = 2
|
| 64 |
+
|
| 65 |
+
# NVENC
|
| 66 |
+
self.use_nvenc = True
|
| 67 |
+
self.nvenc_preset = "p4"
|
| 68 |
+
self.crf = 23
|
| 69 |
+
|
| 70 |
+
def load_models(self, gpu_id: int = 0):
|
| 71 |
+
if self.is_loaded:
|
| 72 |
+
print("Models already loaded!")
|
| 73 |
+
return
|
| 74 |
+
|
| 75 |
+
print("=" * 50)
|
| 76 |
+
print("Loading MuseTalk models (v3 Optimized)...")
|
| 77 |
+
print("=" * 50)
|
| 78 |
+
|
| 79 |
+
start_time = time.time()
|
| 80 |
+
self.device = torch.device(f"cuda:{gpu_id}" if torch.cuda.is_available() else "cpu")
|
| 81 |
+
|
| 82 |
+
self.vae, self.unet, self.pe = load_all_model(
|
| 83 |
+
unet_model_path="./models/musetalkV15/unet.pth",
|
| 84 |
+
vae_type="sd-vae",
|
| 85 |
+
unet_config="./models/musetalk/config.json",
|
| 86 |
+
device=self.device
|
| 87 |
+
)
|
| 88 |
+
self.timesteps = torch.tensor([0], device=self.device)
|
| 89 |
+
|
| 90 |
+
self.pe = self.pe.half().to(self.device)
|
| 91 |
+
self.vae.vae = self.vae.vae.half().to(self.device)
|
| 92 |
+
self.unet.model = self.unet.model.half().to(self.device)
|
| 93 |
+
|
| 94 |
+
self.audio_processor = AudioProcessor(feature_extractor_path="./models/whisper")
|
| 95 |
+
self.weight_dtype = self.unet.model.dtype
|
| 96 |
+
self.whisper = WhisperModel.from_pretrained("./models/whisper")
|
| 97 |
+
self.whisper = self.whisper.to(device=self.device, dtype=self.weight_dtype).eval()
|
| 98 |
+
self.whisper.requires_grad_(False)
|
| 99 |
+
|
| 100 |
+
self.fp = FaceParsing(
|
| 101 |
+
left_cheek_width=self.left_cheek_width,
|
| 102 |
+
right_cheek_width=self.right_cheek_width
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
self.is_loaded = True
|
| 106 |
+
print(f"Models loaded in {time.time() - start_time:.2f}s")
|
| 107 |
+
|
| 108 |
+
def load_avatar(self, avatar_name: str) -> dict:
|
| 109 |
+
if avatar_name in self.loaded_avatars:
|
| 110 |
+
return self.loaded_avatars[avatar_name]
|
| 111 |
+
|
| 112 |
+
avatar_path = self.avatar_dir / avatar_name
|
| 113 |
+
if not avatar_path.exists():
|
| 114 |
+
raise FileNotFoundError(f"Avatar not found: {avatar_name}")
|
| 115 |
+
|
| 116 |
+
avatar_data = {}
|
| 117 |
+
with open(avatar_path / "metadata.pkl", 'rb') as f:
|
| 118 |
+
avatar_data['metadata'] = pickle.load(f)
|
| 119 |
+
with open(avatar_path / "coords.pkl", 'rb') as f:
|
| 120 |
+
avatar_data['coord_list'] = pickle.load(f)
|
| 121 |
+
with open(avatar_path / "frames.pkl", 'rb') as f:
|
| 122 |
+
avatar_data['frame_list'] = pickle.load(f)
|
| 123 |
+
with open(avatar_path / "latents.pkl", 'rb') as f:
|
| 124 |
+
latents_np = pickle.load(f)
|
| 125 |
+
avatar_data['latent_list'] = [torch.from_numpy(l).to(self.device) for l in latents_np]
|
| 126 |
+
|
| 127 |
+
self.loaded_avatars[avatar_name] = avatar_data
|
| 128 |
+
return avatar_data
|
| 129 |
+
|
| 130 |
+
def _encode_video_nvenc(self, frames_dir: str, audio_path: str, output_path: str, fps: int) -> float:
|
| 131 |
+
t0 = time.time()
|
| 132 |
+
temp_vid = output_path.replace('.mp4', '_temp.mp4')
|
| 133 |
+
|
| 134 |
+
if self.use_nvenc:
|
| 135 |
+
cmd = (
|
| 136 |
+
f"ffmpeg -y -v warning -r {fps} -f image2 -i {frames_dir}/%08d.png "
|
| 137 |
+
f"-c:v h264_nvenc -preset {self.nvenc_preset} -cq {self.crf} -pix_fmt yuv420p {temp_vid}"
|
| 138 |
+
)
|
| 139 |
+
else:
|
| 140 |
+
cmd = (
|
| 141 |
+
f"ffmpeg -y -v warning -r {fps} -f image2 -i {frames_dir}/%08d.png "
|
| 142 |
+
f"-vcodec libx264 -crf 18 -pix_fmt yuv420p {temp_vid}"
|
| 143 |
+
)
|
| 144 |
+
os.system(cmd)
|
| 145 |
+
|
| 146 |
+
os.system(f"ffmpeg -y -v warning -i {audio_path} -i {temp_vid} -c:v copy -c:a aac {output_path}")
|
| 147 |
+
os.remove(temp_vid) if os.path.exists(temp_vid) else None
|
| 148 |
+
|
| 149 |
+
return time.time() - t0
|
| 150 |
+
|
| 151 |
+
@torch.no_grad()
|
| 152 |
+
def generate_with_avatar(self, avatar_name: str, audio_path: str, output_path: str, fps: int = 25) -> dict:
|
| 153 |
+
if not self.is_loaded:
|
| 154 |
+
raise RuntimeError("Models not loaded!")
|
| 155 |
+
|
| 156 |
+
timings = {}
|
| 157 |
+
total_start = time.time()
|
| 158 |
+
|
| 159 |
+
t0 = time.time()
|
| 160 |
+
avatar = self.load_avatar(avatar_name)
|
| 161 |
+
timings["avatar_load"] = time.time() - t0
|
| 162 |
+
|
| 163 |
+
coord_list = avatar['coord_list']
|
| 164 |
+
frame_list = avatar['frame_list']
|
| 165 |
+
input_latent_list = avatar['latent_list']
|
| 166 |
+
|
| 167 |
+
temp_dir = tempfile.mkdtemp()
|
| 168 |
+
|
| 169 |
+
try:
|
| 170 |
+
# Whisper
|
| 171 |
+
t0 = time.time()
|
| 172 |
+
whisper_input_features, librosa_length = self.audio_processor.get_audio_feature(audio_path)
|
| 173 |
+
whisper_chunks = self.audio_processor.get_whisper_chunk(
|
| 174 |
+
whisper_input_features, self.device, self.weight_dtype, self.whisper,
|
| 175 |
+
librosa_length, fps=fps,
|
| 176 |
+
audio_padding_length_left=self.audio_padding_left,
|
| 177 |
+
audio_padding_length_right=self.audio_padding_right,
|
| 178 |
+
)
|
| 179 |
+
timings["whisper_features"] = time.time() - t0
|
| 180 |
+
|
| 181 |
+
# Cycle lists
|
| 182 |
+
frame_list_cycle = frame_list + frame_list[::-1]
|
| 183 |
+
coord_list_cycle = coord_list + coord_list[::-1]
|
| 184 |
+
input_latent_list_cycle = input_latent_list + input_latent_list[::-1]
|
| 185 |
+
|
| 186 |
+
# UNet
|
| 187 |
+
t0 = time.time()
|
| 188 |
+
gen = datagen(whisper_chunks=whisper_chunks, vae_encode_latents=input_latent_list_cycle,
|
| 189 |
+
batch_size=self.batch_size, delay_frame=0, device=self.device)
|
| 190 |
+
|
| 191 |
+
res_frame_list = []
|
| 192 |
+
for whisper_batch, latent_batch in gen:
|
| 193 |
+
audio_feature_batch = self.pe(whisper_batch)
|
| 194 |
+
latent_batch = latent_batch.to(dtype=self.unet.model.dtype)
|
| 195 |
+
pred_latents = self.unet.model(latent_batch, self.timesteps,
|
| 196 |
+
encoder_hidden_states=audio_feature_batch).sample
|
| 197 |
+
recon = self.vae.decode_latents(pred_latents)
|
| 198 |
+
res_frame_list.extend(recon)
|
| 199 |
+
timings["unet_inference"] = time.time() - t0
|
| 200 |
+
|
| 201 |
+
# Face blending (sequential - faster than parallel due to FP overhead)
|
| 202 |
+
t0 = time.time()
|
| 203 |
+
result_img_path = os.path.join(temp_dir, "results")
|
| 204 |
+
os.makedirs(result_img_path, exist_ok=True)
|
| 205 |
+
|
| 206 |
+
for i, res_frame in enumerate(res_frame_list):
|
| 207 |
+
bbox = coord_list_cycle[i % len(coord_list_cycle)]
|
| 208 |
+
ori_frame = copy.deepcopy(frame_list_cycle[i % len(frame_list_cycle)])
|
| 209 |
+
x1, y1, x2, y2 = bbox
|
| 210 |
+
y2 = min(y2 + self.extra_margin, ori_frame.shape[0])
|
| 211 |
+
|
| 212 |
+
try:
|
| 213 |
+
res_frame = cv2.resize(res_frame.astype(np.uint8), (x2-x1, y2-y1))
|
| 214 |
+
combine_frame = get_image(ori_frame, res_frame, [x1, y1, x2, y2],
|
| 215 |
+
mode=self.parsing_mode, fp=self.fp)
|
| 216 |
+
cv2.imwrite(f"{result_img_path}/{str(i).zfill(8)}.png", combine_frame)
|
| 217 |
+
except:
|
| 218 |
+
continue
|
| 219 |
+
timings["face_blending"] = time.time() - t0
|
| 220 |
+
|
| 221 |
+
# NVENC encoding
|
| 222 |
+
timings["video_encoding"] = self._encode_video_nvenc(result_img_path, audio_path, output_path, fps)
|
| 223 |
+
|
| 224 |
+
finally:
|
| 225 |
+
shutil.rmtree(temp_dir, ignore_errors=True)
|
| 226 |
+
|
| 227 |
+
timings["total"] = time.time() - total_start
|
| 228 |
+
timings["frames_generated"] = len(res_frame_list)
|
| 229 |
+
return timings
|
| 230 |
+
|
| 231 |
+
@torch.no_grad()
|
| 232 |
+
def generate_batch(self, avatar_name: str, audio_paths: List[str], output_dir: str, fps: int = 25) -> dict:
|
| 233 |
+
if not self.is_loaded:
|
| 234 |
+
raise RuntimeError("Models not loaded!")
|
| 235 |
+
|
| 236 |
+
batch_timings = {"videos": [], "total": 0}
|
| 237 |
+
total_start = time.time()
|
| 238 |
+
|
| 239 |
+
t0 = time.time()
|
| 240 |
+
avatar = self.load_avatar(avatar_name)
|
| 241 |
+
batch_timings["avatar_load"] = time.time() - t0
|
| 242 |
+
|
| 243 |
+
coord_list = avatar['coord_list']
|
| 244 |
+
frame_list = avatar['frame_list']
|
| 245 |
+
input_latent_list = avatar['latent_list']
|
| 246 |
+
frame_list_cycle = frame_list + frame_list[::-1]
|
| 247 |
+
coord_list_cycle = coord_list + coord_list[::-1]
|
| 248 |
+
input_latent_list_cycle = input_latent_list + input_latent_list[::-1]
|
| 249 |
+
|
| 250 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 251 |
+
|
| 252 |
+
for idx, audio_path in enumerate(audio_paths):
|
| 253 |
+
video_start = time.time()
|
| 254 |
+
timings = {}
|
| 255 |
+
output_path = os.path.join(output_dir, f"{Path(audio_path).stem}.mp4")
|
| 256 |
+
temp_dir = tempfile.mkdtemp()
|
| 257 |
+
|
| 258 |
+
try:
|
| 259 |
+
t0 = time.time()
|
| 260 |
+
whisper_input_features, librosa_length = self.audio_processor.get_audio_feature(audio_path)
|
| 261 |
+
whisper_chunks = self.audio_processor.get_whisper_chunk(
|
| 262 |
+
whisper_input_features, self.device, self.weight_dtype, self.whisper,
|
| 263 |
+
librosa_length, fps=fps,
|
| 264 |
+
audio_padding_length_left=self.audio_padding_left,
|
| 265 |
+
audio_padding_length_right=self.audio_padding_right,
|
| 266 |
+
)
|
| 267 |
+
timings["whisper"] = time.time() - t0
|
| 268 |
+
|
| 269 |
+
t0 = time.time()
|
| 270 |
+
gen = datagen(whisper_chunks=whisper_chunks, vae_encode_latents=input_latent_list_cycle,
|
| 271 |
+
batch_size=self.batch_size, delay_frame=0, device=self.device)
|
| 272 |
+
res_frame_list = []
|
| 273 |
+
for whisper_batch, latent_batch in gen:
|
| 274 |
+
audio_feature_batch = self.pe(whisper_batch)
|
| 275 |
+
latent_batch = latent_batch.to(dtype=self.unet.model.dtype)
|
| 276 |
+
pred_latents = self.unet.model(latent_batch, self.timesteps,
|
| 277 |
+
encoder_hidden_states=audio_feature_batch).sample
|
| 278 |
+
res_frame_list.extend(self.vae.decode_latents(pred_latents))
|
| 279 |
+
timings["unet"] = time.time() - t0
|
| 280 |
+
|
| 281 |
+
t0 = time.time()
|
| 282 |
+
result_img_path = os.path.join(temp_dir, "results")
|
| 283 |
+
os.makedirs(result_img_path, exist_ok=True)
|
| 284 |
+
for i, res_frame in enumerate(res_frame_list):
|
| 285 |
+
bbox = coord_list_cycle[i % len(coord_list_cycle)]
|
| 286 |
+
ori_frame = copy.deepcopy(frame_list_cycle[i % len(frame_list_cycle)])
|
| 287 |
+
x1, y1, x2, y2 = bbox
|
| 288 |
+
y2 = min(y2 + self.extra_margin, ori_frame.shape[0])
|
| 289 |
+
try:
|
| 290 |
+
res_frame = cv2.resize(res_frame.astype(np.uint8), (x2-x1, y2-y1))
|
| 291 |
+
combine_frame = get_image(ori_frame, res_frame, [x1, y1, x2, y2],
|
| 292 |
+
mode=self.parsing_mode, fp=self.fp)
|
| 293 |
+
cv2.imwrite(f"{result_img_path}/{str(i).zfill(8)}.png", combine_frame)
|
| 294 |
+
except:
|
| 295 |
+
continue
|
| 296 |
+
timings["blending"] = time.time() - t0
|
| 297 |
+
|
| 298 |
+
timings["encoding"] = self._encode_video_nvenc(result_img_path, audio_path, output_path, fps)
|
| 299 |
+
|
| 300 |
+
finally:
|
| 301 |
+
shutil.rmtree(temp_dir, ignore_errors=True)
|
| 302 |
+
|
| 303 |
+
timings["total"] = time.time() - video_start
|
| 304 |
+
timings["frames"] = len(res_frame_list)
|
| 305 |
+
timings["output"] = output_path
|
| 306 |
+
batch_timings["videos"].append(timings)
|
| 307 |
+
print(f" [{idx+1}/{len(audio_paths)}] {Path(audio_path).stem}: {timings['total']:.2f}s")
|
| 308 |
+
|
| 309 |
+
batch_timings["total"] = time.time() - total_start
|
| 310 |
+
batch_timings["num_videos"] = len(audio_paths)
|
| 311 |
+
batch_timings["avg_per_video"] = batch_timings["total"] / len(audio_paths) if audio_paths else 0
|
| 312 |
+
return batch_timings
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
server = MuseTalkServerV3()
|
| 316 |
+
app = FastAPI(title="MuseTalk API v3", version="3.0.0")
|
| 317 |
+
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"])
|
| 318 |
+
|
| 319 |
+
@app.on_event("startup")
|
| 320 |
+
async def startup():
|
| 321 |
+
server.load_models()
|
| 322 |
+
|
| 323 |
+
@app.get("/health")
|
| 324 |
+
async def health():
|
| 325 |
+
return {"status": "ok" if server.is_loaded else "loading", "device": str(server.device),
|
| 326 |
+
"avatars": list(server.loaded_avatars.keys()), "nvenc": server.use_nvenc}
|
| 327 |
+
|
| 328 |
+
@app.get("/avatars")
|
| 329 |
+
async def list_avatars():
|
| 330 |
+
avatars = []
|
| 331 |
+
for p in server.avatar_dir.iterdir():
|
| 332 |
+
if p.is_dir() and (p / "metadata.pkl").exists():
|
| 333 |
+
with open(p / "metadata.pkl", 'rb') as f:
|
| 334 |
+
avatars.append(pickle.load(f))
|
| 335 |
+
return {"avatars": avatars}
|
| 336 |
+
|
| 337 |
+
class GenReq(BaseModel):
|
| 338 |
+
avatar_name: str
|
| 339 |
+
audio_path: str
|
| 340 |
+
output_path: str
|
| 341 |
+
fps: int = 25
|
| 342 |
+
|
| 343 |
+
@app.post("/generate/avatar")
|
| 344 |
+
async def generate(req: GenReq):
|
| 345 |
+
if not os.path.exists(req.audio_path):
|
| 346 |
+
raise HTTPException(404, f"Audio not found: {req.audio_path}")
|
| 347 |
+
try:
|
| 348 |
+
timings = server.generate_with_avatar(req.avatar_name, req.audio_path, req.output_path, req.fps)
|
| 349 |
+
return {"status": "success", "output_path": req.output_path, "timings": timings}
|
| 350 |
+
except Exception as e:
|
| 351 |
+
raise HTTPException(500, str(e))
|
| 352 |
+
|
| 353 |
+
class BatchReq(BaseModel):
|
| 354 |
+
avatar_name: str
|
| 355 |
+
audio_paths: List[str]
|
| 356 |
+
output_dir: str
|
| 357 |
+
fps: int = 25
|
| 358 |
+
|
| 359 |
+
@app.post("/generate/batch")
|
| 360 |
+
async def batch(req: BatchReq):
|
| 361 |
+
for p in req.audio_paths:
|
| 362 |
+
if not os.path.exists(p):
|
| 363 |
+
raise HTTPException(404, f"Audio not found: {p}")
|
| 364 |
+
try:
|
| 365 |
+
timings = server.generate_batch(req.avatar_name, req.audio_paths, req.output_dir, req.fps)
|
| 366 |
+
return {"status": "success", "output_dir": req.output_dir, "timings": timings}
|
| 367 |
+
except Exception as e:
|
| 368 |
+
raise HTTPException(500, str(e))
|
| 369 |
+
|
| 370 |
+
if __name__ == "__main__":
|
| 371 |
+
uvicorn.run(app, host="0.0.0.0", port=8000)
|
run_inference.py
ADDED
|
@@ -0,0 +1,276 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import cv2
|
| 3 |
+
import math
|
| 4 |
+
import copy
|
| 5 |
+
import torch
|
| 6 |
+
import glob
|
| 7 |
+
import shutil
|
| 8 |
+
import pickle
|
| 9 |
+
import argparse
|
| 10 |
+
import numpy as np
|
| 11 |
+
import subprocess
|
| 12 |
+
from tqdm import tqdm
|
| 13 |
+
from omegaconf import OmegaConf
|
| 14 |
+
from transformers import WhisperModel
|
| 15 |
+
import sys
|
| 16 |
+
|
| 17 |
+
from musetalk.utils.blending import get_image
|
| 18 |
+
from musetalk.utils.face_parsing import FaceParsing
|
| 19 |
+
from musetalk.utils.audio_processor import AudioProcessor
|
| 20 |
+
from musetalk.utils.utils import get_file_type, get_video_fps, datagen, load_all_model
|
| 21 |
+
from musetalk.utils.preprocessing import get_landmark_and_bbox, read_imgs, coord_placeholder
|
| 22 |
+
|
| 23 |
+
def fast_check_ffmpeg():
|
| 24 |
+
try:
|
| 25 |
+
subprocess.run(["ffmpeg", "-version"], capture_output=True, check=True)
|
| 26 |
+
return True
|
| 27 |
+
except:
|
| 28 |
+
return False
|
| 29 |
+
|
| 30 |
+
@torch.no_grad()
|
| 31 |
+
def main(args):
|
| 32 |
+
# Configure ffmpeg path
|
| 33 |
+
if not fast_check_ffmpeg():
|
| 34 |
+
print("Adding ffmpeg to PATH")
|
| 35 |
+
# Choose path separator based on operating system
|
| 36 |
+
path_separator = ';' if sys.platform == 'win32' else ':'
|
| 37 |
+
os.environ["PATH"] = f"{args.ffmpeg_path}{path_separator}{os.environ['PATH']}"
|
| 38 |
+
if not fast_check_ffmpeg():
|
| 39 |
+
print("Warning: Unable to find ffmpeg, please ensure ffmpeg is properly installed")
|
| 40 |
+
|
| 41 |
+
# Set computing device
|
| 42 |
+
device = torch.device(f"cuda:{args.gpu_id}" if torch.cuda.is_available() else "cpu")
|
| 43 |
+
# Load model weights
|
| 44 |
+
vae, unet, pe = load_all_model(
|
| 45 |
+
unet_model_path=args.unet_model_path,
|
| 46 |
+
vae_type=args.vae_type,
|
| 47 |
+
unet_config=args.unet_config,
|
| 48 |
+
device=device
|
| 49 |
+
)
|
| 50 |
+
timesteps = torch.tensor([0], device=device)
|
| 51 |
+
|
| 52 |
+
# Convert models to half precision if float16 is enabled
|
| 53 |
+
if args.use_float16:
|
| 54 |
+
pe = pe.half()
|
| 55 |
+
vae.vae = vae.vae.half()
|
| 56 |
+
unet.model = unet.model.half()
|
| 57 |
+
|
| 58 |
+
# Move models to specified device
|
| 59 |
+
pe = pe.to(device)
|
| 60 |
+
vae.vae = vae.vae.to(device)
|
| 61 |
+
unet.model = unet.model.to(device)
|
| 62 |
+
|
| 63 |
+
# Initialize audio processor and Whisper model
|
| 64 |
+
audio_processor = AudioProcessor(feature_extractor_path=args.whisper_dir)
|
| 65 |
+
weight_dtype = unet.model.dtype
|
| 66 |
+
whisper = WhisperModel.from_pretrained(args.whisper_dir)
|
| 67 |
+
whisper = whisper.to(device=device, dtype=weight_dtype).eval()
|
| 68 |
+
whisper.requires_grad_(False)
|
| 69 |
+
|
| 70 |
+
# Initialize face parser with configurable parameters based on version
|
| 71 |
+
if args.version == "v15":
|
| 72 |
+
fp = FaceParsing(
|
| 73 |
+
left_cheek_width=args.left_cheek_width,
|
| 74 |
+
right_cheek_width=args.right_cheek_width
|
| 75 |
+
)
|
| 76 |
+
else: # v1
|
| 77 |
+
fp = FaceParsing()
|
| 78 |
+
|
| 79 |
+
# Load inference configuration
|
| 80 |
+
inference_config = OmegaConf.load(args.inference_config)
|
| 81 |
+
print("Loaded inference config:", inference_config)
|
| 82 |
+
|
| 83 |
+
# Process each task
|
| 84 |
+
for task_id in inference_config:
|
| 85 |
+
try:
|
| 86 |
+
# Get task configuration
|
| 87 |
+
video_path = inference_config[task_id]["video_path"]
|
| 88 |
+
audio_path = inference_config[task_id]["audio_path"]
|
| 89 |
+
if "result_name" in inference_config[task_id]:
|
| 90 |
+
args.output_vid_name = inference_config[task_id]["result_name"]
|
| 91 |
+
|
| 92 |
+
# Set bbox_shift based on version
|
| 93 |
+
if args.version == "v15":
|
| 94 |
+
bbox_shift = 0 # v15 uses fixed bbox_shift
|
| 95 |
+
else:
|
| 96 |
+
bbox_shift = inference_config[task_id].get("bbox_shift", args.bbox_shift) # v1 uses config or default
|
| 97 |
+
|
| 98 |
+
# Set output paths
|
| 99 |
+
input_basename = os.path.basename(video_path).split('.')[0]
|
| 100 |
+
audio_basename = os.path.basename(audio_path).split('.')[0]
|
| 101 |
+
output_basename = f"{input_basename}_{audio_basename}"
|
| 102 |
+
|
| 103 |
+
# Create temporary directories
|
| 104 |
+
temp_dir = os.path.join(args.result_dir, f"{args.version}")
|
| 105 |
+
os.makedirs(temp_dir, exist_ok=True)
|
| 106 |
+
|
| 107 |
+
# Set result save paths
|
| 108 |
+
result_img_save_path = os.path.join(temp_dir, output_basename)
|
| 109 |
+
crop_coord_save_path = os.path.join(args.result_dir, "../", input_basename+".pkl")
|
| 110 |
+
os.makedirs(result_img_save_path, exist_ok=True)
|
| 111 |
+
|
| 112 |
+
# Set output video paths
|
| 113 |
+
if args.output_vid_name is None:
|
| 114 |
+
output_vid_name = os.path.join(temp_dir, output_basename + ".mp4")
|
| 115 |
+
else:
|
| 116 |
+
output_vid_name = os.path.join(temp_dir, args.output_vid_name)
|
| 117 |
+
output_vid_name_concat = os.path.join(temp_dir, output_basename + "_concat.mp4")
|
| 118 |
+
|
| 119 |
+
# Extract frames from source video
|
| 120 |
+
if get_file_type(video_path) == "video":
|
| 121 |
+
save_dir_full = os.path.join(temp_dir, input_basename)
|
| 122 |
+
os.makedirs(save_dir_full, exist_ok=True)
|
| 123 |
+
cmd = f"ffmpeg -v fatal -i {video_path} -vf fps={args.fps} -start_number 0 {save_dir_full}/%08d.png" # PATCHED: extract at target fps
|
| 124 |
+
os.system(cmd)
|
| 125 |
+
input_img_list = sorted(glob.glob(os.path.join(save_dir_full, '*.[jpJP][pnPN]*[gG]')))
|
| 126 |
+
fps = args.fps # PATCHED: use target fps instead of video fps
|
| 127 |
+
elif get_file_type(video_path) == "image":
|
| 128 |
+
input_img_list = [video_path]
|
| 129 |
+
fps = args.fps
|
| 130 |
+
elif os.path.isdir(video_path):
|
| 131 |
+
input_img_list = glob.glob(os.path.join(video_path, '*.[jpJP][pnPN]*[gG]'))
|
| 132 |
+
input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
|
| 133 |
+
fps = args.fps
|
| 134 |
+
else:
|
| 135 |
+
raise ValueError(f"{video_path} should be a video file, an image file or a directory of images")
|
| 136 |
+
|
| 137 |
+
# Extract audio features
|
| 138 |
+
whisper_input_features, librosa_length = audio_processor.get_audio_feature(audio_path)
|
| 139 |
+
whisper_chunks = audio_processor.get_whisper_chunk(
|
| 140 |
+
whisper_input_features,
|
| 141 |
+
device,
|
| 142 |
+
weight_dtype,
|
| 143 |
+
whisper,
|
| 144 |
+
librosa_length,
|
| 145 |
+
fps=fps,
|
| 146 |
+
audio_padding_length_left=args.audio_padding_length_left,
|
| 147 |
+
audio_padding_length_right=args.audio_padding_length_right,
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
# Preprocess input images
|
| 151 |
+
if os.path.exists(crop_coord_save_path) and args.use_saved_coord:
|
| 152 |
+
print("Using saved coordinates")
|
| 153 |
+
with open(crop_coord_save_path, 'rb') as f:
|
| 154 |
+
coord_list = pickle.load(f)
|
| 155 |
+
frame_list = read_imgs(input_img_list)
|
| 156 |
+
else:
|
| 157 |
+
print("Extracting landmarks... time-consuming operation")
|
| 158 |
+
coord_list, frame_list = get_landmark_and_bbox(input_img_list, bbox_shift)
|
| 159 |
+
with open(crop_coord_save_path, 'wb') as f:
|
| 160 |
+
pickle.dump(coord_list, f)
|
| 161 |
+
|
| 162 |
+
print(f"Number of frames: {len(frame_list)}")
|
| 163 |
+
|
| 164 |
+
# Process each frame
|
| 165 |
+
input_latent_list = []
|
| 166 |
+
for bbox, frame in zip(coord_list, frame_list):
|
| 167 |
+
if bbox == coord_placeholder:
|
| 168 |
+
continue
|
| 169 |
+
x1, y1, x2, y2 = bbox
|
| 170 |
+
if args.version == "v15":
|
| 171 |
+
y2 = y2 + args.extra_margin
|
| 172 |
+
y2 = min(y2, frame.shape[0])
|
| 173 |
+
crop_frame = frame[y1:y2, x1:x2]
|
| 174 |
+
crop_frame = cv2.resize(crop_frame, (256,256), interpolation=cv2.INTER_LANCZOS4)
|
| 175 |
+
latents = vae.get_latents_for_unet(crop_frame)
|
| 176 |
+
input_latent_list.append(latents)
|
| 177 |
+
|
| 178 |
+
# Smooth first and last frames
|
| 179 |
+
frame_list_cycle = frame_list + frame_list[::-1]
|
| 180 |
+
coord_list_cycle = coord_list + coord_list[::-1]
|
| 181 |
+
input_latent_list_cycle = input_latent_list + input_latent_list[::-1]
|
| 182 |
+
|
| 183 |
+
# Batch inference
|
| 184 |
+
print("Starting inference")
|
| 185 |
+
video_num = len(whisper_chunks)
|
| 186 |
+
batch_size = args.batch_size
|
| 187 |
+
gen = datagen(
|
| 188 |
+
whisper_chunks=whisper_chunks,
|
| 189 |
+
vae_encode_latents=input_latent_list_cycle,
|
| 190 |
+
batch_size=batch_size,
|
| 191 |
+
delay_frame=0,
|
| 192 |
+
device=device,
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
res_frame_list = []
|
| 196 |
+
total = int(np.ceil(float(video_num) / batch_size))
|
| 197 |
+
|
| 198 |
+
# Execute inference
|
| 199 |
+
for i, (whisper_batch, latent_batch) in enumerate(tqdm(gen, total=total)):
|
| 200 |
+
audio_feature_batch = pe(whisper_batch)
|
| 201 |
+
latent_batch = latent_batch.to(dtype=unet.model.dtype)
|
| 202 |
+
|
| 203 |
+
pred_latents = unet.model(latent_batch, timesteps, encoder_hidden_states=audio_feature_batch).sample
|
| 204 |
+
recon = vae.decode_latents(pred_latents)
|
| 205 |
+
for res_frame in recon:
|
| 206 |
+
res_frame_list.append(res_frame)
|
| 207 |
+
|
| 208 |
+
# Pad generated images to original video size
|
| 209 |
+
print("Padding generated images to original video size")
|
| 210 |
+
for i, res_frame in enumerate(tqdm(res_frame_list)):
|
| 211 |
+
bbox = coord_list_cycle[i%(len(coord_list_cycle))]
|
| 212 |
+
ori_frame = copy.deepcopy(frame_list_cycle[i%(len(frame_list_cycle))])
|
| 213 |
+
x1, y1, x2, y2 = bbox
|
| 214 |
+
if args.version == "v15":
|
| 215 |
+
y2 = y2 + args.extra_margin
|
| 216 |
+
y2 = min(y2, frame.shape[0])
|
| 217 |
+
try:
|
| 218 |
+
res_frame = cv2.resize(res_frame.astype(np.uint8), (x2-x1, y2-y1))
|
| 219 |
+
except:
|
| 220 |
+
continue
|
| 221 |
+
|
| 222 |
+
# Merge results with version-specific parameters
|
| 223 |
+
if args.version == "v15":
|
| 224 |
+
combine_frame = get_image(ori_frame, res_frame, [x1, y1, x2, y2], mode=args.parsing_mode, fp=fp)
|
| 225 |
+
else:
|
| 226 |
+
combine_frame = get_image(ori_frame, res_frame, [x1, y1, x2, y2], fp=fp)
|
| 227 |
+
cv2.imwrite(f"{result_img_save_path}/{str(i).zfill(8)}.png", combine_frame)
|
| 228 |
+
|
| 229 |
+
# Save prediction results
|
| 230 |
+
temp_vid_path = f"{temp_dir}/temp_{input_basename}_{audio_basename}.mp4"
|
| 231 |
+
cmd_img2video = f"ffmpeg -y -v warning -r {fps} -f image2 -i {result_img_save_path}/%08d.png -vcodec libx264 -vf format=yuv420p -crf 18 {temp_vid_path}"
|
| 232 |
+
print("Video generation command:", cmd_img2video)
|
| 233 |
+
os.system(cmd_img2video)
|
| 234 |
+
|
| 235 |
+
cmd_combine_audio = f"ffmpeg -y -v warning -i {audio_path} -i {temp_vid_path} {output_vid_name}"
|
| 236 |
+
print("Audio combination command:", cmd_combine_audio)
|
| 237 |
+
os.system(cmd_combine_audio)
|
| 238 |
+
|
| 239 |
+
# Clean up temporary files
|
| 240 |
+
shutil.rmtree(result_img_save_path)
|
| 241 |
+
os.remove(temp_vid_path)
|
| 242 |
+
|
| 243 |
+
shutil.rmtree(save_dir_full)
|
| 244 |
+
if not args.saved_coord:
|
| 245 |
+
os.remove(crop_coord_save_path)
|
| 246 |
+
|
| 247 |
+
print(f"Results saved to {output_vid_name}")
|
| 248 |
+
except Exception as e:
|
| 249 |
+
print("Error occurred during processing:", e)
|
| 250 |
+
|
| 251 |
+
if __name__ == "__main__":
|
| 252 |
+
parser = argparse.ArgumentParser()
|
| 253 |
+
parser.add_argument("--ffmpeg_path", type=str, default="./ffmpeg-4.4-amd64-static/", help="Path to ffmpeg executable")
|
| 254 |
+
parser.add_argument("--gpu_id", type=int, default=0, help="GPU ID to use")
|
| 255 |
+
parser.add_argument("--vae_type", type=str, default="sd-vae", help="Type of VAE model")
|
| 256 |
+
parser.add_argument("--unet_config", type=str, default="./models/musetalk/config.json", help="Path to UNet configuration file")
|
| 257 |
+
parser.add_argument("--unet_model_path", type=str, default="./models/musetalkV15/unet.pth", help="Path to UNet model weights")
|
| 258 |
+
parser.add_argument("--whisper_dir", type=str, default="./models/whisper", help="Directory containing Whisper model")
|
| 259 |
+
parser.add_argument("--inference_config", type=str, default="configs/inference/test_img.yaml", help="Path to inference configuration file")
|
| 260 |
+
parser.add_argument("--bbox_shift", type=int, default=0, help="Bounding box shift value")
|
| 261 |
+
parser.add_argument("--result_dir", default='./results', help="Directory for output results")
|
| 262 |
+
parser.add_argument("--extra_margin", type=int, default=10, help="Extra margin for face cropping")
|
| 263 |
+
parser.add_argument("--fps", type=int, default=25, help="Video frames per second")
|
| 264 |
+
parser.add_argument("--audio_padding_length_left", type=int, default=2, help="Left padding length for audio")
|
| 265 |
+
parser.add_argument("--audio_padding_length_right", type=int, default=2, help="Right padding length for audio")
|
| 266 |
+
parser.add_argument("--batch_size", type=int, default=8, help="Batch size for inference")
|
| 267 |
+
parser.add_argument("--output_vid_name", type=str, default=None, help="Name of output video file")
|
| 268 |
+
parser.add_argument("--use_saved_coord", action="store_true", help='Use saved coordinates to save time')
|
| 269 |
+
parser.add_argument("--saved_coord", action="store_true", help='Save coordinates for future use')
|
| 270 |
+
parser.add_argument("--use_float16", action="store_true", help="Use float16 for faster inference")
|
| 271 |
+
parser.add_argument("--parsing_mode", default='jaw', help="Face blending parsing mode")
|
| 272 |
+
parser.add_argument("--left_cheek_width", type=int, default=90, help="Width of left cheek region")
|
| 273 |
+
parser.add_argument("--right_cheek_width", type=int, default=90, help="Width of right cheek region")
|
| 274 |
+
parser.add_argument("--version", type=str, default="v15", choices=["v1", "v15"], help="Model version to use")
|
| 275 |
+
args = parser.parse_args()
|
| 276 |
+
main(args)
|
scripts/inference.py
CHANGED
|
@@ -213,7 +213,7 @@ def main(args):
|
|
| 213 |
x1, y1, x2, y2 = bbox
|
| 214 |
if args.version == "v15":
|
| 215 |
y2 = y2 + args.extra_margin
|
| 216 |
-
y2 = min(y2,
|
| 217 |
try:
|
| 218 |
res_frame = cv2.resize(res_frame.astype(np.uint8), (x2-x1, y2-y1))
|
| 219 |
except:
|
|
|
|
| 213 |
x1, y1, x2, y2 = bbox
|
| 214 |
if args.version == "v15":
|
| 215 |
y2 = y2 + args.extra_margin
|
| 216 |
+
y2 = min(y2, ori_frame.shape[0])
|
| 217 |
try:
|
| 218 |
res_frame = cv2.resize(res_frame.astype(np.uint8), (x2-x1, y2-y1))
|
| 219 |
except:
|
server.py
ADDED
|
@@ -0,0 +1,607 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
MuseTalk Real-Time Server
|
| 3 |
+
Servidor FastAPI para lip-sync em tempo real
|
| 4 |
+
"""
|
| 5 |
+
import os
|
| 6 |
+
import sys
|
| 7 |
+
import io
|
| 8 |
+
import time
|
| 9 |
+
import json
|
| 10 |
+
import uuid
|
| 11 |
+
import queue
|
| 12 |
+
import pickle
|
| 13 |
+
import shutil
|
| 14 |
+
import asyncio
|
| 15 |
+
import threading
|
| 16 |
+
from pathlib import Path
|
| 17 |
+
from typing import Optional
|
| 18 |
+
import tempfile
|
| 19 |
+
|
| 20 |
+
import cv2
|
| 21 |
+
import glob
|
| 22 |
+
import copy
|
| 23 |
+
import torch
|
| 24 |
+
import numpy as np
|
| 25 |
+
from tqdm import tqdm
|
| 26 |
+
from fastapi import FastAPI, UploadFile, File, Form, HTTPException, BackgroundTasks
|
| 27 |
+
from fastapi.responses import FileResponse, StreamingResponse, JSONResponse
|
| 28 |
+
from pydantic import BaseModel
|
| 29 |
+
import uvicorn
|
| 30 |
+
|
| 31 |
+
# Suppress warnings
|
| 32 |
+
import warnings
|
| 33 |
+
warnings.filterwarnings("ignore")
|
| 34 |
+
|
| 35 |
+
# MuseTalk imports
|
| 36 |
+
from musetalk.utils.utils import datagen, load_all_model
|
| 37 |
+
from musetalk.utils.blending import get_image_prepare_material, get_image_blending
|
| 38 |
+
from musetalk.utils.audio_processor import AudioProcessor
|
| 39 |
+
from musetalk.utils.preprocessing_simple import get_landmark_and_bbox, read_imgs
|
| 40 |
+
from transformers import WhisperModel
|
| 41 |
+
|
| 42 |
+
app = FastAPI(title="MuseTalk Real-Time Server", version="1.5")
|
| 43 |
+
|
| 44 |
+
# Global model instances
|
| 45 |
+
models = {}
|
| 46 |
+
avatars = {}
|
| 47 |
+
|
| 48 |
+
class AvatarConfig(BaseModel):
|
| 49 |
+
avatar_id: str
|
| 50 |
+
video_path: str
|
| 51 |
+
bbox_shift: int = 0
|
| 52 |
+
|
| 53 |
+
class InferenceRequest(BaseModel):
|
| 54 |
+
avatar_id: str
|
| 55 |
+
fps: int = 25
|
| 56 |
+
|
| 57 |
+
def video2imgs(vid_path, save_path):
|
| 58 |
+
"""Extract frames from video"""
|
| 59 |
+
cap = cv2.VideoCapture(vid_path)
|
| 60 |
+
count = 0
|
| 61 |
+
while True:
|
| 62 |
+
ret, frame = cap.read()
|
| 63 |
+
if ret:
|
| 64 |
+
cv2.imwrite(f"{save_path}/{count:08d}.png", frame)
|
| 65 |
+
count += 1
|
| 66 |
+
else:
|
| 67 |
+
break
|
| 68 |
+
cap.release()
|
| 69 |
+
return count
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
@app.on_event("startup")
|
| 73 |
+
async def load_models():
|
| 74 |
+
"""Load all models at startup"""
|
| 75 |
+
global models
|
| 76 |
+
|
| 77 |
+
print("Loading MuseTalk models...")
|
| 78 |
+
# Force CPU if FORCE_CPU env var is set or if CUDA kernels are incompatible
|
| 79 |
+
force_cpu = os.environ.get("FORCE_CPU", "0") == "1"
|
| 80 |
+
if force_cpu or not torch.cuda.is_available():
|
| 81 |
+
device = torch.device("cpu")
|
| 82 |
+
else:
|
| 83 |
+
try:
|
| 84 |
+
# Test if CUDA kernels work for this GPU
|
| 85 |
+
test_tensor = torch.zeros(1).cuda()
|
| 86 |
+
_ = test_tensor.half()
|
| 87 |
+
device = torch.device("cuda:0")
|
| 88 |
+
except RuntimeError as e:
|
| 89 |
+
print(f"CUDA kernel test failed: {e}")
|
| 90 |
+
print("Falling back to CPU...")
|
| 91 |
+
device = torch.device("cpu")
|
| 92 |
+
print(f"Using device: {device}")
|
| 93 |
+
|
| 94 |
+
# Model paths
|
| 95 |
+
unet_model_path = "./models/musetalkV15/unet.pth"
|
| 96 |
+
unet_config = "./models/musetalkV15/musetalk.json"
|
| 97 |
+
whisper_dir = "./models/whisper"
|
| 98 |
+
vae_type = "sd-vae"
|
| 99 |
+
|
| 100 |
+
# Load models
|
| 101 |
+
vae, unet, pe = load_all_model(
|
| 102 |
+
unet_model_path=unet_model_path,
|
| 103 |
+
vae_type=vae_type,
|
| 104 |
+
unet_config=unet_config,
|
| 105 |
+
device=device
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
# Move to device, use half precision only for GPU
|
| 109 |
+
if device.type == "cuda":
|
| 110 |
+
pe = pe.half().to(device)
|
| 111 |
+
vae.vae = vae.vae.half().to(device)
|
| 112 |
+
unet.model = unet.model.half().to(device)
|
| 113 |
+
else:
|
| 114 |
+
pe = pe.to(device)
|
| 115 |
+
vae.vae = vae.vae.to(device)
|
| 116 |
+
unet.model = unet.model.to(device)
|
| 117 |
+
|
| 118 |
+
# Load whisper
|
| 119 |
+
audio_processor = AudioProcessor(feature_extractor_path=whisper_dir)
|
| 120 |
+
whisper = WhisperModel.from_pretrained(whisper_dir)
|
| 121 |
+
weight_dtype = unet.model.dtype if device.type == "cuda" else torch.float32
|
| 122 |
+
whisper = whisper.to(device=device, dtype=weight_dtype).eval()
|
| 123 |
+
whisper.requires_grad_(False)
|
| 124 |
+
|
| 125 |
+
# Initialize face parser
|
| 126 |
+
from musetalk.utils.face_parsing import FaceParsing
|
| 127 |
+
fp = FaceParsing(left_cheek_width=90, right_cheek_width=90)
|
| 128 |
+
|
| 129 |
+
timesteps = torch.tensor([0], device=device)
|
| 130 |
+
|
| 131 |
+
models = {
|
| 132 |
+
"vae": vae,
|
| 133 |
+
"unet": unet,
|
| 134 |
+
"pe": pe,
|
| 135 |
+
"whisper": whisper,
|
| 136 |
+
"audio_processor": audio_processor,
|
| 137 |
+
"fp": fp,
|
| 138 |
+
"device": device,
|
| 139 |
+
"timesteps": timesteps,
|
| 140 |
+
"weight_dtype": weight_dtype
|
| 141 |
+
}
|
| 142 |
+
|
| 143 |
+
print("Models loaded successfully!")
|
| 144 |
+
|
| 145 |
+
@app.get("/")
|
| 146 |
+
async def root():
|
| 147 |
+
return {"status": "ok", "message": "MuseTalk Real-Time Server"}
|
| 148 |
+
|
| 149 |
+
@app.get("/health")
|
| 150 |
+
async def health():
|
| 151 |
+
return {
|
| 152 |
+
"status": "healthy",
|
| 153 |
+
"models_loaded": len(models) > 0,
|
| 154 |
+
"avatars_count": len(avatars),
|
| 155 |
+
"gpu_available": torch.cuda.is_available()
|
| 156 |
+
}
|
| 157 |
+
|
| 158 |
+
@app.post("/avatar/prepare")
|
| 159 |
+
async def prepare_avatar(
|
| 160 |
+
avatar_id: str = Form(...),
|
| 161 |
+
video: UploadFile = File(...),
|
| 162 |
+
bbox_shift: int = Form(0, description="Ajusta abertura da boca: positivo=mais aberto, negativo=menos aberto (-9 a 9)"),
|
| 163 |
+
extra_margin: int = Form(10, description="Margem extra para movimento do queixo"),
|
| 164 |
+
parsing_mode: str = Form("jaw", description="Modo de parsing: 'jaw' (v1.5) ou 'raw' (v1.0)"),
|
| 165 |
+
left_cheek_width: int = Form(90, description="Largura da bochecha esquerda"),
|
| 166 |
+
right_cheek_width: int = Form(90, description="Largura da bochecha direita")
|
| 167 |
+
):
|
| 168 |
+
"""Prepare an avatar from video for real-time inference"""
|
| 169 |
+
global avatars
|
| 170 |
+
|
| 171 |
+
if not models:
|
| 172 |
+
raise HTTPException(status_code=503, detail="Models not loaded")
|
| 173 |
+
|
| 174 |
+
# Save uploaded video
|
| 175 |
+
avatar_path = f"./results/v15/avatars/{avatar_id}"
|
| 176 |
+
full_imgs_path = f"{avatar_path}/full_imgs"
|
| 177 |
+
mask_out_path = f"{avatar_path}/mask"
|
| 178 |
+
|
| 179 |
+
os.makedirs(avatar_path, exist_ok=True)
|
| 180 |
+
os.makedirs(full_imgs_path, exist_ok=True)
|
| 181 |
+
os.makedirs(mask_out_path, exist_ok=True)
|
| 182 |
+
|
| 183 |
+
# Save video
|
| 184 |
+
video_path = f"{avatar_path}/source_video{Path(video.filename).suffix}"
|
| 185 |
+
with open(video_path, "wb") as f:
|
| 186 |
+
content = await video.read()
|
| 187 |
+
f.write(content)
|
| 188 |
+
|
| 189 |
+
# Extract frames
|
| 190 |
+
print(f"Extracting frames from video...")
|
| 191 |
+
frame_count = video2imgs(video_path, full_imgs_path)
|
| 192 |
+
print(f"Extracted {frame_count} frames")
|
| 193 |
+
|
| 194 |
+
input_img_list = sorted(glob.glob(os.path.join(full_imgs_path, '*.[jpJP][pnPN]*[gG]')))
|
| 195 |
+
|
| 196 |
+
print("Extracting landmarks...")
|
| 197 |
+
# bbox_shift controls mouth openness: positive=more open, negative=less open
|
| 198 |
+
coord_list_raw, frame_list_raw = get_landmark_and_bbox(input_img_list, upperbondrange=bbox_shift)
|
| 199 |
+
|
| 200 |
+
# Generate latents - filter out frames without detected faces
|
| 201 |
+
input_latent_list = []
|
| 202 |
+
valid_coord_list = []
|
| 203 |
+
valid_frame_list = []
|
| 204 |
+
coord_placeholder = (0.0, 0.0, 0.0, 0.0)
|
| 205 |
+
|
| 206 |
+
vae = models["vae"]
|
| 207 |
+
|
| 208 |
+
# Create FaceParsing with custom cheek widths for this avatar
|
| 209 |
+
from musetalk.utils.face_parsing import FaceParsing
|
| 210 |
+
fp_avatar = FaceParsing(left_cheek_width=left_cheek_width, right_cheek_width=right_cheek_width)
|
| 211 |
+
|
| 212 |
+
for bbox, frame in zip(coord_list_raw, frame_list_raw):
|
| 213 |
+
if bbox == coord_placeholder:
|
| 214 |
+
continue
|
| 215 |
+
x1, y1, x2, y2 = bbox
|
| 216 |
+
# Validate bbox dimensions
|
| 217 |
+
if x2 <= x1 or y2 <= y1:
|
| 218 |
+
continue
|
| 219 |
+
# Add extra margin for jaw movement (v1.5 feature)
|
| 220 |
+
y2 = min(y2 + extra_margin, frame.shape[0])
|
| 221 |
+
|
| 222 |
+
# Store valid frame and coordinates
|
| 223 |
+
valid_coord_list.append([x1, y1, x2, y2])
|
| 224 |
+
valid_frame_list.append(frame)
|
| 225 |
+
|
| 226 |
+
crop_frame = frame[y1:y2, x1:x2]
|
| 227 |
+
if crop_frame.size == 0:
|
| 228 |
+
valid_coord_list.pop()
|
| 229 |
+
valid_frame_list.pop()
|
| 230 |
+
continue
|
| 231 |
+
resized_crop_frame = cv2.resize(crop_frame, (256, 256), interpolation=cv2.INTER_LANCZOS4)
|
| 232 |
+
latents = vae.get_latents_for_unet(resized_crop_frame)
|
| 233 |
+
input_latent_list.append(latents)
|
| 234 |
+
|
| 235 |
+
print(f"Valid frames with detected faces: {len(valid_frame_list)}/{len(frame_list_raw)}")
|
| 236 |
+
|
| 237 |
+
if len(valid_frame_list) == 0:
|
| 238 |
+
raise HTTPException(status_code=400, detail="No faces detected in video. Please use a video with a clear frontal face.")
|
| 239 |
+
|
| 240 |
+
# Create cycles from valid frames only
|
| 241 |
+
frame_list_cycle = valid_frame_list + valid_frame_list[::-1]
|
| 242 |
+
coord_list_cycle = valid_coord_list + valid_coord_list[::-1]
|
| 243 |
+
input_latent_list_cycle = input_latent_list + input_latent_list[::-1]
|
| 244 |
+
|
| 245 |
+
# Generate masks
|
| 246 |
+
mask_list_cycle = []
|
| 247 |
+
mask_coords_list_cycle = []
|
| 248 |
+
|
| 249 |
+
print(f"Generating masks with mode={parsing_mode}...")
|
| 250 |
+
for i, frame in enumerate(tqdm(frame_list_cycle)):
|
| 251 |
+
x1, y1, x2, y2 = coord_list_cycle[i]
|
| 252 |
+
mask, crop_box = get_image_prepare_material(frame, [x1, y1, x2, y2], fp=fp_avatar, mode=parsing_mode)
|
| 253 |
+
cv2.imwrite(f"{mask_out_path}/{str(i).zfill(8)}.png", mask)
|
| 254 |
+
mask_coords_list_cycle.append(crop_box)
|
| 255 |
+
mask_list_cycle.append(mask)
|
| 256 |
+
|
| 257 |
+
# Save preprocessed data
|
| 258 |
+
with open(f"{avatar_path}/coords.pkl", 'wb') as f:
|
| 259 |
+
pickle.dump(coord_list_cycle, f)
|
| 260 |
+
|
| 261 |
+
with open(f"{avatar_path}/mask_coords.pkl", 'wb') as f:
|
| 262 |
+
pickle.dump(mask_coords_list_cycle, f)
|
| 263 |
+
|
| 264 |
+
# Save quality settings
|
| 265 |
+
quality_settings = {
|
| 266 |
+
"bbox_shift": bbox_shift,
|
| 267 |
+
"extra_margin": extra_margin,
|
| 268 |
+
"parsing_mode": parsing_mode,
|
| 269 |
+
"left_cheek_width": left_cheek_width,
|
| 270 |
+
"right_cheek_width": right_cheek_width
|
| 271 |
+
}
|
| 272 |
+
with open(f"{avatar_path}/quality_settings.json", 'w') as f:
|
| 273 |
+
json.dump(quality_settings, f)
|
| 274 |
+
|
| 275 |
+
torch.save(input_latent_list_cycle, f"{avatar_path}/latents.pt")
|
| 276 |
+
|
| 277 |
+
# Store in memory - keep latents on CPU to save GPU memory
|
| 278 |
+
input_latent_list_cpu = [lat.cpu() for lat in input_latent_list_cycle]
|
| 279 |
+
|
| 280 |
+
avatars[avatar_id] = {
|
| 281 |
+
"path": avatar_path,
|
| 282 |
+
"frame_list_cycle": frame_list_cycle,
|
| 283 |
+
"coord_list_cycle": coord_list_cycle,
|
| 284 |
+
"input_latent_list_cycle": input_latent_list_cpu,
|
| 285 |
+
"mask_list_cycle": mask_list_cycle,
|
| 286 |
+
"mask_coords_list_cycle": mask_coords_list_cycle,
|
| 287 |
+
"quality_settings": quality_settings
|
| 288 |
+
}
|
| 289 |
+
|
| 290 |
+
# Clear GPU cache after preparation
|
| 291 |
+
import gc
|
| 292 |
+
gc.collect()
|
| 293 |
+
torch.cuda.empty_cache()
|
| 294 |
+
|
| 295 |
+
return {
|
| 296 |
+
"status": "success",
|
| 297 |
+
"avatar_id": avatar_id,
|
| 298 |
+
"frame_count": len(frame_list_cycle),
|
| 299 |
+
"quality_settings": quality_settings
|
| 300 |
+
}
|
| 301 |
+
|
| 302 |
+
@app.post("/avatar/load/{avatar_id}")
|
| 303 |
+
async def load_avatar(avatar_id: str):
|
| 304 |
+
"""Load a previously prepared avatar"""
|
| 305 |
+
global avatars
|
| 306 |
+
|
| 307 |
+
avatar_path = f"./results/v15/avatars/{avatar_id}"
|
| 308 |
+
|
| 309 |
+
if not os.path.exists(avatar_path):
|
| 310 |
+
raise HTTPException(status_code=404, detail=f"Avatar {avatar_id} not found")
|
| 311 |
+
|
| 312 |
+
full_imgs_path = f"{avatar_path}/full_imgs"
|
| 313 |
+
mask_out_path = f"{avatar_path}/mask"
|
| 314 |
+
|
| 315 |
+
# Load preprocessed data
|
| 316 |
+
input_latent_list_cycle = torch.load(f"{avatar_path}/latents.pt")
|
| 317 |
+
|
| 318 |
+
with open(f"{avatar_path}/coords.pkl", 'rb') as f:
|
| 319 |
+
coord_list_cycle = pickle.load(f)
|
| 320 |
+
|
| 321 |
+
with open(f"{avatar_path}/mask_coords.pkl", 'rb') as f:
|
| 322 |
+
mask_coords_list_cycle = pickle.load(f)
|
| 323 |
+
|
| 324 |
+
# Load quality settings (with defaults for backwards compatibility)
|
| 325 |
+
quality_settings_path = f"{avatar_path}/quality_settings.json"
|
| 326 |
+
if os.path.exists(quality_settings_path):
|
| 327 |
+
with open(quality_settings_path, 'r') as f:
|
| 328 |
+
quality_settings = json.load(f)
|
| 329 |
+
else:
|
| 330 |
+
quality_settings = {
|
| 331 |
+
"bbox_shift": 0,
|
| 332 |
+
"extra_margin": 10,
|
| 333 |
+
"parsing_mode": "jaw",
|
| 334 |
+
"left_cheek_width": 90,
|
| 335 |
+
"right_cheek_width": 90
|
| 336 |
+
}
|
| 337 |
+
|
| 338 |
+
# Load frames
|
| 339 |
+
input_img_list = sorted(glob.glob(os.path.join(full_imgs_path, '*.[jpJP][pnPN]*[gG]')))
|
| 340 |
+
frame_list_cycle = read_imgs(input_img_list)
|
| 341 |
+
|
| 342 |
+
# Load masks
|
| 343 |
+
input_mask_list = sorted(glob.glob(os.path.join(mask_out_path, '*.[jpJP][pnPN]*[gG]')))
|
| 344 |
+
mask_list_cycle = read_imgs(input_mask_list)
|
| 345 |
+
|
| 346 |
+
# Keep latents on CPU to save GPU memory
|
| 347 |
+
input_latent_list_cpu = [lat.cpu() if hasattr(lat, 'cpu') else lat for lat in input_latent_list_cycle]
|
| 348 |
+
|
| 349 |
+
avatars[avatar_id] = {
|
| 350 |
+
"path": avatar_path,
|
| 351 |
+
"frame_list_cycle": frame_list_cycle,
|
| 352 |
+
"coord_list_cycle": coord_list_cycle,
|
| 353 |
+
"input_latent_list_cycle": input_latent_list_cpu,
|
| 354 |
+
"mask_list_cycle": mask_list_cycle,
|
| 355 |
+
"mask_coords_list_cycle": mask_coords_list_cycle,
|
| 356 |
+
"quality_settings": quality_settings
|
| 357 |
+
}
|
| 358 |
+
|
| 359 |
+
# Clear GPU cache
|
| 360 |
+
import gc
|
| 361 |
+
gc.collect()
|
| 362 |
+
torch.cuda.empty_cache()
|
| 363 |
+
|
| 364 |
+
return {
|
| 365 |
+
"status": "success",
|
| 366 |
+
"avatar_id": avatar_id,
|
| 367 |
+
"frame_count": len(frame_list_cycle),
|
| 368 |
+
"quality_settings": quality_settings
|
| 369 |
+
}
|
| 370 |
+
|
| 371 |
+
@app.get("/avatars")
|
| 372 |
+
async def list_avatars():
|
| 373 |
+
"""List all available avatars"""
|
| 374 |
+
avatar_dir = "./results/v15/avatars"
|
| 375 |
+
if not os.path.exists(avatar_dir):
|
| 376 |
+
return {"avatars": [], "loaded": list(avatars.keys())}
|
| 377 |
+
|
| 378 |
+
available = [d for d in os.listdir(avatar_dir) if os.path.isdir(os.path.join(avatar_dir, d))]
|
| 379 |
+
return {"avatars": available, "loaded": list(avatars.keys())}
|
| 380 |
+
|
| 381 |
+
@app.post("/inference")
|
| 382 |
+
async def inference(
|
| 383 |
+
avatar_id: str = Form(...),
|
| 384 |
+
audio: UploadFile = File(...),
|
| 385 |
+
fps: int = Form(25)
|
| 386 |
+
):
|
| 387 |
+
"""Run inference with uploaded audio and return video"""
|
| 388 |
+
|
| 389 |
+
if avatar_id not in avatars:
|
| 390 |
+
raise HTTPException(status_code=404, detail=f"Avatar {avatar_id} not loaded. Use /avatar/load first")
|
| 391 |
+
|
| 392 |
+
if not models:
|
| 393 |
+
raise HTTPException(status_code=503, detail="Models not loaded")
|
| 394 |
+
|
| 395 |
+
avatar = avatars[avatar_id]
|
| 396 |
+
device = models["device"]
|
| 397 |
+
|
| 398 |
+
# Save audio temporarily
|
| 399 |
+
with tempfile.NamedTemporaryFile(suffix=Path(audio.filename).suffix, delete=False) as tmp:
|
| 400 |
+
content = await audio.read()
|
| 401 |
+
tmp.write(content)
|
| 402 |
+
audio_path = tmp.name
|
| 403 |
+
|
| 404 |
+
try:
|
| 405 |
+
start_time = time.time()
|
| 406 |
+
|
| 407 |
+
# Extract audio features
|
| 408 |
+
audio_processor = models["audio_processor"]
|
| 409 |
+
whisper = models["whisper"]
|
| 410 |
+
weight_dtype = models["weight_dtype"]
|
| 411 |
+
|
| 412 |
+
whisper_input_features, librosa_length = audio_processor.get_audio_feature(
|
| 413 |
+
audio_path, weight_dtype=weight_dtype
|
| 414 |
+
)
|
| 415 |
+
whisper_chunks = audio_processor.get_whisper_chunk(
|
| 416 |
+
whisper_input_features,
|
| 417 |
+
device,
|
| 418 |
+
weight_dtype,
|
| 419 |
+
whisper,
|
| 420 |
+
librosa_length,
|
| 421 |
+
fps=fps,
|
| 422 |
+
audio_padding_length_left=2,
|
| 423 |
+
audio_padding_length_right=2,
|
| 424 |
+
)
|
| 425 |
+
|
| 426 |
+
print(f"Audio processing: {(time.time() - start_time)*1000:.0f}ms")
|
| 427 |
+
|
| 428 |
+
# Inference
|
| 429 |
+
vae = models["vae"]
|
| 430 |
+
unet = models["unet"]
|
| 431 |
+
pe = models["pe"]
|
| 432 |
+
timesteps = models["timesteps"]
|
| 433 |
+
|
| 434 |
+
video_num = len(whisper_chunks)
|
| 435 |
+
batch_size = 4 # Reduced batch size to save GPU memory
|
| 436 |
+
|
| 437 |
+
gen = datagen(whisper_chunks, avatar["input_latent_list_cycle"], batch_size)
|
| 438 |
+
|
| 439 |
+
result_frames = []
|
| 440 |
+
inference_start = time.time()
|
| 441 |
+
|
| 442 |
+
for i, (whisper_batch, latent_batch) in enumerate(gen):
|
| 443 |
+
audio_feature_batch = pe(whisper_batch.to(device))
|
| 444 |
+
latent_batch = latent_batch.to(device=device, dtype=unet.model.dtype)
|
| 445 |
+
|
| 446 |
+
pred_latents = unet.model(
|
| 447 |
+
latent_batch,
|
| 448 |
+
timesteps,
|
| 449 |
+
encoder_hidden_states=audio_feature_batch
|
| 450 |
+
).sample
|
| 451 |
+
|
| 452 |
+
pred_latents = pred_latents.to(device=device, dtype=vae.vae.dtype)
|
| 453 |
+
recon = vae.decode_latents(pred_latents)
|
| 454 |
+
|
| 455 |
+
for idx_in_batch, res_frame in enumerate(recon):
|
| 456 |
+
frame_idx = i * batch_size + idx_in_batch
|
| 457 |
+
if frame_idx >= video_num:
|
| 458 |
+
break
|
| 459 |
+
|
| 460 |
+
bbox = avatar["coord_list_cycle"][frame_idx % len(avatar["coord_list_cycle"])]
|
| 461 |
+
ori_frame = copy.deepcopy(avatar["frame_list_cycle"][frame_idx % len(avatar["frame_list_cycle"])])
|
| 462 |
+
x1, y1, x2, y2 = bbox
|
| 463 |
+
|
| 464 |
+
res_frame = cv2.resize(res_frame.astype(np.uint8), (x2 - x1, y2 - y1))
|
| 465 |
+
mask = avatar["mask_list_cycle"][frame_idx % len(avatar["mask_list_cycle"])]
|
| 466 |
+
mask_crop_box = avatar["mask_coords_list_cycle"][frame_idx % len(avatar["mask_coords_list_cycle"])]
|
| 467 |
+
|
| 468 |
+
combine_frame = get_image_blending(ori_frame, res_frame, bbox, mask, mask_crop_box)
|
| 469 |
+
result_frames.append(combine_frame)
|
| 470 |
+
|
| 471 |
+
print(f"Inference: {(time.time() - inference_start)*1000:.0f}ms for {video_num} frames")
|
| 472 |
+
print(f"FPS: {video_num / (time.time() - inference_start):.1f}")
|
| 473 |
+
|
| 474 |
+
# Create video
|
| 475 |
+
output_path = tempfile.mktemp(suffix=".mp4")
|
| 476 |
+
h, w = result_frames[0].shape[:2]
|
| 477 |
+
|
| 478 |
+
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
| 479 |
+
out = cv2.VideoWriter(output_path, fourcc, fps, (w, h))
|
| 480 |
+
|
| 481 |
+
for frame in result_frames:
|
| 482 |
+
out.write(frame)
|
| 483 |
+
out.release()
|
| 484 |
+
|
| 485 |
+
# Combine with audio using ffmpeg
|
| 486 |
+
final_output = tempfile.mktemp(suffix=".mp4")
|
| 487 |
+
os.system(f"ffmpeg -y -v warning -i {audio_path} -i {output_path} -c:v libx264 -c:a aac {final_output}")
|
| 488 |
+
|
| 489 |
+
os.unlink(output_path)
|
| 490 |
+
os.unlink(audio_path)
|
| 491 |
+
|
| 492 |
+
total_time = time.time() - start_time
|
| 493 |
+
print(f"Total time: {total_time*1000:.0f}ms")
|
| 494 |
+
|
| 495 |
+
return FileResponse(
|
| 496 |
+
final_output,
|
| 497 |
+
media_type="video/mp4",
|
| 498 |
+
filename=f"output_{avatar_id}.mp4",
|
| 499 |
+
headers={"X-Processing-Time": f"{total_time:.2f}s"}
|
| 500 |
+
)
|
| 501 |
+
|
| 502 |
+
except Exception as e:
|
| 503 |
+
if os.path.exists(audio_path):
|
| 504 |
+
os.unlink(audio_path)
|
| 505 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 506 |
+
|
| 507 |
+
@app.post("/inference/frames")
|
| 508 |
+
async def inference_frames(
|
| 509 |
+
avatar_id: str = Form(...),
|
| 510 |
+
audio: UploadFile = File(...),
|
| 511 |
+
fps: int = Form(25)
|
| 512 |
+
):
|
| 513 |
+
"""Run inference and return frames as JSON (for streaming)"""
|
| 514 |
+
|
| 515 |
+
if avatar_id not in avatars:
|
| 516 |
+
raise HTTPException(status_code=404, detail=f"Avatar {avatar_id} not loaded")
|
| 517 |
+
|
| 518 |
+
avatar = avatars[avatar_id]
|
| 519 |
+
device = models["device"]
|
| 520 |
+
|
| 521 |
+
# Save audio temporarily
|
| 522 |
+
with tempfile.NamedTemporaryFile(suffix=Path(audio.filename).suffix, delete=False) as tmp:
|
| 523 |
+
content = await audio.read()
|
| 524 |
+
tmp.write(content)
|
| 525 |
+
audio_path = tmp.name
|
| 526 |
+
|
| 527 |
+
try:
|
| 528 |
+
# Extract audio features
|
| 529 |
+
audio_processor = models["audio_processor"]
|
| 530 |
+
whisper = models["whisper"]
|
| 531 |
+
weight_dtype = models["weight_dtype"]
|
| 532 |
+
|
| 533 |
+
whisper_input_features, librosa_length = audio_processor.get_audio_feature(
|
| 534 |
+
audio_path, weight_dtype=weight_dtype
|
| 535 |
+
)
|
| 536 |
+
whisper_chunks = audio_processor.get_whisper_chunk(
|
| 537 |
+
whisper_input_features,
|
| 538 |
+
device,
|
| 539 |
+
weight_dtype,
|
| 540 |
+
whisper,
|
| 541 |
+
librosa_length,
|
| 542 |
+
fps=fps,
|
| 543 |
+
)
|
| 544 |
+
|
| 545 |
+
# Inference
|
| 546 |
+
vae = models["vae"]
|
| 547 |
+
unet = models["unet"]
|
| 548 |
+
pe = models["pe"]
|
| 549 |
+
timesteps = models["timesteps"]
|
| 550 |
+
|
| 551 |
+
video_num = len(whisper_chunks)
|
| 552 |
+
batch_size = 4 # Reduced batch size to save GPU memory
|
| 553 |
+
|
| 554 |
+
gen = datagen(whisper_chunks, avatar["input_latent_list_cycle"], batch_size)
|
| 555 |
+
|
| 556 |
+
frames_data = []
|
| 557 |
+
|
| 558 |
+
for i, (whisper_batch, latent_batch) in enumerate(gen):
|
| 559 |
+
audio_feature_batch = pe(whisper_batch.to(device))
|
| 560 |
+
latent_batch = latent_batch.to(device=device, dtype=unet.model.dtype)
|
| 561 |
+
|
| 562 |
+
pred_latents = unet.model(
|
| 563 |
+
latent_batch,
|
| 564 |
+
timesteps,
|
| 565 |
+
encoder_hidden_states=audio_feature_batch
|
| 566 |
+
).sample
|
| 567 |
+
|
| 568 |
+
pred_latents = pred_latents.to(device=device, dtype=vae.vae.dtype)
|
| 569 |
+
recon = vae.decode_latents(pred_latents)
|
| 570 |
+
|
| 571 |
+
for idx_in_batch, res_frame in enumerate(recon):
|
| 572 |
+
frame_idx = i * batch_size + idx_in_batch
|
| 573 |
+
if frame_idx >= video_num:
|
| 574 |
+
break
|
| 575 |
+
|
| 576 |
+
bbox = avatar["coord_list_cycle"][frame_idx % len(avatar["coord_list_cycle"])]
|
| 577 |
+
ori_frame = copy.deepcopy(avatar["frame_list_cycle"][frame_idx % len(avatar["frame_list_cycle"])])
|
| 578 |
+
x1, y1, x2, y2 = bbox
|
| 579 |
+
|
| 580 |
+
res_frame = cv2.resize(res_frame.astype(np.uint8), (x2 - x1, y2 - y1))
|
| 581 |
+
mask = avatar["mask_list_cycle"][frame_idx % len(avatar["mask_list_cycle"])]
|
| 582 |
+
mask_crop_box = avatar["mask_coords_list_cycle"][frame_idx % len(avatar["mask_coords_list_cycle"])]
|
| 583 |
+
|
| 584 |
+
combine_frame = get_image_blending(ori_frame, res_frame, bbox, mask, mask_crop_box)
|
| 585 |
+
|
| 586 |
+
# Encode frame as JPEG
|
| 587 |
+
_, buffer = cv2.imencode('.jpg', combine_frame, [cv2.IMWRITE_JPEG_QUALITY, 85])
|
| 588 |
+
import base64
|
| 589 |
+
frame_b64 = base64.b64encode(buffer).decode('utf-8')
|
| 590 |
+
frames_data.append(frame_b64)
|
| 591 |
+
|
| 592 |
+
os.unlink(audio_path)
|
| 593 |
+
|
| 594 |
+
return {
|
| 595 |
+
"frames": frames_data,
|
| 596 |
+
"fps": fps,
|
| 597 |
+
"total_frames": len(frames_data)
|
| 598 |
+
}
|
| 599 |
+
|
| 600 |
+
except Exception as e:
|
| 601 |
+
if os.path.exists(audio_path):
|
| 602 |
+
os.unlink(audio_path)
|
| 603 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 604 |
+
|
| 605 |
+
|
| 606 |
+
if __name__ == "__main__":
|
| 607 |
+
uvicorn.run(app, host="0.0.0.0", port=8000)
|