Spaces:
Running
Running
| import gradio as gr | |
| import json | |
| import os | |
| from pathlib import Path | |
| import uuid | |
| import fcntl | |
| import time | |
| import tempfile | |
| from vertex_client import get_vertex_client | |
| # gr.NO_RELOAD = False | |
| # Counter persistence file | |
| COUNTER_FILE = Path("generation_counter.json") | |
| # Example texts | |
| EXAMPLE_TEXT_ENGLISH = "Welcome to Ringg TTS! This is a text to speech system that can convert your text into natural-sounding audio. Try it out with your own content!" | |
| EXAMPLE_TEXT_HINDI = "नमस्ते! मैं रिंग टीटीएस हूँ। मैं आपके टेक्स्ट को प्राकृतिक आवाज़ में बदल सकता हूँ। कृपया अपना टेक्स्ट यहाँ लिखें और सुनें।" | |
| EXAMPLE_TEXT_MIXED = "Hello दोस्तों! Welcome to Ringg TTS. यह एक बहुत ही शानदार text to speech system है जो Hindi और English दोनों languages को support करता है।" | |
| def load_counter(): | |
| """Load universal generation counter from file (thread-safe)""" | |
| try: | |
| if COUNTER_FILE.exists(): | |
| with open(COUNTER_FILE, "r") as f: | |
| # Try to acquire shared lock for reading | |
| try: | |
| fcntl.flock(f.fileno(), fcntl.LOCK_SH) | |
| data = json.load(f) | |
| fcntl.flock(f.fileno(), fcntl.LOCK_UN) | |
| return data.get("count", 0) | |
| except Exception: | |
| # If locking fails, just read without lock | |
| f.seek(0) | |
| data = json.load(f) | |
| return data.get("count", 0) | |
| except Exception as e: | |
| print(f"Error loading counter: {e}") | |
| return 0 | |
| def save_counter(count): | |
| """Save universal generation counter to file (thread-safe)""" | |
| try: | |
| # Use file locking to prevent race conditions with multiple users | |
| with open(COUNTER_FILE, "w") as f: | |
| try: | |
| fcntl.flock(f.fileno(), fcntl.LOCK_EX) | |
| json.dump({"count": count, "last_updated": time.time()}, f) | |
| f.flush() | |
| os.fsync(f.fileno()) | |
| fcntl.flock(f.fileno(), fcntl.LOCK_UN) | |
| except Exception: | |
| # If locking fails, just write without lock | |
| json.dump({"count": count, "last_updated": time.time()}, f) | |
| f.flush() | |
| except Exception as e: | |
| print(f"Error saving counter: {e}") | |
| def increment_counter(): | |
| """Atomically increment and return the new counter value""" | |
| try: | |
| # Read current value, increment, and save atomically | |
| with open(COUNTER_FILE, "r+" if COUNTER_FILE.exists() else "w+") as f: | |
| try: | |
| fcntl.flock(f.fileno(), fcntl.LOCK_EX) | |
| # Read current count | |
| f.seek(0) | |
| try: | |
| data = json.load(f) | |
| current_count = data.get("count", 0) | |
| except Exception: | |
| current_count = 0 | |
| # Increment | |
| new_count = current_count + 1 | |
| # Write back | |
| f.seek(0) | |
| f.truncate() | |
| json.dump({"count": new_count, "last_updated": time.time()}, f) | |
| f.flush() | |
| os.fsync(f.fileno()) | |
| fcntl.flock(f.fileno(), fcntl.LOCK_UN) | |
| return new_count | |
| except Exception: | |
| # Fallback without locking | |
| f.seek(0) | |
| try: | |
| data = json.load(f) | |
| current_count = data.get("count", 0) | |
| except Exception: | |
| current_count = 0 | |
| new_count = current_count + 1 | |
| f.seek(0) | |
| f.truncate() | |
| json.dump({"count": new_count, "last_updated": time.time()}, f) | |
| f.flush() | |
| return new_count | |
| except Exception as e: | |
| print(f"Error incrementing counter: {e}") | |
| return 0 | |
| def get_voices(): | |
| """Fetch available voices from Vertex AI""" | |
| try: | |
| vertex_client = get_vertex_client() | |
| success, voices_response = vertex_client.get_voices() | |
| if success and voices_response: | |
| print("✅ Fetched voices from Vertex AI") | |
| voices_data = voices_response.get("voices", {}) | |
| # Create a list of tuples (display_name, voice_id) | |
| voices = [] | |
| for voice_id, voice_info in voices_data.items(): | |
| name = voice_info.get("name", "Unknown") | |
| gender = voice_info.get("gender", "N/A") | |
| display_name = f"{name} ({gender})" | |
| voices.append((display_name, voice_id)) | |
| return sorted(voices, key=lambda x: x[0]) | |
| else: | |
| print("❌ Failed to fetch voices from Vertex AI") | |
| return [] | |
| except Exception as e: | |
| print(f"❌ Error fetching voices from Vertex AI: {e}") | |
| return [] | |
| def synthesize_speech(text, voice_id): | |
| """Synthesize speech from text using Vertex AI""" | |
| if not text or not text.strip(): | |
| return None, "⚠️ Please enter some text", "", "", "", "", "", "" | |
| if not voice_id: | |
| return None, "⚠️ Please select a voice", "", "", "", "", "", "" | |
| # Print input text length | |
| text_length = len(text) | |
| print(f"Input text length: {text_length} characters") | |
| try: | |
| vertex_client = get_vertex_client() | |
| success, audio_bytes, metrics = vertex_client.synthesize( | |
| text, voice_id, timeout=60 | |
| ) | |
| if success and audio_bytes: | |
| print("✅ Synthesized audio using Vertex AI") | |
| # Save binary audio to temp file in system temp directory | |
| temp_dir = tempfile.gettempdir() | |
| audio_file = os.path.join(temp_dir, f"ringg_{str(uuid.uuid4())}.wav") | |
| with open(audio_file, "wb") as f: | |
| f.write(audio_bytes) | |
| # Format metrics if available | |
| if metrics: | |
| total_time = f"{metrics.get('t', 0):.3f}s" | |
| rtf = f"{metrics.get('rtf', 0):.4f}" | |
| wav_duration = f"{metrics.get('wav_seconds', 0):.2f}s" | |
| vocoder_time = f"{metrics.get('t_vocoder', 0):.3f}s" | |
| no_vocoder_time = f"{metrics.get('t_no_vocoder', 0):.3f}s" | |
| rtf_no_vocoder = f"{metrics.get('rtf_no_vocoder', 0):.4f}" | |
| else: | |
| total_time = rtf = wav_duration = vocoder_time = no_vocoder_time = ( | |
| rtf_no_vocoder | |
| ) = "" | |
| status_msg = "" | |
| return ( | |
| audio_file, | |
| status_msg, | |
| total_time, | |
| rtf, | |
| wav_duration, | |
| vocoder_time, | |
| no_vocoder_time, | |
| rtf_no_vocoder, | |
| ) | |
| else: | |
| return None, "❌ Failed to generate audio", "", "", "", "", "", "" | |
| except Exception as e: | |
| print(f"❌ Vertex AI synthesis failed: {e}") | |
| return None, f"❌ Error: {str(e)}", "", "", "", "", "", "" | |
| # Load initial counter value | |
| initial_counter = load_counter() | |
| # Create Gradio interface | |
| with gr.Blocks( | |
| theme=gr.themes.Base( | |
| font=[gr.themes.GoogleFont("Source Sans Pro"), "Arial", "sans-serif"] | |
| ), | |
| css=".gradio-container {max-width: none !important;}", | |
| ) as demo: | |
| # Title with Health Status | |
| with gr.Row(): | |
| with gr.Column(scale=4): | |
| audio_image = gr.HTML( | |
| value=""" | |
| <div style="display: flex; align-items: center; gap: 10px;"> | |
| <img style="width: 50px; height: 50px; background-color: white; border-radius: 10%;" src="https://storage.googleapis.com/desivocal-prod/desi-vocal/ringg.svg" alt="Logo"> | |
| <h1 style="margin: 0;">Ringg Squirrel TTS v1.0 🐿️</h1> | |
| </div> | |
| """ | |
| ) | |
| with gr.Column(scale=1): | |
| generation_counter = gr.Markdown( | |
| f"**🌍 Generations since last commit:** {initial_counter}", | |
| elem_id="counter", | |
| ) | |
| # Best Practices Section | |
| gr.Markdown(""" | |
| ## 📝 Best Practices for Best Results | |
| - **Supported Languages:** Hindi and English only | |
| - **Check spelling carefully:** Misspelled words may be mispronounced | |
| - **Punctuation matters:** Use proper punctuation for natural pauses and intonation | |
| - **Technical terms:** Extremely rare or specialized technical terms might be mispronounced | |
| - **Numbers & dates:** Write numbers as words for better pronunciation (e.g., "twenty-five" instead of "25") | |
| """) | |
| # Input Section - Text, Voice, and Character Count grouped together | |
| with gr.Group(): | |
| # Text Input | |
| text_input = gr.Textbox( | |
| label="Text (max 300 characters)", | |
| placeholder="Type or paste your text here (max 300 characters)...", | |
| lines=6, | |
| max_lines=10, | |
| max_length=300, | |
| ) | |
| # Voice Selection | |
| voices = get_voices() | |
| voice_choices = {display: vid for display, vid in voices} | |
| voice_dropdown = gr.Dropdown( | |
| choices=list(voice_choices.keys()), | |
| label="Choose a voice style", | |
| info=f"{len(voices)} voices available", | |
| value=list(voice_choices.keys())[0] if voices else None, | |
| show_label=False, | |
| ) | |
| # Character count display | |
| char_count = gr.Code( | |
| "Character count: 0 / 300", | |
| show_line_numbers=False, | |
| show_label=False, | |
| ) | |
| # Audio output section | |
| gr.Markdown("### 🎧 Audio Result") | |
| audio_output = gr.Audio(label="Generated Audio", type="filepath") | |
| status = gr.Markdown("", visible=True) | |
| metrics_header = gr.Markdown("**📊 Metrics**", visible=False) | |
| metrics_output = gr.Code( | |
| label="Performance Metrics", | |
| language="json", | |
| interactive=False, | |
| visible=False, | |
| ) | |
| generate_btn = gr.Button("🎬 Generate Speech", variant="primary", size="lg") | |
| with gr.Row(): | |
| example_btn1 = gr.Button("English Example", size="sm") | |
| example_btn2 = gr.Button("Hindi Example", size="sm") | |
| example_btn3 = gr.Button("Mixed Example", size="sm") | |
| # Footer | |
| gr.Markdown("---") | |
| gr.Markdown("# 🙏 Acknowledgements") | |
| # gr.Markdown("- Based on [ZipVoice](https://github.com/k2-fsa/ZipVoice)") | |
| gr.Markdown( | |
| "- Special thanks to [@jeremylee12](https://huggingface.co/jeremylee12) for his contributions" | |
| ) | |
| # Event Handlers | |
| def update_char_count(text): | |
| """Update character count as user types""" | |
| count = len(text) if text else 0 | |
| return f"Character count: {count} / 300" | |
| def load_example_text(example_text): | |
| """Load example text and update character count""" | |
| count = len(example_text) | |
| return example_text, f"Character count: {count} / 300" | |
| def clear_text(): | |
| """Clear text input""" | |
| return "", "Character count: 0 / 300" | |
| def on_generate(text, voice_display): | |
| """Generate speech using the distill model.""" | |
| # Validate inputs | |
| if not text or not text.strip(): | |
| error_msg = "⚠️ Please enter some text" | |
| yield ( | |
| None, | |
| error_msg, | |
| gr.update(visible=False), | |
| gr.update(visible=False), | |
| f"**🌍 Generations:** {load_counter()}", | |
| ) | |
| return | |
| voice_id = voice_choices.get(voice_display) | |
| if not voice_id: | |
| error_msg = "⚠️ Please select a voice" | |
| yield ( | |
| None, | |
| error_msg, | |
| gr.update(visible=False), | |
| gr.update(visible=False), | |
| f"**🌍 Generations:** {load_counter()}", | |
| ) | |
| return | |
| # Show loading state initially | |
| yield ( | |
| None, | |
| "⏳ Loading...", | |
| gr.update(visible=False), | |
| gr.update(visible=False), | |
| f"**🌍 Generations:** {load_counter()}", | |
| ) | |
| # Synthesize speech | |
| vertex_client = get_vertex_client() | |
| success, audio_bytes, metrics = vertex_client.synthesize(text, voice_id) | |
| if success and audio_bytes: | |
| # Save audio file in system temp directory | |
| temp_dir = tempfile.gettempdir() | |
| audio_file = os.path.join( | |
| temp_dir, f"ringg_{str(uuid.uuid4())}.wav" | |
| ) | |
| with open(audio_file, "wb") as f: | |
| f.write(audio_bytes) | |
| # Increment counter | |
| new_count = increment_counter() | |
| # Format metrics | |
| metrics_json = "" | |
| has_metrics = False | |
| if metrics: | |
| has_metrics = True | |
| metrics_json = json.dumps( | |
| { | |
| "total_time": f"{metrics.get('t', 0):.3f}s", | |
| "rtf": f"{metrics.get('rtf', 0):.4f}", | |
| "audio_duration": f"{metrics.get('wav_seconds', 0):.2f}s", | |
| "vocoder_time": f"{metrics.get('t_vocoder', 0):.3f}s", | |
| "no_vocoder_time": f"{metrics.get('t_no_vocoder', 0):.3f}s", | |
| "rtf_no_vocoder": f"{metrics.get('rtf_no_vocoder', 0):.4f}", | |
| }, | |
| indent=2, | |
| ) | |
| # Yield success result | |
| yield ( | |
| audio_file, | |
| "", | |
| gr.update(visible=has_metrics), | |
| gr.update(value=metrics_json, visible=has_metrics), | |
| f"**🌍 Generations:** {new_count}", | |
| ) | |
| else: | |
| # Yield failure result | |
| yield ( | |
| None, | |
| "❌ Failed to generate", | |
| gr.update(visible=False), | |
| gr.update(visible=False), | |
| f"**🌍 Generations:** {load_counter()}", | |
| ) | |
| def refresh_counter_on_load(): | |
| """Refresh the universal generation counter when the UI loads/reloads""" | |
| return f"**🌍 Generations since last reload:** {load_counter()}" | |
| # Update character count on text input change | |
| text_input.change(fn=update_char_count, inputs=[text_input], outputs=[char_count]) | |
| # Example button clicks | |
| example_btn1.click( | |
| fn=lambda: load_example_text(EXAMPLE_TEXT_ENGLISH), | |
| inputs=None, | |
| outputs=[text_input, char_count], | |
| ) | |
| example_btn2.click( | |
| fn=lambda: load_example_text(EXAMPLE_TEXT_HINDI), | |
| inputs=None, | |
| outputs=[text_input, char_count], | |
| ) | |
| example_btn3.click( | |
| fn=lambda: load_example_text(EXAMPLE_TEXT_MIXED), | |
| inputs=None, | |
| outputs=[text_input, char_count], | |
| ) | |
| generate_btn.click( | |
| fn=on_generate, | |
| inputs=[text_input, voice_dropdown], | |
| outputs=[ | |
| audio_output, | |
| status, | |
| metrics_header, | |
| metrics_output, | |
| generation_counter, | |
| ], | |
| concurrency_limit=2, | |
| concurrency_id="synthesis", | |
| ) | |
| # Refresh global generation counter on page load/refresh | |
| demo.load(fn=refresh_counter_on_load, inputs=None, outputs=[generation_counter]) | |
| if __name__ == "__main__": | |
| demo.queue(default_concurrency_limit=2, max_size=20) | |
| demo.launch(share=False, server_name="0.0.0.0", server_port=7860, debug=True) | |