utkarshshukla2912 commited on
Commit
4806882
·
1 Parent(s): 9baf492

remove base inference

Browse files
Files changed (3) hide show
  1. app.py +63 -126
  2. generation_counter.json +1 -1
  3. vertex_client.py +10 -122
app.py CHANGED
@@ -258,31 +258,17 @@ with gr.Blocks(
258
  show_label=False,
259
  )
260
 
261
- # Side-by-side comparison of Base and Distill models
262
- gr.Markdown("### 🎧 Audio Results Comparison")
263
- with gr.Row():
264
- with gr.Column(scale=1):
265
- # gr.Markdown("#### Base Model")
266
- audio_output_base = gr.Audio(label="Base Model Audio", type="filepath")
267
- status_base = gr.Markdown("", visible=True)
268
- metrics_header_base = gr.Markdown("**📊 Metrics**", visible=False)
269
- metrics_output_base = gr.Code(
270
- label="Base Metrics", language="json", interactive=False, visible=False
271
- )
272
-
273
- with gr.Column(scale=1):
274
- # gr.Markdown("#### Distill Model")
275
- audio_output_distill = gr.Audio(
276
- label="Distill Model Audio", type="filepath"
277
- )
278
- status_distill = gr.Markdown("", visible=True)
279
- metrics_header_distill = gr.Markdown("**📊 Metrics**", visible=False)
280
- metrics_output_distill = gr.Code(
281
- label="Distill Metrics",
282
- language="json",
283
- interactive=False,
284
- visible=False,
285
- )
286
 
287
  generate_btn = gr.Button("🎬 Generate Speech", variant="primary", size="lg")
288
 
@@ -315,15 +301,11 @@ with gr.Blocks(
315
  return "", "Character count: 0 / 300"
316
 
317
  def on_generate(text, voice_display):
318
- """Generate speech using both base and distill models in parallel."""
319
  # Validate inputs
320
  if not text or not text.strip():
321
  error_msg = "⚠️ Please enter some text"
322
  yield (
323
- None,
324
- error_msg,
325
- gr.update(visible=False),
326
- gr.update(visible=False),
327
  None,
328
  error_msg,
329
  gr.update(visible=False),
@@ -336,10 +318,6 @@ with gr.Blocks(
336
  if not voice_id:
337
  error_msg = "⚠️ Please select a voice"
338
  yield (
339
- None,
340
- error_msg,
341
- gr.update(visible=False),
342
- gr.update(visible=False),
343
  None,
344
  error_msg,
345
  gr.update(visible=False),
@@ -348,101 +326,64 @@ with gr.Blocks(
348
  )
349
  return
350
 
351
- # Initialize state for both models
352
- results = {
353
- "base": {"audio": None, "status": "⏳ Loading...", "metrics": None},
354
- "distill": {"audio": None, "status": "⏳ Loading...", "metrics": None},
355
- }
356
-
357
  # Show loading state initially
358
  yield (
359
  None,
360
- results["base"]["status"],
361
- gr.update(visible=False),
362
- gr.update(visible=False),
363
- None,
364
- results["distill"]["status"],
365
  gr.update(visible=False),
366
  gr.update(visible=False),
367
  f"**🌍 Generations:** {load_counter()}",
368
  )
369
 
370
- # Use parallel synthesis
371
  vertex_client = get_vertex_client()
372
- counter_incremented = False
373
-
374
- for (
375
- model_type,
376
- success,
377
- audio_bytes,
378
- metrics,
379
- ) in vertex_client.synthesize_parallel(text, voice_id):
380
- if success and audio_bytes:
381
- # Save audio file in system temp directory
382
- temp_dir = tempfile.gettempdir()
383
- audio_file = os.path.join(
384
- temp_dir, f"ringg_{model_type}_{str(uuid.uuid4())}.wav"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
385
  )
386
- with open(audio_file, "wb") as f:
387
- f.write(audio_bytes)
388
-
389
- # Increment counter only once (for the first successful result)
390
- if not counter_incremented:
391
- new_count = increment_counter()
392
- counter_incremented = True
393
- else:
394
- new_count = load_counter()
395
-
396
- # Format metrics
397
- metrics_json = ""
398
- has_metrics = False
399
- if metrics:
400
- has_metrics = True
401
- metrics_json = json.dumps(
402
- {
403
- "total_time": f"{metrics.get('t', 0):.3f}s",
404
- "rtf": f"{metrics.get('rtf', 0):.4f}",
405
- "audio_duration": f"{metrics.get('wav_seconds', 0):.2f}s",
406
- "vocoder_time": f"{metrics.get('t_vocoder', 0):.3f}s",
407
- "no_vocoder_time": f"{metrics.get('t_no_vocoder', 0):.3f}s",
408
- "rtf_no_vocoder": f"{metrics.get('rtf_no_vocoder', 0):.4f}",
409
- },
410
- indent=2,
411
- )
412
-
413
- # Update the corresponding model result
414
- results[model_type] = {
415
- "audio": audio_file,
416
- "status": "",
417
- "metrics": metrics_json,
418
- "has_metrics": has_metrics,
419
- }
420
- else:
421
- # Update failed model
422
- results[model_type] = {
423
- "audio": None,
424
- "status": "❌ Failed to generate",
425
- "metrics": "",
426
- "has_metrics": False,
427
- }
428
-
429
- # Yield updated state for both models
430
  yield (
431
- results["base"]["audio"],
432
- results["base"]["status"],
433
- gr.update(visible=results["base"].get("has_metrics", False)),
434
- gr.update(
435
- value=results["base"]["metrics"],
436
- visible=results["base"].get("has_metrics", False),
437
- ),
438
- results["distill"]["audio"],
439
- results["distill"]["status"],
440
- gr.update(visible=results["distill"].get("has_metrics", False)),
441
- gr.update(
442
- value=results["distill"]["metrics"],
443
- visible=results["distill"].get("has_metrics", False),
444
- ),
445
- f"**🌍 Generations:** {new_count if counter_incremented else load_counter()}",
446
  )
447
 
448
  def refresh_counter_on_load():
@@ -475,14 +416,10 @@ with gr.Blocks(
475
  fn=on_generate,
476
  inputs=[text_input, voice_dropdown],
477
  outputs=[
478
- audio_output_base,
479
- status_base,
480
- metrics_header_base,
481
- metrics_output_base,
482
- audio_output_distill,
483
- status_distill,
484
- metrics_header_distill,
485
- metrics_output_distill,
486
  generation_counter,
487
  ],
488
  concurrency_limit=2,
 
258
  show_label=False,
259
  )
260
 
261
+ # Audio output section
262
+ gr.Markdown("### 🎧 Audio Result")
263
+ audio_output = gr.Audio(label="Generated Audio", type="filepath")
264
+ status = gr.Markdown("", visible=True)
265
+ metrics_header = gr.Markdown("**📊 Metrics**", visible=False)
266
+ metrics_output = gr.Code(
267
+ label="Performance Metrics",
268
+ language="json",
269
+ interactive=False,
270
+ visible=False,
271
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
272
 
273
  generate_btn = gr.Button("🎬 Generate Speech", variant="primary", size="lg")
274
 
 
301
  return "", "Character count: 0 / 300"
302
 
303
  def on_generate(text, voice_display):
304
+ """Generate speech using the distill model."""
305
  # Validate inputs
306
  if not text or not text.strip():
307
  error_msg = "⚠️ Please enter some text"
308
  yield (
 
 
 
 
309
  None,
310
  error_msg,
311
  gr.update(visible=False),
 
318
  if not voice_id:
319
  error_msg = "⚠️ Please select a voice"
320
  yield (
 
 
 
 
321
  None,
322
  error_msg,
323
  gr.update(visible=False),
 
326
  )
327
  return
328
 
 
 
 
 
 
 
329
  # Show loading state initially
330
  yield (
331
  None,
332
+ "⏳ Loading...",
 
 
 
 
333
  gr.update(visible=False),
334
  gr.update(visible=False),
335
  f"**🌍 Generations:** {load_counter()}",
336
  )
337
 
338
+ # Synthesize speech
339
  vertex_client = get_vertex_client()
340
+ success, audio_bytes, metrics = vertex_client.synthesize(text, voice_id)
341
+
342
+ if success and audio_bytes:
343
+ # Save audio file in system temp directory
344
+ temp_dir = tempfile.gettempdir()
345
+ audio_file = os.path.join(
346
+ temp_dir, f"ringg_{str(uuid.uuid4())}.wav"
347
+ )
348
+ with open(audio_file, "wb") as f:
349
+ f.write(audio_bytes)
350
+
351
+ # Increment counter
352
+ new_count = increment_counter()
353
+
354
+ # Format metrics
355
+ metrics_json = ""
356
+ has_metrics = False
357
+ if metrics:
358
+ has_metrics = True
359
+ metrics_json = json.dumps(
360
+ {
361
+ "total_time": f"{metrics.get('t', 0):.3f}s",
362
+ "rtf": f"{metrics.get('rtf', 0):.4f}",
363
+ "audio_duration": f"{metrics.get('wav_seconds', 0):.2f}s",
364
+ "vocoder_time": f"{metrics.get('t_vocoder', 0):.3f}s",
365
+ "no_vocoder_time": f"{metrics.get('t_no_vocoder', 0):.3f}s",
366
+ "rtf_no_vocoder": f"{metrics.get('rtf_no_vocoder', 0):.4f}",
367
+ },
368
+ indent=2,
369
  )
370
+
371
+ # Yield success result
372
+ yield (
373
+ audio_file,
374
+ "",
375
+ gr.update(visible=has_metrics),
376
+ gr.update(value=metrics_json, visible=has_metrics),
377
+ f"**🌍 Generations:** {new_count}",
378
+ )
379
+ else:
380
+ # Yield failure result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
381
  yield (
382
+ None,
383
+ "❌ Failed to generate",
384
+ gr.update(visible=False),
385
+ gr.update(visible=False),
386
+ f"**🌍 Generations:** {load_counter()}",
 
 
 
 
 
 
 
 
 
 
387
  )
388
 
389
  def refresh_counter_on_load():
 
416
  fn=on_generate,
417
  inputs=[text_input, voice_dropdown],
418
  outputs=[
419
+ audio_output,
420
+ status,
421
+ metrics_header,
422
+ metrics_output,
 
 
 
 
423
  generation_counter,
424
  ],
425
  concurrency_limit=2,
generation_counter.json CHANGED
@@ -1 +1 @@
1
- {"count": 10, "last_updated": 1762780862.430711}
 
1
+ {"count": 11, "last_updated": 1763749917.869355}
vertex_client.py CHANGED
@@ -5,8 +5,7 @@ import os
5
  import json
6
  import logging
7
  import requests
8
- from typing import Optional, Dict, Any, Tuple, Generator
9
- from concurrent.futures import ThreadPoolExecutor, as_completed
10
  from google.cloud import aiplatform
11
  from google.oauth2 import service_account
12
  from dotenv import load_dotenv
@@ -25,7 +24,6 @@ class VertexAIClient:
25
  def __init__(self):
26
  """Initialize the Vertex AI client."""
27
  self.endpoint = None
28
- self.endpoint_distill = None
29
  self.credentials = None
30
  self.initialized = False
31
 
@@ -59,7 +57,7 @@ class VertexAIClient:
59
 
60
  def initialize(self) -> bool:
61
  """
62
- Initialize Vertex AI and find the zipvoice and zipvoice_base_distill endpoints.
63
 
64
  Returns:
65
  True if initialization successful, False otherwise
@@ -82,24 +80,18 @@ class VertexAIClient:
82
  )
83
  logger.info("Vertex AI initialized for project desivocalprod01")
84
 
85
- # Find both endpoints
86
  for endpoint in aiplatform.Endpoint.list():
87
- if endpoint.display_name == "zipvoice":
88
  self.endpoint = endpoint
89
- logger.info(f"Found zipvoice endpoint: {endpoint.resource_name}")
90
- elif endpoint.display_name == "zipvoice_base_distill":
91
- self.endpoint_distill = endpoint
92
  logger.info(f"Found zipvoice_base_distill endpoint: {endpoint.resource_name}")
 
93
 
94
- # Check if at least the base endpoint is found
95
  if not self.endpoint:
96
- logger.error("zipvoice endpoint not found in Vertex AI")
97
  return False
98
 
99
- # Warn if distill endpoint is not found but continue
100
- if not self.endpoint_distill:
101
- logger.warning("zipvoice_base_distill endpoint not found - distill model will not be available")
102
-
103
  self.initialized = True
104
  return True
105
 
@@ -139,65 +131,6 @@ class VertexAIClient:
139
  return False, None
140
 
141
  def synthesize(self, text: str, voice_id: str, timeout: int = 60) -> Tuple[bool, Optional[bytes], Optional[Dict[str, Any]]]:
142
- """
143
- Synthesize speech from text using Vertex AI endpoint.
144
-
145
- Args:
146
- text: Text to synthesize
147
- voice_id: Voice ID to use
148
- timeout: Request timeout in seconds
149
-
150
- Returns:
151
- Tuple of (success, audio_bytes, metrics)
152
- """
153
- if not self.initialized:
154
- if not self.initialize():
155
- return False, None, None
156
-
157
- try:
158
- logger.info(f"Synthesizing text (length: {len(text)}) with voice {voice_id}")
159
- response = self.endpoint.raw_predict(
160
- body=json.dumps({
161
- "text": text,
162
- "voice_id": voice_id,
163
- }),
164
- headers={"Content-Type": "application/json"},
165
- )
166
-
167
- # Parse JSON response
168
- result = json.loads(response.text) if hasattr(response, 'text') else response
169
- logger.info(f"Vertex AI response: {result}")
170
-
171
- # Check if synthesis was successful
172
- if result.get("success"):
173
- audio_url = result.get("audio_url")
174
- metrics = result.get("metrics")
175
-
176
- if not audio_url:
177
- logger.error("No audio_url in successful response")
178
- return False, None, None
179
-
180
- # Download audio from URL
181
- logger.info(f"Downloading audio from: {audio_url}")
182
- audio_response = requests.get(audio_url, timeout=timeout)
183
-
184
- if audio_response.status_code == 200:
185
- audio_data = audio_response.content
186
- logger.info(f"Successfully downloaded audio ({len(audio_data)} bytes)")
187
- return True, audio_data, metrics
188
- else:
189
- logger.error(f"Failed to download audio: HTTP {audio_response.status_code}")
190
- return False, None, None
191
- else:
192
- error_msg = result.get("message", "Unknown error")
193
- logger.error(f"Synthesis failed: {error_msg}")
194
- return False, None, None
195
-
196
- except Exception as e:
197
- logger.error(f"Failed to synthesize speech with Vertex AI: {e}")
198
- return False, None, None
199
-
200
- def synthesize_distill(self, text: str, voice_id: str, timeout: int = 60) -> Tuple[bool, Optional[bytes], Optional[Dict[str, Any]]]:
201
  """
202
  Synthesize speech from text using Vertex AI distill endpoint.
203
 
@@ -213,13 +146,9 @@ class VertexAIClient:
213
  if not self.initialize():
214
  return False, None, None
215
 
216
- if not self.endpoint_distill:
217
- logger.error("Distill endpoint not available")
218
- return False, None, None
219
-
220
  try:
221
  logger.info(f"Synthesizing text (length: {len(text)}) with voice {voice_id} using distill model")
222
- response = self.endpoint_distill.raw_predict(
223
  body=json.dumps({
224
  "text": text,
225
  "voice_id": voice_id,
@@ -230,7 +159,7 @@ class VertexAIClient:
230
 
231
  # Parse JSON response
232
  result = json.loads(response.text) if hasattr(response, 'text') else response
233
- logger.info(f"Vertex AI distill response: {result}")
234
 
235
  # Check if synthesis was successful
236
  if result.get("success"):
@@ -258,50 +187,9 @@ class VertexAIClient:
258
  return False, None, None
259
 
260
  except Exception as e:
261
- logger.error(f"Failed to synthesize speech with Vertex AI distill: {e}")
262
  return False, None, None
263
 
264
- def synthesize_parallel(self, text: str, voice_id: str, timeout: int = 60) -> Generator[Tuple[str, bool, Optional[bytes], Optional[Dict[str, Any]]], None, None]:
265
- """
266
- Synthesize speech from text using both base and distill endpoints in parallel.
267
-
268
- Yields results as they arrive (doesn't wait for both to complete).
269
-
270
- Args:
271
- text: Text to synthesize
272
- voice_id: Voice ID to use
273
- timeout: Request timeout in seconds
274
-
275
- Yields:
276
- Tuple of (model_type, success, audio_bytes, metrics)
277
- model_type is either "base" or "distill"
278
- """
279
- if not self.initialized:
280
- if not self.initialize():
281
- logger.error("Failed to initialize client for parallel synthesis")
282
- return
283
-
284
- # Create executor for parallel execution
285
- with ThreadPoolExecutor(max_workers=2) as executor:
286
- # Submit both tasks
287
- futures = {}
288
-
289
- # Always submit base model
290
- futures[executor.submit(self.synthesize, text, voice_id, timeout)] = "base"
291
-
292
- # Submit distill model if available
293
- if self.endpoint_distill:
294
- futures[executor.submit(self.synthesize_distill, text, voice_id, timeout)] = "distill"
295
-
296
- # Yield results as they complete
297
- for future in as_completed(futures):
298
- model_type = futures[future]
299
- try:
300
- success, audio_bytes, metrics = future.result()
301
- yield model_type, success, audio_bytes, metrics
302
- except Exception as e:
303
- logger.error(f"Error in parallel synthesis for {model_type}: {e}")
304
- yield model_type, False, None, None
305
 
306
 
307
  # Global instance
 
5
  import json
6
  import logging
7
  import requests
8
+ from typing import Optional, Dict, Any, Tuple
 
9
  from google.cloud import aiplatform
10
  from google.oauth2 import service_account
11
  from dotenv import load_dotenv
 
24
  def __init__(self):
25
  """Initialize the Vertex AI client."""
26
  self.endpoint = None
 
27
  self.credentials = None
28
  self.initialized = False
29
 
 
57
 
58
  def initialize(self) -> bool:
59
  """
60
+ Initialize Vertex AI and find the zipvoice_base_distill endpoint.
61
 
62
  Returns:
63
  True if initialization successful, False otherwise
 
80
  )
81
  logger.info("Vertex AI initialized for project desivocalprod01")
82
 
83
+ # Find distill endpoint
84
  for endpoint in aiplatform.Endpoint.list():
85
+ if endpoint.display_name == "zipvoice_base_distill":
86
  self.endpoint = endpoint
 
 
 
87
  logger.info(f"Found zipvoice_base_distill endpoint: {endpoint.resource_name}")
88
+ break
89
 
90
+ # Check if endpoint is found
91
  if not self.endpoint:
92
+ logger.error("zipvoice_base_distill endpoint not found in Vertex AI")
93
  return False
94
 
 
 
 
 
95
  self.initialized = True
96
  return True
97
 
 
131
  return False, None
132
 
133
  def synthesize(self, text: str, voice_id: str, timeout: int = 60) -> Tuple[bool, Optional[bytes], Optional[Dict[str, Any]]]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
  """
135
  Synthesize speech from text using Vertex AI distill endpoint.
136
 
 
146
  if not self.initialize():
147
  return False, None, None
148
 
 
 
 
 
149
  try:
150
  logger.info(f"Synthesizing text (length: {len(text)}) with voice {voice_id} using distill model")
151
+ response = self.endpoint.raw_predict(
152
  body=json.dumps({
153
  "text": text,
154
  "voice_id": voice_id,
 
159
 
160
  # Parse JSON response
161
  result = json.loads(response.text) if hasattr(response, 'text') else response
162
+ logger.info(f"Vertex AI response: {result}")
163
 
164
  # Check if synthesis was successful
165
  if result.get("success"):
 
187
  return False, None, None
188
 
189
  except Exception as e:
190
+ logger.error(f"Failed to synthesize speech with Vertex AI: {e}")
191
  return False, None, None
192
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
 
194
 
195
  # Global instance