Ringg-TTS-v1.0 / app.py
utkarshshukla2912's picture
reduced text size
9baf492
raw
history blame
18.5 kB
import gradio as gr
import json
import os
from pathlib import Path
import uuid
import fcntl
import time
import tempfile
from vertex_client import get_vertex_client
# gr.NO_RELOAD = False
# Counter persistence file
COUNTER_FILE = Path("generation_counter.json")
# Example texts
EXAMPLE_TEXT_ENGLISH = "Welcome to Ringg TTS! This is a text to speech system that can convert your text into natural-sounding audio. Try it out with your own content!"
EXAMPLE_TEXT_HINDI = "नमस्ते! मैं रिंग टीटीएस हूँ। मैं आपके टेक्स्ट को प्राकृतिक आवाज़ में बदल सकता हूँ। कृपया अपना टेक्स्ट यहाँ लिखें और सुनें।"
EXAMPLE_TEXT_MIXED = "Hello दोस्तों! Welcome to Ringg TTS. यह एक बहुत ही शानदार text to speech system है जो Hindi और English दोनों languages को support करता है।"
def load_counter():
"""Load universal generation counter from file (thread-safe)"""
try:
if COUNTER_FILE.exists():
with open(COUNTER_FILE, "r") as f:
# Try to acquire shared lock for reading
try:
fcntl.flock(f.fileno(), fcntl.LOCK_SH)
data = json.load(f)
fcntl.flock(f.fileno(), fcntl.LOCK_UN)
return data.get("count", 0)
except Exception:
# If locking fails, just read without lock
f.seek(0)
data = json.load(f)
return data.get("count", 0)
except Exception as e:
print(f"Error loading counter: {e}")
return 0
def save_counter(count):
"""Save universal generation counter to file (thread-safe)"""
try:
# Use file locking to prevent race conditions with multiple users
with open(COUNTER_FILE, "w") as f:
try:
fcntl.flock(f.fileno(), fcntl.LOCK_EX)
json.dump({"count": count, "last_updated": time.time()}, f)
f.flush()
os.fsync(f.fileno())
fcntl.flock(f.fileno(), fcntl.LOCK_UN)
except Exception:
# If locking fails, just write without lock
json.dump({"count": count, "last_updated": time.time()}, f)
f.flush()
except Exception as e:
print(f"Error saving counter: {e}")
def increment_counter():
"""Atomically increment and return the new counter value"""
try:
# Read current value, increment, and save atomically
with open(COUNTER_FILE, "r+" if COUNTER_FILE.exists() else "w+") as f:
try:
fcntl.flock(f.fileno(), fcntl.LOCK_EX)
# Read current count
f.seek(0)
try:
data = json.load(f)
current_count = data.get("count", 0)
except Exception:
current_count = 0
# Increment
new_count = current_count + 1
# Write back
f.seek(0)
f.truncate()
json.dump({"count": new_count, "last_updated": time.time()}, f)
f.flush()
os.fsync(f.fileno())
fcntl.flock(f.fileno(), fcntl.LOCK_UN)
return new_count
except Exception:
# Fallback without locking
f.seek(0)
try:
data = json.load(f)
current_count = data.get("count", 0)
except Exception:
current_count = 0
new_count = current_count + 1
f.seek(0)
f.truncate()
json.dump({"count": new_count, "last_updated": time.time()}, f)
f.flush()
return new_count
except Exception as e:
print(f"Error incrementing counter: {e}")
return 0
def get_voices():
"""Fetch available voices from Vertex AI"""
try:
vertex_client = get_vertex_client()
success, voices_response = vertex_client.get_voices()
if success and voices_response:
print("✅ Fetched voices from Vertex AI")
voices_data = voices_response.get("voices", {})
# Create a list of tuples (display_name, voice_id)
voices = []
for voice_id, voice_info in voices_data.items():
name = voice_info.get("name", "Unknown")
gender = voice_info.get("gender", "N/A")
display_name = f"{name} ({gender})"
voices.append((display_name, voice_id))
return sorted(voices, key=lambda x: x[0])
else:
print("❌ Failed to fetch voices from Vertex AI")
return []
except Exception as e:
print(f"❌ Error fetching voices from Vertex AI: {e}")
return []
def synthesize_speech(text, voice_id):
"""Synthesize speech from text using Vertex AI"""
if not text or not text.strip():
return None, "⚠️ Please enter some text", "", "", "", "", "", ""
if not voice_id:
return None, "⚠️ Please select a voice", "", "", "", "", "", ""
# Print input text length
text_length = len(text)
print(f"Input text length: {text_length} characters")
try:
vertex_client = get_vertex_client()
success, audio_bytes, metrics = vertex_client.synthesize(
text, voice_id, timeout=60
)
if success and audio_bytes:
print("✅ Synthesized audio using Vertex AI")
# Save binary audio to temp file in system temp directory
temp_dir = tempfile.gettempdir()
audio_file = os.path.join(temp_dir, f"ringg_{str(uuid.uuid4())}.wav")
with open(audio_file, "wb") as f:
f.write(audio_bytes)
# Format metrics if available
if metrics:
total_time = f"{metrics.get('t', 0):.3f}s"
rtf = f"{metrics.get('rtf', 0):.4f}"
wav_duration = f"{metrics.get('wav_seconds', 0):.2f}s"
vocoder_time = f"{metrics.get('t_vocoder', 0):.3f}s"
no_vocoder_time = f"{metrics.get('t_no_vocoder', 0):.3f}s"
rtf_no_vocoder = f"{metrics.get('rtf_no_vocoder', 0):.4f}"
else:
total_time = rtf = wav_duration = vocoder_time = no_vocoder_time = (
rtf_no_vocoder
) = ""
status_msg = ""
return (
audio_file,
status_msg,
total_time,
rtf,
wav_duration,
vocoder_time,
no_vocoder_time,
rtf_no_vocoder,
)
else:
return None, "❌ Failed to generate audio", "", "", "", "", "", ""
except Exception as e:
print(f"❌ Vertex AI synthesis failed: {e}")
return None, f"❌ Error: {str(e)}", "", "", "", "", "", ""
# Load initial counter value
initial_counter = load_counter()
# Create Gradio interface
with gr.Blocks(
theme=gr.themes.Base(
font=[gr.themes.GoogleFont("Source Sans Pro"), "Arial", "sans-serif"]
),
css=".gradio-container {max-width: none !important;}",
) as demo:
# Title with Health Status
with gr.Row():
with gr.Column(scale=4):
audio_image = gr.HTML(
value="""
<div style="display: flex; align-items: center; gap: 10px;">
<img style="width: 50px; height: 50px; background-color: white; border-radius: 10%;" src="https://storage.googleapis.com/desivocal-prod/desi-vocal/ringg.svg" alt="Logo">
<h1 style="margin: 0;">Ringg Squirrel TTS v1.0 🐿️</h1>
</div>
"""
)
with gr.Column(scale=1):
generation_counter = gr.Markdown(
f"**🌍 Generations since last commit:** {initial_counter}",
elem_id="counter",
)
# Best Practices Section
gr.Markdown("""
## 📝 Best Practices for Best Results
- **Supported Languages:** Hindi and English only
- **Check spelling carefully:** Misspelled words may be mispronounced
- **Punctuation matters:** Use proper punctuation for natural pauses and intonation
- **Technical terms:** Extremely rare or specialized technical terms might be mispronounced
- **Numbers & dates:** Write numbers as words for better pronunciation (e.g., "twenty-five" instead of "25")
""")
# Input Section - Text, Voice, and Character Count grouped together
with gr.Group():
# Text Input
text_input = gr.Textbox(
label="Text (max 300 characters)",
placeholder="Type or paste your text here (max 300 characters)...",
lines=6,
max_lines=10,
max_length=300,
)
# Voice Selection
voices = get_voices()
voice_choices = {display: vid for display, vid in voices}
voice_dropdown = gr.Dropdown(
choices=list(voice_choices.keys()),
label="Choose a voice style",
info=f"{len(voices)} voices available",
value=list(voice_choices.keys())[0] if voices else None,
show_label=False,
)
# Character count display
char_count = gr.Code(
"Character count: 0 / 300",
show_line_numbers=False,
show_label=False,
)
# Side-by-side comparison of Base and Distill models
gr.Markdown("### 🎧 Audio Results Comparison")
with gr.Row():
with gr.Column(scale=1):
# gr.Markdown("#### Base Model")
audio_output_base = gr.Audio(label="Base Model Audio", type="filepath")
status_base = gr.Markdown("", visible=True)
metrics_header_base = gr.Markdown("**📊 Metrics**", visible=False)
metrics_output_base = gr.Code(
label="Base Metrics", language="json", interactive=False, visible=False
)
with gr.Column(scale=1):
# gr.Markdown("#### Distill Model")
audio_output_distill = gr.Audio(
label="Distill Model Audio", type="filepath"
)
status_distill = gr.Markdown("", visible=True)
metrics_header_distill = gr.Markdown("**📊 Metrics**", visible=False)
metrics_output_distill = gr.Code(
label="Distill Metrics",
language="json",
interactive=False,
visible=False,
)
generate_btn = gr.Button("🎬 Generate Speech", variant="primary", size="lg")
with gr.Row():
example_btn1 = gr.Button("English Example", size="sm")
example_btn2 = gr.Button("Hindi Example", size="sm")
example_btn3 = gr.Button("Mixed Example", size="sm")
# Footer
gr.Markdown("---")
gr.Markdown("# 🙏 Acknowledgements")
# gr.Markdown("- Based on [ZipVoice](https://github.com/k2-fsa/ZipVoice)")
gr.Markdown(
"- Special thanks to [@jeremylee12](https://huggingface.co/jeremylee12) for his contributions"
)
# Event Handlers
def update_char_count(text):
"""Update character count as user types"""
count = len(text) if text else 0
return f"Character count: {count} / 300"
def load_example_text(example_text):
"""Load example text and update character count"""
count = len(example_text)
return example_text, f"Character count: {count} / 300"
def clear_text():
"""Clear text input"""
return "", "Character count: 0 / 300"
def on_generate(text, voice_display):
"""Generate speech using both base and distill models in parallel."""
# Validate inputs
if not text or not text.strip():
error_msg = "⚠️ Please enter some text"
yield (
None,
error_msg,
gr.update(visible=False),
gr.update(visible=False),
None,
error_msg,
gr.update(visible=False),
gr.update(visible=False),
f"**🌍 Generations:** {load_counter()}",
)
return
voice_id = voice_choices.get(voice_display)
if not voice_id:
error_msg = "⚠️ Please select a voice"
yield (
None,
error_msg,
gr.update(visible=False),
gr.update(visible=False),
None,
error_msg,
gr.update(visible=False),
gr.update(visible=False),
f"**🌍 Generations:** {load_counter()}",
)
return
# Initialize state for both models
results = {
"base": {"audio": None, "status": "⏳ Loading...", "metrics": None},
"distill": {"audio": None, "status": "⏳ Loading...", "metrics": None},
}
# Show loading state initially
yield (
None,
results["base"]["status"],
gr.update(visible=False),
gr.update(visible=False),
None,
results["distill"]["status"],
gr.update(visible=False),
gr.update(visible=False),
f"**🌍 Generations:** {load_counter()}",
)
# Use parallel synthesis
vertex_client = get_vertex_client()
counter_incremented = False
for (
model_type,
success,
audio_bytes,
metrics,
) in vertex_client.synthesize_parallel(text, voice_id):
if success and audio_bytes:
# Save audio file in system temp directory
temp_dir = tempfile.gettempdir()
audio_file = os.path.join(
temp_dir, f"ringg_{model_type}_{str(uuid.uuid4())}.wav"
)
with open(audio_file, "wb") as f:
f.write(audio_bytes)
# Increment counter only once (for the first successful result)
if not counter_incremented:
new_count = increment_counter()
counter_incremented = True
else:
new_count = load_counter()
# Format metrics
metrics_json = ""
has_metrics = False
if metrics:
has_metrics = True
metrics_json = json.dumps(
{
"total_time": f"{metrics.get('t', 0):.3f}s",
"rtf": f"{metrics.get('rtf', 0):.4f}",
"audio_duration": f"{metrics.get('wav_seconds', 0):.2f}s",
"vocoder_time": f"{metrics.get('t_vocoder', 0):.3f}s",
"no_vocoder_time": f"{metrics.get('t_no_vocoder', 0):.3f}s",
"rtf_no_vocoder": f"{metrics.get('rtf_no_vocoder', 0):.4f}",
},
indent=2,
)
# Update the corresponding model result
results[model_type] = {
"audio": audio_file,
"status": "",
"metrics": metrics_json,
"has_metrics": has_metrics,
}
else:
# Update failed model
results[model_type] = {
"audio": None,
"status": "❌ Failed to generate",
"metrics": "",
"has_metrics": False,
}
# Yield updated state for both models
yield (
results["base"]["audio"],
results["base"]["status"],
gr.update(visible=results["base"].get("has_metrics", False)),
gr.update(
value=results["base"]["metrics"],
visible=results["base"].get("has_metrics", False),
),
results["distill"]["audio"],
results["distill"]["status"],
gr.update(visible=results["distill"].get("has_metrics", False)),
gr.update(
value=results["distill"]["metrics"],
visible=results["distill"].get("has_metrics", False),
),
f"**🌍 Generations:** {new_count if counter_incremented else load_counter()}",
)
def refresh_counter_on_load():
"""Refresh the universal generation counter when the UI loads/reloads"""
return f"**🌍 Generations since last reload:** {load_counter()}"
# Update character count on text input change
text_input.change(fn=update_char_count, inputs=[text_input], outputs=[char_count])
# Example button clicks
example_btn1.click(
fn=lambda: load_example_text(EXAMPLE_TEXT_ENGLISH),
inputs=None,
outputs=[text_input, char_count],
)
example_btn2.click(
fn=lambda: load_example_text(EXAMPLE_TEXT_HINDI),
inputs=None,
outputs=[text_input, char_count],
)
example_btn3.click(
fn=lambda: load_example_text(EXAMPLE_TEXT_MIXED),
inputs=None,
outputs=[text_input, char_count],
)
generate_btn.click(
fn=on_generate,
inputs=[text_input, voice_dropdown],
outputs=[
audio_output_base,
status_base,
metrics_header_base,
metrics_output_base,
audio_output_distill,
status_distill,
metrics_header_distill,
metrics_output_distill,
generation_counter,
],
concurrency_limit=2,
concurrency_id="synthesis",
)
# Refresh global generation counter on page load/refresh
demo.load(fn=refresh_counter_on_load, inputs=None, outputs=[generation_counter])
if __name__ == "__main__":
demo.queue(default_concurrency_limit=2, max_size=20)
demo.launch(share=False, server_name="0.0.0.0", server_port=7860, debug=True)