Ringg-TTS-v1.0 / vertex_client.py
utkarshshukla2912's picture
added distill model
8b08d3c
raw
history blame
12 kB
"""
Vertex AI client for TTS synthesis using Google Cloud AI Platform.
"""
import os
import json
import logging
import requests
from typing import Optional, Dict, Any, Tuple, Generator
from concurrent.futures import ThreadPoolExecutor, as_completed
from google.cloud import aiplatform
from google.oauth2 import service_account
from dotenv import load_dotenv
# Load environment variables from .env file (for local development)
load_dotenv()
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class VertexAIClient:
"""Client for interacting with Vertex AI TTS endpoint."""
def __init__(self):
"""Initialize the Vertex AI client."""
self.endpoint = None
self.endpoint_distill = None
self.credentials = None
self.initialized = False
def _load_credentials(self) -> Optional[service_account.Credentials]:
"""
Load credentials from auth_string environment variable.
Returns:
Credentials object or None if failed
"""
try:
auth_string = os.environ.get("auth_string")
if not auth_string:
logger.warning("auth_string environment variable not found")
return None
# Parse JSON credentials
credentials_dict = json.loads(auth_string)
credentials = service_account.Credentials.from_service_account_info(
credentials_dict
)
logger.info("Successfully loaded credentials from auth_string")
return credentials
except json.JSONDecodeError as e:
logger.error(f"Failed to parse auth_string JSON: {e}")
return None
except Exception as e:
logger.error(f"Failed to load credentials: {e}")
return None
def initialize(self) -> bool:
"""
Initialize Vertex AI and find the zipvoice and zipvoice_base_distill endpoints.
Returns:
True if initialization successful, False otherwise
"""
if self.initialized:
return True
try:
# Load credentials
self.credentials = self._load_credentials()
if not self.credentials:
logger.error("Cannot initialize without credentials")
return False
# Initialize Vertex AI
aiplatform.init(
project="desivocalprod01",
location="asia-south1",
credentials=self.credentials,
)
logger.info("Vertex AI initialized for project desivocalprod01")
# Find both endpoints
for endpoint in aiplatform.Endpoint.list():
if endpoint.display_name == "zipvoice":
self.endpoint = endpoint
logger.info(f"Found zipvoice endpoint: {endpoint.resource_name}")
elif endpoint.display_name == "zipvoice_base_distill":
self.endpoint_distill = endpoint
logger.info(f"Found zipvoice_base_distill endpoint: {endpoint.resource_name}")
# Check if at least the base endpoint is found
if not self.endpoint:
logger.error("zipvoice endpoint not found in Vertex AI")
return False
# Warn if distill endpoint is not found but continue
if not self.endpoint_distill:
logger.warning("zipvoice_base_distill endpoint not found - distill model will not be available")
self.initialized = True
return True
except Exception as e:
logger.error(f"Failed to initialize Vertex AI: {e}")
return False
def get_voices(self) -> Tuple[bool, Optional[Dict[str, Any]]]:
"""
Get available voices from local configuration file.
Note: Vertex AI endpoint doesn't have a separate /voices API.
Voices are configured in voices_config.json
Returns:
Tuple of (success, voices_dict)
voices_dict format: {"voices": {"voice_id": {"name": "...", "gender": "..."}}}
"""
try:
# Try to load from voices_config.json in current directory
import pathlib
config_path = pathlib.Path(__file__).parent / "voices_config.json"
if config_path.exists():
logger.info(f"Loading voices from {config_path}")
with open(config_path, "r") as f:
voices_data = json.load(f)
logger.info(f"Successfully loaded {len(voices_data.get('voices', {}))} voices from config")
return True, voices_data
else:
logger.warning(f"voices_config.json not found at {config_path}")
# Return empty voices list
return True, {"voices": {}}
except Exception as e:
logger.error(f"Failed to load voices config: {e}")
return False, None
def synthesize(self, text: str, voice_id: str, timeout: int = 60) -> Tuple[bool, Optional[bytes], Optional[Dict[str, Any]]]:
"""
Synthesize speech from text using Vertex AI endpoint.
Args:
text: Text to synthesize
voice_id: Voice ID to use
timeout: Request timeout in seconds
Returns:
Tuple of (success, audio_bytes, metrics)
"""
if not self.initialized:
if not self.initialize():
return False, None, None
try:
logger.info(f"Synthesizing text (length: {len(text)}) with voice {voice_id}")
response = self.endpoint.raw_predict(
body=json.dumps({
"text": text,
"voice_id": voice_id,
}),
headers={"Content-Type": "application/json"},
)
# Parse JSON response
result = json.loads(response.text) if hasattr(response, 'text') else response
logger.info(f"Vertex AI response: {result}")
# Check if synthesis was successful
if result.get("success"):
audio_url = result.get("audio_url")
metrics = result.get("metrics")
if not audio_url:
logger.error("No audio_url in successful response")
return False, None, None
# Download audio from URL
logger.info(f"Downloading audio from: {audio_url}")
audio_response = requests.get(audio_url, timeout=timeout)
if audio_response.status_code == 200:
audio_data = audio_response.content
logger.info(f"Successfully downloaded audio ({len(audio_data)} bytes)")
return True, audio_data, metrics
else:
logger.error(f"Failed to download audio: HTTP {audio_response.status_code}")
return False, None, None
else:
error_msg = result.get("message", "Unknown error")
logger.error(f"Synthesis failed: {error_msg}")
return False, None, None
except Exception as e:
logger.error(f"Failed to synthesize speech with Vertex AI: {e}")
return False, None, None
def synthesize_distill(self, text: str, voice_id: str, timeout: int = 60) -> Tuple[bool, Optional[bytes], Optional[Dict[str, Any]]]:
"""
Synthesize speech from text using Vertex AI distill endpoint.
Args:
text: Text to synthesize
voice_id: Voice ID to use
timeout: Request timeout in seconds
Returns:
Tuple of (success, audio_bytes, metrics)
"""
if not self.initialized:
if not self.initialize():
return False, None, None
if not self.endpoint_distill:
logger.error("Distill endpoint not available")
return False, None, None
try:
logger.info(f"Synthesizing text (length: {len(text)}) with voice {voice_id} using distill model")
response = self.endpoint_distill.raw_predict(
body=json.dumps({
"text": text,
"voice_id": voice_id,
"model_type": "distill",
}),
headers={"Content-Type": "application/json"},
)
# Parse JSON response
result = json.loads(response.text) if hasattr(response, 'text') else response
logger.info(f"Vertex AI distill response: {result}")
# Check if synthesis was successful
if result.get("success"):
audio_url = result.get("audio_url")
metrics = result.get("metrics")
if not audio_url:
logger.error("No audio_url in successful response")
return False, None, None
# Download audio from URL
logger.info(f"Downloading audio from: {audio_url}")
audio_response = requests.get(audio_url, timeout=timeout)
if audio_response.status_code == 200:
audio_data = audio_response.content
logger.info(f"Successfully downloaded audio ({len(audio_data)} bytes)")
return True, audio_data, metrics
else:
logger.error(f"Failed to download audio: HTTP {audio_response.status_code}")
return False, None, None
else:
error_msg = result.get("message", "Unknown error")
logger.error(f"Synthesis failed: {error_msg}")
return False, None, None
except Exception as e:
logger.error(f"Failed to synthesize speech with Vertex AI distill: {e}")
return False, None, None
def synthesize_parallel(self, text: str, voice_id: str, timeout: int = 60) -> Generator[Tuple[str, bool, Optional[bytes], Optional[Dict[str, Any]]], None, None]:
"""
Synthesize speech from text using both base and distill endpoints in parallel.
Yields results as they arrive (doesn't wait for both to complete).
Args:
text: Text to synthesize
voice_id: Voice ID to use
timeout: Request timeout in seconds
Yields:
Tuple of (model_type, success, audio_bytes, metrics)
model_type is either "base" or "distill"
"""
if not self.initialized:
if not self.initialize():
logger.error("Failed to initialize client for parallel synthesis")
return
# Create executor for parallel execution
with ThreadPoolExecutor(max_workers=2) as executor:
# Submit both tasks
futures = {}
# Always submit base model
futures[executor.submit(self.synthesize, text, voice_id, timeout)] = "base"
# Submit distill model if available
if self.endpoint_distill:
futures[executor.submit(self.synthesize_distill, text, voice_id, timeout)] = "distill"
# Yield results as they complete
for future in as_completed(futures):
model_type = futures[future]
try:
success, audio_bytes, metrics = future.result()
yield model_type, success, audio_bytes, metrics
except Exception as e:
logger.error(f"Error in parallel synthesis for {model_type}: {e}")
yield model_type, False, None, None
# Global instance
_vertex_client = None
def get_vertex_client() -> VertexAIClient:
"""Get or create the global Vertex AI client instance."""
global _vertex_client
if _vertex_client is None:
_vertex_client = VertexAIClient()
return _vertex_client