# --------------------------------------------------------- # Nava Ultra-Fast CPU Inference (4-bit Quant + Caching) # --------------------------------------------------------- import gradio as gr import torch import soundfile as sf from pathlib import Path from transformers import ( AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig ) from peft import PeftModel from snac import SNAC # --------------------------------------------------------- # CONFIG # --------------------------------------------------------- MODEL_NAME = "rahul7star/nava1.0" LORA_NAME = "rahul7star/nava-audio" SNAC_MODEL_NAME = "rahul7star/nava-snac" SEQ_LEN = 240000 TARGET_SR = 240000 OUT_ROOT = Path("/tmp/data") OUT_ROOT.mkdir(exist_ok=True, parents=True) DEFAULT_TEXT = ( "राजनीतिज्ञों ने कहा कि उन्होंने निर्णायक मत को अनावश्यक रूप से " "निर्धारित करने के लिए अफ़गान संविधान में काफी अस्पष्टता पाई थी" ) DEVICE = "cpu" # --------------------------------------------------------- # QUANT CONFIG (4-BIT) # --------------------------------------------------------- quant_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.bfloat16, ) # --------------------------------------------------------- # LOAD TOKENIZER (cached) # --------------------------------------------------------- print("🔄 Loading tokenizer...") tokenizer = AutoTokenizer.from_pretrained( MODEL_NAME, trust_remote_code=True ) # --------------------------------------------------------- # LOAD BASE MODEL (4-bit CPU quant) # --------------------------------------------------------- print("🔄 Loading base model in 4-bit…") base_model = AutoModelForCausalLM.from_pretrained( MODEL_NAME, quantization_config=quant_config, device_map={"": DEVICE}, torch_dtype=torch.bfloat16, trust_remote_code=True ) # --------------------------------------------------------- # LOAD LORA (merged on top) # --------------------------------------------------------- print("🔄 Loading LoRA weights…") model = PeftModel.from_pretrained( base_model, LORA_NAME, device_map={"": DEVICE} ).eval() # --------------------------------------------------------- # LOAD SNAC ONCE ONLY # --------------------------------------------------------- print("🔄 Loading SNAC…") snac_model = SNAC.from_pretrained(SNAC_MODEL_NAME).eval().to(DEVICE) # ========================================================= # INFERENCE FUNCTION # ========================================================= def generate_audio_cpu_lora(text): logs = [] logs.append("⚡ Running fast 4-bit CPU inference…") # Tokens soh = tokenizer.decode([128259]) eoh = tokenizer.decode([128260]) soa = tokenizer.decode([128261]) sos = tokenizer.decode([128257]) eot = tokenizer.decode([128009]) bos = tokenizer.bos_token prompt = soh + bos + text + eot + eoh + soa + sos inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE) # ----------------------------------------------------- # GENERATE SNAC TOKENS (FAST 4-bit) # ----------------------------------------------------- with torch.inference_mode(): outputs = model.generate( **inputs, max_new_tokens=SEQ_LEN, temperature=0.4, top_p=0.9, repetition_penalty=1.1, do_sample=True, eos_token_id=128258, pad_token_id=tokenizer.pad_token_id ) # Strip prompt gen_ids = outputs[0, inputs['input_ids'].shape[1]:].tolist() # Extract valid SNAC tokens snac_min, snac_max = 128266, 156937 eos_id = 128258 eos_idx = gen_ids.index(eos_id) if eos_id in gen_ids else len(gen_ids) snac_tokens = [t for t in gen_ids[:eos_idx] if snac_min <= t <= snac_max] # ----------------------------------------------------- # DECODE SNAC → AUDIO # ----------------------------------------------------- l1, l2, l3 = [], [], [] frames = len(snac_tokens) // 7 snac_tokens = snac_tokens[:frames * 7] for i in range(frames): s = snac_tokens[i * 7:(i + 1) * 7] l1.append((s[0] - snac_min) % 4096) l2.extend([(s[1]-snac_min)%4096, (s[4]-snac_min)%4096]) l3.extend([(s[2]-snac_min)%4096, (s[3]-snac_min)%4096, (s[5]-snac_min)%4096, (s[6]-snac_min)%4096]) codes = [ torch.tensor(l1).unsqueeze(0), torch.tensor(l2).unsqueeze(0), torch.tensor(l3).unsqueeze(0) ] with torch.inference_mode(): z = snac_model.quantizer.from_codes(codes) audio = snac_model.decoder(z)[0, 0].cpu().numpy() # Remove crackles if len(audio) > 2048: audio = audio[2048:] # Save WAV out = OUT_ROOT / "tts_output_cpu_lora.wav" sf.write(out, audio, TARGET_SR) logs.append("🎧 Audio generated successfully") return str(out), str(out), "\n".join(logs) # ========================================================= # GRADIO UI # ========================================================= with gr.Blocks() as demo: gr.Markdown("## ⚡ Maya TTS — Ultra-Fast 4-bit CPU Inference") txt = gr.Textbox(label="Enter text", value=DEFAULT_TEXT) btn = gr.Button("Generate Audio") audio = gr.Audio(label="Audio", type="filepath") file = gr.File(label="Download") logs = gr.Textbox(label="Logs") btn.click(generate_audio_cpu_lora, [txt], [audio, file, logs]) if __name__ == "__main__": demo.launch()