Ringg-TTS-v1.0 / app.py
utkarshshukla2912's picture
remove base inference
4806882
raw
history blame
15.7 kB
import gradio as gr
import json
import os
from pathlib import Path
import uuid
import fcntl
import time
import tempfile
from vertex_client import get_vertex_client
# gr.NO_RELOAD = False
# Counter persistence file
COUNTER_FILE = Path("generation_counter.json")
# Example texts
EXAMPLE_TEXT_ENGLISH = "Welcome to Ringg TTS! This is a text to speech system that can convert your text into natural-sounding audio. Try it out with your own content!"
EXAMPLE_TEXT_HINDI = "नमस्ते! मैं रिंग टीटीएस हूँ। मैं आपके टेक्स्ट को प्राकृतिक आवाज़ में बदल सकता हूँ। कृपया अपना टेक्स्ट यहाँ लिखें और सुनें।"
EXAMPLE_TEXT_MIXED = "Hello दोस्तों! Welcome to Ringg TTS. यह एक बहुत ही शानदार text to speech system है जो Hindi और English दोनों languages को support करता है।"
def load_counter():
"""Load universal generation counter from file (thread-safe)"""
try:
if COUNTER_FILE.exists():
with open(COUNTER_FILE, "r") as f:
# Try to acquire shared lock for reading
try:
fcntl.flock(f.fileno(), fcntl.LOCK_SH)
data = json.load(f)
fcntl.flock(f.fileno(), fcntl.LOCK_UN)
return data.get("count", 0)
except Exception:
# If locking fails, just read without lock
f.seek(0)
data = json.load(f)
return data.get("count", 0)
except Exception as e:
print(f"Error loading counter: {e}")
return 0
def save_counter(count):
"""Save universal generation counter to file (thread-safe)"""
try:
# Use file locking to prevent race conditions with multiple users
with open(COUNTER_FILE, "w") as f:
try:
fcntl.flock(f.fileno(), fcntl.LOCK_EX)
json.dump({"count": count, "last_updated": time.time()}, f)
f.flush()
os.fsync(f.fileno())
fcntl.flock(f.fileno(), fcntl.LOCK_UN)
except Exception:
# If locking fails, just write without lock
json.dump({"count": count, "last_updated": time.time()}, f)
f.flush()
except Exception as e:
print(f"Error saving counter: {e}")
def increment_counter():
"""Atomically increment and return the new counter value"""
try:
# Read current value, increment, and save atomically
with open(COUNTER_FILE, "r+" if COUNTER_FILE.exists() else "w+") as f:
try:
fcntl.flock(f.fileno(), fcntl.LOCK_EX)
# Read current count
f.seek(0)
try:
data = json.load(f)
current_count = data.get("count", 0)
except Exception:
current_count = 0
# Increment
new_count = current_count + 1
# Write back
f.seek(0)
f.truncate()
json.dump({"count": new_count, "last_updated": time.time()}, f)
f.flush()
os.fsync(f.fileno())
fcntl.flock(f.fileno(), fcntl.LOCK_UN)
return new_count
except Exception:
# Fallback without locking
f.seek(0)
try:
data = json.load(f)
current_count = data.get("count", 0)
except Exception:
current_count = 0
new_count = current_count + 1
f.seek(0)
f.truncate()
json.dump({"count": new_count, "last_updated": time.time()}, f)
f.flush()
return new_count
except Exception as e:
print(f"Error incrementing counter: {e}")
return 0
def get_voices():
"""Fetch available voices from Vertex AI"""
try:
vertex_client = get_vertex_client()
success, voices_response = vertex_client.get_voices()
if success and voices_response:
print("✅ Fetched voices from Vertex AI")
voices_data = voices_response.get("voices", {})
# Create a list of tuples (display_name, voice_id)
voices = []
for voice_id, voice_info in voices_data.items():
name = voice_info.get("name", "Unknown")
gender = voice_info.get("gender", "N/A")
display_name = f"{name} ({gender})"
voices.append((display_name, voice_id))
return sorted(voices, key=lambda x: x[0])
else:
print("❌ Failed to fetch voices from Vertex AI")
return []
except Exception as e:
print(f"❌ Error fetching voices from Vertex AI: {e}")
return []
def synthesize_speech(text, voice_id):
"""Synthesize speech from text using Vertex AI"""
if not text or not text.strip():
return None, "⚠️ Please enter some text", "", "", "", "", "", ""
if not voice_id:
return None, "⚠️ Please select a voice", "", "", "", "", "", ""
# Print input text length
text_length = len(text)
print(f"Input text length: {text_length} characters")
try:
vertex_client = get_vertex_client()
success, audio_bytes, metrics = vertex_client.synthesize(
text, voice_id, timeout=60
)
if success and audio_bytes:
print("✅ Synthesized audio using Vertex AI")
# Save binary audio to temp file in system temp directory
temp_dir = tempfile.gettempdir()
audio_file = os.path.join(temp_dir, f"ringg_{str(uuid.uuid4())}.wav")
with open(audio_file, "wb") as f:
f.write(audio_bytes)
# Format metrics if available
if metrics:
total_time = f"{metrics.get('t', 0):.3f}s"
rtf = f"{metrics.get('rtf', 0):.4f}"
wav_duration = f"{metrics.get('wav_seconds', 0):.2f}s"
vocoder_time = f"{metrics.get('t_vocoder', 0):.3f}s"
no_vocoder_time = f"{metrics.get('t_no_vocoder', 0):.3f}s"
rtf_no_vocoder = f"{metrics.get('rtf_no_vocoder', 0):.4f}"
else:
total_time = rtf = wav_duration = vocoder_time = no_vocoder_time = (
rtf_no_vocoder
) = ""
status_msg = ""
return (
audio_file,
status_msg,
total_time,
rtf,
wav_duration,
vocoder_time,
no_vocoder_time,
rtf_no_vocoder,
)
else:
return None, "❌ Failed to generate audio", "", "", "", "", "", ""
except Exception as e:
print(f"❌ Vertex AI synthesis failed: {e}")
return None, f"❌ Error: {str(e)}", "", "", "", "", "", ""
# Load initial counter value
initial_counter = load_counter()
# Create Gradio interface
with gr.Blocks(
theme=gr.themes.Base(
font=[gr.themes.GoogleFont("Source Sans Pro"), "Arial", "sans-serif"]
),
css=".gradio-container {max-width: none !important;}",
) as demo:
# Title with Health Status
with gr.Row():
with gr.Column(scale=4):
audio_image = gr.HTML(
value="""
<div style="display: flex; align-items: center; gap: 10px;">
<img style="width: 50px; height: 50px; background-color: white; border-radius: 10%;" src="https://storage.googleapis.com/desivocal-prod/desi-vocal/ringg.svg" alt="Logo">
<h1 style="margin: 0;">Ringg Squirrel TTS v1.0 🐿️</h1>
</div>
"""
)
with gr.Column(scale=1):
generation_counter = gr.Markdown(
f"**🌍 Generations since last commit:** {initial_counter}",
elem_id="counter",
)
# Best Practices Section
gr.Markdown("""
## 📝 Best Practices for Best Results
- **Supported Languages:** Hindi and English only
- **Check spelling carefully:** Misspelled words may be mispronounced
- **Punctuation matters:** Use proper punctuation for natural pauses and intonation
- **Technical terms:** Extremely rare or specialized technical terms might be mispronounced
- **Numbers & dates:** Write numbers as words for better pronunciation (e.g., "twenty-five" instead of "25")
""")
# Input Section - Text, Voice, and Character Count grouped together
with gr.Group():
# Text Input
text_input = gr.Textbox(
label="Text (max 300 characters)",
placeholder="Type or paste your text here (max 300 characters)...",
lines=6,
max_lines=10,
max_length=300,
)
# Voice Selection
voices = get_voices()
voice_choices = {display: vid for display, vid in voices}
voice_dropdown = gr.Dropdown(
choices=list(voice_choices.keys()),
label="Choose a voice style",
info=f"{len(voices)} voices available",
value=list(voice_choices.keys())[0] if voices else None,
show_label=False,
)
# Character count display
char_count = gr.Code(
"Character count: 0 / 300",
show_line_numbers=False,
show_label=False,
)
# Audio output section
gr.Markdown("### 🎧 Audio Result")
audio_output = gr.Audio(label="Generated Audio", type="filepath")
status = gr.Markdown("", visible=True)
metrics_header = gr.Markdown("**📊 Metrics**", visible=False)
metrics_output = gr.Code(
label="Performance Metrics",
language="json",
interactive=False,
visible=False,
)
generate_btn = gr.Button("🎬 Generate Speech", variant="primary", size="lg")
with gr.Row():
example_btn1 = gr.Button("English Example", size="sm")
example_btn2 = gr.Button("Hindi Example", size="sm")
example_btn3 = gr.Button("Mixed Example", size="sm")
# Footer
gr.Markdown("---")
gr.Markdown("# 🙏 Acknowledgements")
# gr.Markdown("- Based on [ZipVoice](https://github.com/k2-fsa/ZipVoice)")
gr.Markdown(
"- Special thanks to [@jeremylee12](https://huggingface.co/jeremylee12) for his contributions"
)
# Event Handlers
def update_char_count(text):
"""Update character count as user types"""
count = len(text) if text else 0
return f"Character count: {count} / 300"
def load_example_text(example_text):
"""Load example text and update character count"""
count = len(example_text)
return example_text, f"Character count: {count} / 300"
def clear_text():
"""Clear text input"""
return "", "Character count: 0 / 300"
def on_generate(text, voice_display):
"""Generate speech using the distill model."""
# Validate inputs
if not text or not text.strip():
error_msg = "⚠️ Please enter some text"
yield (
None,
error_msg,
gr.update(visible=False),
gr.update(visible=False),
f"**🌍 Generations:** {load_counter()}",
)
return
voice_id = voice_choices.get(voice_display)
if not voice_id:
error_msg = "⚠️ Please select a voice"
yield (
None,
error_msg,
gr.update(visible=False),
gr.update(visible=False),
f"**🌍 Generations:** {load_counter()}",
)
return
# Show loading state initially
yield (
None,
"⏳ Loading...",
gr.update(visible=False),
gr.update(visible=False),
f"**🌍 Generations:** {load_counter()}",
)
# Synthesize speech
vertex_client = get_vertex_client()
success, audio_bytes, metrics = vertex_client.synthesize(text, voice_id)
if success and audio_bytes:
# Save audio file in system temp directory
temp_dir = tempfile.gettempdir()
audio_file = os.path.join(
temp_dir, f"ringg_{str(uuid.uuid4())}.wav"
)
with open(audio_file, "wb") as f:
f.write(audio_bytes)
# Increment counter
new_count = increment_counter()
# Format metrics
metrics_json = ""
has_metrics = False
if metrics:
has_metrics = True
metrics_json = json.dumps(
{
"total_time": f"{metrics.get('t', 0):.3f}s",
"rtf": f"{metrics.get('rtf', 0):.4f}",
"audio_duration": f"{metrics.get('wav_seconds', 0):.2f}s",
"vocoder_time": f"{metrics.get('t_vocoder', 0):.3f}s",
"no_vocoder_time": f"{metrics.get('t_no_vocoder', 0):.3f}s",
"rtf_no_vocoder": f"{metrics.get('rtf_no_vocoder', 0):.4f}",
},
indent=2,
)
# Yield success result
yield (
audio_file,
"",
gr.update(visible=has_metrics),
gr.update(value=metrics_json, visible=has_metrics),
f"**🌍 Generations:** {new_count}",
)
else:
# Yield failure result
yield (
None,
"❌ Failed to generate",
gr.update(visible=False),
gr.update(visible=False),
f"**🌍 Generations:** {load_counter()}",
)
def refresh_counter_on_load():
"""Refresh the universal generation counter when the UI loads/reloads"""
return f"**🌍 Generations since last reload:** {load_counter()}"
# Update character count on text input change
text_input.change(fn=update_char_count, inputs=[text_input], outputs=[char_count])
# Example button clicks
example_btn1.click(
fn=lambda: load_example_text(EXAMPLE_TEXT_ENGLISH),
inputs=None,
outputs=[text_input, char_count],
)
example_btn2.click(
fn=lambda: load_example_text(EXAMPLE_TEXT_HINDI),
inputs=None,
outputs=[text_input, char_count],
)
example_btn3.click(
fn=lambda: load_example_text(EXAMPLE_TEXT_MIXED),
inputs=None,
outputs=[text_input, char_count],
)
generate_btn.click(
fn=on_generate,
inputs=[text_input, voice_dropdown],
outputs=[
audio_output,
status,
metrics_header,
metrics_output,
generation_counter,
],
concurrency_limit=2,
concurrency_id="synthesis",
)
# Refresh global generation counter on page load/refresh
demo.load(fn=refresh_counter_on_load, inputs=None, outputs=[generation_counter])
if __name__ == "__main__":
demo.queue(default_concurrency_limit=2, max_size=20)
demo.launch(share=False, server_name="0.0.0.0", server_port=7860, debug=True)