Spaces:
Running
Running
| """ | |
| Vertex AI client for TTS synthesis using Google Cloud AI Platform. | |
| """ | |
| import os | |
| import json | |
| import logging | |
| import requests | |
| from typing import Optional, Dict, Any, Tuple | |
| from google.cloud import aiplatform | |
| from google.oauth2 import service_account | |
| from dotenv import load_dotenv | |
| # Load environment variables from .env file (for local development) | |
| load_dotenv() | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| class VertexAIClient: | |
| """Client for interacting with Vertex AI TTS endpoint.""" | |
| def __init__(self): | |
| """Initialize the Vertex AI client.""" | |
| self.endpoint = None | |
| self.credentials = None | |
| self.initialized = False | |
| def _load_credentials(self) -> Optional[service_account.Credentials]: | |
| """ | |
| Load credentials from auth_string environment variable. | |
| Returns: | |
| Credentials object or None if failed | |
| """ | |
| try: | |
| auth_string = os.environ.get("auth_string") | |
| if not auth_string: | |
| logger.warning("auth_string environment variable not found") | |
| return None | |
| # Parse JSON credentials | |
| credentials_dict = json.loads(auth_string) | |
| credentials = service_account.Credentials.from_service_account_info( | |
| credentials_dict | |
| ) | |
| logger.info("Successfully loaded credentials from auth_string") | |
| return credentials | |
| except json.JSONDecodeError as e: | |
| logger.error(f"Failed to parse auth_string JSON: {e}") | |
| return None | |
| except Exception as e: | |
| logger.error(f"Failed to load credentials: {e}") | |
| return None | |
| def initialize(self) -> bool: | |
| """ | |
| Initialize Vertex AI and find the zipvoice_base_distill endpoint. | |
| Returns: | |
| True if initialization successful, False otherwise | |
| """ | |
| if self.initialized: | |
| return True | |
| try: | |
| # Load credentials | |
| self.credentials = self._load_credentials() | |
| if not self.credentials: | |
| logger.error("Cannot initialize without credentials") | |
| return False | |
| # Initialize Vertex AI | |
| aiplatform.init( | |
| project="desivocalprod01", | |
| location="asia-south1", | |
| credentials=self.credentials, | |
| ) | |
| logger.info("Vertex AI initialized for project desivocalprod01") | |
| # Find distill endpoint | |
| for endpoint in aiplatform.Endpoint.list(): | |
| if endpoint.display_name == "zipvoice_base_distill": | |
| self.endpoint = endpoint | |
| logger.info(f"Found zipvoice_base_distill endpoint: {endpoint.resource_name}") | |
| break | |
| # Check if endpoint is found | |
| if not self.endpoint: | |
| logger.error("zipvoice_base_distill endpoint not found in Vertex AI") | |
| return False | |
| self.initialized = True | |
| return True | |
| except Exception as e: | |
| logger.error(f"Failed to initialize Vertex AI: {e}") | |
| return False | |
| def get_voices(self) -> Tuple[bool, Optional[Dict[str, Any]]]: | |
| """ | |
| Get available voices from local configuration file. | |
| Note: Vertex AI endpoint doesn't have a separate /voices API. | |
| Voices are configured in voices_config.json | |
| Returns: | |
| Tuple of (success, voices_dict) | |
| voices_dict format: {"voices": {"voice_id": {"name": "...", "gender": "..."}}} | |
| """ | |
| try: | |
| # Try to load from voices_config.json in current directory | |
| import pathlib | |
| config_path = pathlib.Path(__file__).parent / "voices_config.json" | |
| if config_path.exists(): | |
| logger.info(f"Loading voices from {config_path}") | |
| with open(config_path, "r") as f: | |
| voices_data = json.load(f) | |
| logger.info(f"Successfully loaded {len(voices_data.get('voices', {}))} voices from config") | |
| return True, voices_data | |
| else: | |
| logger.warning(f"voices_config.json not found at {config_path}") | |
| # Return empty voices list | |
| return True, {"voices": {}} | |
| except Exception as e: | |
| logger.error(f"Failed to load voices config: {e}") | |
| return False, None | |
| def synthesize(self, text: str, voice_id: str, timeout: int = 60) -> Tuple[bool, Optional[bytes], Optional[Dict[str, Any]]]: | |
| """ | |
| Synthesize speech from text using Vertex AI distill endpoint. | |
| Args: | |
| text: Text to synthesize | |
| voice_id: Voice ID to use | |
| timeout: Request timeout in seconds | |
| Returns: | |
| Tuple of (success, audio_bytes, metrics) | |
| """ | |
| if not self.initialized: | |
| if not self.initialize(): | |
| return False, None, None | |
| try: | |
| logger.info(f"Synthesizing text (length: {len(text)}) with voice {voice_id} using distill model") | |
| response = self.endpoint.raw_predict( | |
| body=json.dumps({ | |
| "text": text, | |
| "voice_id": voice_id, | |
| "model_type": "distill", | |
| }), | |
| headers={"Content-Type": "application/json"}, | |
| ) | |
| # Parse JSON response | |
| result = json.loads(response.text) if hasattr(response, 'text') else response | |
| logger.info(f"Vertex AI response: {result}") | |
| # Check if synthesis was successful | |
| if result.get("success"): | |
| audio_url = result.get("audio_url") | |
| metrics = result.get("metrics") | |
| if not audio_url: | |
| logger.error("No audio_url in successful response") | |
| return False, None, None | |
| # Download audio from URL | |
| logger.info(f"Downloading audio from: {audio_url}") | |
| audio_response = requests.get(audio_url, timeout=timeout) | |
| if audio_response.status_code == 200: | |
| audio_data = audio_response.content | |
| logger.info(f"Successfully downloaded audio ({len(audio_data)} bytes)") | |
| return True, audio_data, metrics | |
| else: | |
| logger.error(f"Failed to download audio: HTTP {audio_response.status_code}") | |
| return False, None, None | |
| else: | |
| error_msg = result.get("message", "Unknown error") | |
| logger.error(f"Synthesis failed: {error_msg}") | |
| return False, None, None | |
| except Exception as e: | |
| logger.error(f"Failed to synthesize speech with Vertex AI: {e}") | |
| return False, None, None | |
| # Global instance | |
| _vertex_client = None | |
| def get_vertex_client() -> VertexAIClient: | |
| """Get or create the global Vertex AI client instance.""" | |
| global _vertex_client | |
| if _vertex_client is None: | |
| _vertex_client = VertexAIClient() | |
| return _vertex_client | |