Ringg-TTS-v1.0 / vertex_client.py
utkarshshukla2912's picture
remove base inference
4806882
"""
Vertex AI client for TTS synthesis using Google Cloud AI Platform.
"""
import os
import json
import logging
import requests
from typing import Optional, Dict, Any, Tuple
from google.cloud import aiplatform
from google.oauth2 import service_account
from dotenv import load_dotenv
# Load environment variables from .env file (for local development)
load_dotenv()
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class VertexAIClient:
"""Client for interacting with Vertex AI TTS endpoint."""
def __init__(self):
"""Initialize the Vertex AI client."""
self.endpoint = None
self.credentials = None
self.initialized = False
def _load_credentials(self) -> Optional[service_account.Credentials]:
"""
Load credentials from auth_string environment variable.
Returns:
Credentials object or None if failed
"""
try:
auth_string = os.environ.get("auth_string")
if not auth_string:
logger.warning("auth_string environment variable not found")
return None
# Parse JSON credentials
credentials_dict = json.loads(auth_string)
credentials = service_account.Credentials.from_service_account_info(
credentials_dict
)
logger.info("Successfully loaded credentials from auth_string")
return credentials
except json.JSONDecodeError as e:
logger.error(f"Failed to parse auth_string JSON: {e}")
return None
except Exception as e:
logger.error(f"Failed to load credentials: {e}")
return None
def initialize(self) -> bool:
"""
Initialize Vertex AI and find the zipvoice_base_distill endpoint.
Returns:
True if initialization successful, False otherwise
"""
if self.initialized:
return True
try:
# Load credentials
self.credentials = self._load_credentials()
if not self.credentials:
logger.error("Cannot initialize without credentials")
return False
# Initialize Vertex AI
aiplatform.init(
project="desivocalprod01",
location="asia-south1",
credentials=self.credentials,
)
logger.info("Vertex AI initialized for project desivocalprod01")
# Find distill endpoint
for endpoint in aiplatform.Endpoint.list():
if endpoint.display_name == "zipvoice_base_distill":
self.endpoint = endpoint
logger.info(f"Found zipvoice_base_distill endpoint: {endpoint.resource_name}")
break
# Check if endpoint is found
if not self.endpoint:
logger.error("zipvoice_base_distill endpoint not found in Vertex AI")
return False
self.initialized = True
return True
except Exception as e:
logger.error(f"Failed to initialize Vertex AI: {e}")
return False
def get_voices(self) -> Tuple[bool, Optional[Dict[str, Any]]]:
"""
Get available voices from local configuration file.
Note: Vertex AI endpoint doesn't have a separate /voices API.
Voices are configured in voices_config.json
Returns:
Tuple of (success, voices_dict)
voices_dict format: {"voices": {"voice_id": {"name": "...", "gender": "..."}}}
"""
try:
# Try to load from voices_config.json in current directory
import pathlib
config_path = pathlib.Path(__file__).parent / "voices_config.json"
if config_path.exists():
logger.info(f"Loading voices from {config_path}")
with open(config_path, "r") as f:
voices_data = json.load(f)
logger.info(f"Successfully loaded {len(voices_data.get('voices', {}))} voices from config")
return True, voices_data
else:
logger.warning(f"voices_config.json not found at {config_path}")
# Return empty voices list
return True, {"voices": {}}
except Exception as e:
logger.error(f"Failed to load voices config: {e}")
return False, None
def synthesize(self, text: str, voice_id: str, timeout: int = 60) -> Tuple[bool, Optional[bytes], Optional[Dict[str, Any]]]:
"""
Synthesize speech from text using Vertex AI distill endpoint.
Args:
text: Text to synthesize
voice_id: Voice ID to use
timeout: Request timeout in seconds
Returns:
Tuple of (success, audio_bytes, metrics)
"""
if not self.initialized:
if not self.initialize():
return False, None, None
try:
logger.info(f"Synthesizing text (length: {len(text)}) with voice {voice_id} using distill model")
response = self.endpoint.raw_predict(
body=json.dumps({
"text": text,
"voice_id": voice_id,
"model_type": "distill",
}),
headers={"Content-Type": "application/json"},
)
# Parse JSON response
result = json.loads(response.text) if hasattr(response, 'text') else response
logger.info(f"Vertex AI response: {result}")
# Check if synthesis was successful
if result.get("success"):
audio_url = result.get("audio_url")
metrics = result.get("metrics")
if not audio_url:
logger.error("No audio_url in successful response")
return False, None, None
# Download audio from URL
logger.info(f"Downloading audio from: {audio_url}")
audio_response = requests.get(audio_url, timeout=timeout)
if audio_response.status_code == 200:
audio_data = audio_response.content
logger.info(f"Successfully downloaded audio ({len(audio_data)} bytes)")
return True, audio_data, metrics
else:
logger.error(f"Failed to download audio: HTTP {audio_response.status_code}")
return False, None, None
else:
error_msg = result.get("message", "Unknown error")
logger.error(f"Synthesis failed: {error_msg}")
return False, None, None
except Exception as e:
logger.error(f"Failed to synthesize speech with Vertex AI: {e}")
return False, None, None
# Global instance
_vertex_client = None
def get_vertex_client() -> VertexAIClient:
"""Get or create the global Vertex AI client instance."""
global _vertex_client
if _vertex_client is None:
_vertex_client = VertexAIClient()
return _vertex_client