rahul7star commited on
Commit
166e332
·
verified ·
1 Parent(s): c18c42d

Create app_quant.py

Browse files
Files changed (1) hide show
  1. app_quant.py +182 -0
app_quant.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ---------------------------------------------------------
2
+ # Nava Ultra-Fast CPU Inference (4-bit Quant + Caching)
3
+ # ---------------------------------------------------------
4
+ import gradio as gr
5
+ import torch
6
+ import soundfile as sf
7
+ from pathlib import Path
8
+
9
+ from transformers import (
10
+ AutoTokenizer,
11
+ AutoModelForCausalLM,
12
+ BitsAndBytesConfig
13
+ )
14
+ from peft import PeftModel
15
+ from snac import SNAC
16
+
17
+ # ---------------------------------------------------------
18
+ # CONFIG
19
+ # ---------------------------------------------------------
20
+ MODEL_NAME = "rahul7star/nava1.0"
21
+ LORA_NAME = "rahul7star/nava-audio"
22
+ SNAC_MODEL_NAME = "rahul7star/nava-snac"
23
+
24
+ SEQ_LEN = 240000
25
+ TARGET_SR = 240000
26
+ OUT_ROOT = Path("/tmp/data")
27
+ OUT_ROOT.mkdir(exist_ok=True, parents=True)
28
+
29
+ DEFAULT_TEXT = (
30
+ "राजनीतिज्ञों ने कहा कि उन्होंने निर्णायक मत को अनावश्यक रूप से "
31
+ "निर्धारित करने के लिए अफ़गान संविधान में काफी अस्पष्टता पाई थी"
32
+ )
33
+
34
+ DEVICE = "cpu"
35
+
36
+ # ---------------------------------------------------------
37
+ # QUANT CONFIG (4-BIT)
38
+ # ---------------------------------------------------------
39
+ quant_config = BitsAndBytesConfig(
40
+ load_in_4bit=True,
41
+ bnb_4bit_quant_type="nf4",
42
+ bnb_4bit_use_double_quant=True,
43
+ bnb_4bit_compute_dtype=torch.bfloat16,
44
+ )
45
+
46
+ # ---------------------------------------------------------
47
+ # LOAD TOKENIZER (cached)
48
+ # ---------------------------------------------------------
49
+ print("🔄 Loading tokenizer...")
50
+ tokenizer = AutoTokenizer.from_pretrained(
51
+ MODEL_NAME,
52
+ trust_remote_code=True
53
+ )
54
+
55
+ # ---------------------------------------------------------
56
+ # LOAD BASE MODEL (4-bit CPU quant)
57
+ # ---------------------------------------------------------
58
+ print("🔄 Loading base model in 4-bit…")
59
+ base_model = AutoModelForCausalLM.from_pretrained(
60
+ MODEL_NAME,
61
+ quantization_config=quant_config,
62
+ device_map={"": DEVICE},
63
+ torch_dtype=torch.bfloat16,
64
+ trust_remote_code=True
65
+ )
66
+
67
+ # ---------------------------------------------------------
68
+ # LOAD LORA (merged on top)
69
+ # ---------------------------------------------------------
70
+ print("🔄 Loading LoRA weights…")
71
+ model = PeftModel.from_pretrained(
72
+ base_model,
73
+ LORA_NAME,
74
+ device_map={"": DEVICE}
75
+ ).eval()
76
+
77
+ # ---------------------------------------------------------
78
+ # LOAD SNAC ONCE ONLY
79
+ # ---------------------------------------------------------
80
+ print("🔄 Loading SNAC…")
81
+ snac_model = SNAC.from_pretrained(SNAC_MODEL_NAME).eval().to(DEVICE)
82
+
83
+
84
+ # =========================================================
85
+ # INFERENCE FUNCTION
86
+ # =========================================================
87
+ def generate_audio_cpu_lora(text):
88
+
89
+ logs = []
90
+ logs.append("⚡ Running fast 4-bit CPU inference…")
91
+
92
+ # Tokens
93
+ soh = tokenizer.decode([128259])
94
+ eoh = tokenizer.decode([128260])
95
+ soa = tokenizer.decode([128261])
96
+ sos = tokenizer.decode([128257])
97
+ eot = tokenizer.decode([128009])
98
+ bos = tokenizer.bos_token
99
+
100
+ prompt = soh + bos + text + eot + eoh + soa + sos
101
+ inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
102
+
103
+ # -----------------------------------------------------
104
+ # GENERATE SNAC TOKENS (FAST 4-bit)
105
+ # -----------------------------------------------------
106
+ with torch.inference_mode():
107
+ outputs = model.generate(
108
+ **inputs,
109
+ max_new_tokens=SEQ_LEN,
110
+ temperature=0.4,
111
+ top_p=0.9,
112
+ repetition_penalty=1.1,
113
+ do_sample=True,
114
+ eos_token_id=128258,
115
+ pad_token_id=tokenizer.pad_token_id
116
+ )
117
+
118
+ # Strip prompt
119
+ gen_ids = outputs[0, inputs['input_ids'].shape[1]:].tolist()
120
+
121
+ # Extract valid SNAC tokens
122
+ snac_min, snac_max = 128266, 156937
123
+ eos_id = 128258
124
+ eos_idx = gen_ids.index(eos_id) if eos_id in gen_ids else len(gen_ids)
125
+
126
+ snac_tokens = [t for t in gen_ids[:eos_idx] if snac_min <= t <= snac_max]
127
+
128
+ # -----------------------------------------------------
129
+ # DECODE SNAC → AUDIO
130
+ # -----------------------------------------------------
131
+ l1, l2, l3 = [], [], []
132
+ frames = len(snac_tokens) // 7
133
+ snac_tokens = snac_tokens[:frames * 7]
134
+
135
+ for i in range(frames):
136
+ s = snac_tokens[i * 7:(i + 1) * 7]
137
+ l1.append((s[0] - snac_min) % 4096)
138
+ l2.extend([(s[1]-snac_min)%4096, (s[4]-snac_min)%4096])
139
+ l3.extend([(s[2]-snac_min)%4096, (s[3]-snac_min)%4096,
140
+ (s[5]-snac_min)%4096, (s[6]-snac_min)%4096])
141
+
142
+ codes = [
143
+ torch.tensor(l1).unsqueeze(0),
144
+ torch.tensor(l2).unsqueeze(0),
145
+ torch.tensor(l3).unsqueeze(0)
146
+ ]
147
+
148
+ with torch.inference_mode():
149
+ z = snac_model.quantizer.from_codes(codes)
150
+ audio = snac_model.decoder(z)[0, 0].cpu().numpy()
151
+
152
+ # Remove crackles
153
+ if len(audio) > 2048:
154
+ audio = audio[2048:]
155
+
156
+ # Save WAV
157
+ out = OUT_ROOT / "tts_output_cpu_lora.wav"
158
+ sf.write(out, audio, TARGET_SR)
159
+
160
+ logs.append("🎧 Audio generated successfully")
161
+
162
+ return str(out), str(out), "\n".join(logs)
163
+
164
+
165
+ # =========================================================
166
+ # GRADIO UI
167
+ # =========================================================
168
+ with gr.Blocks() as demo:
169
+ gr.Markdown("## ⚡ Maya TTS — Ultra-Fast 4-bit CPU Inference")
170
+
171
+ txt = gr.Textbox(label="Enter text", value=DEFAULT_TEXT)
172
+ btn = gr.Button("Generate Audio")
173
+
174
+ audio = gr.Audio(label="Audio", type="filepath")
175
+ file = gr.File(label="Download")
176
+ logs = gr.Textbox(label="Logs")
177
+
178
+ btn.click(generate_audio_cpu_lora, [txt], [audio, file, logs])
179
+
180
+
181
+ if __name__ == "__main__":
182
+ demo.launch()