Gapeleon commited on
Commit
9933c70
·
verified ·
1 Parent(s): 07c4f24

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +332 -0
app.py ADDED
@@ -0,0 +1,332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+ from transformers import AutoTokenizer, AutoModel, AutoImageProcessor
4
+ from PIL import Image
5
+ import gc
6
+ import os
7
+ import spaces
8
+
9
+ # Model configuration
10
+ MODEL_PATH = "nvidia/Llama-Nemotron-Nano-VL-8B-V1"
11
+
12
+ # Load model globally
13
+ print("Loading model...")
14
+ model = AutoModel.from_pretrained(
15
+ MODEL_PATH,
16
+ torch_dtype=torch.bfloat16,
17
+ low_cpu_mem_usage=True,
18
+ trust_remote_code=True,
19
+ ).eval()
20
+
21
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
22
+ image_processor = AutoImageProcessor.from_pretrained(
23
+ MODEL_PATH,
24
+ trust_remote_code=True
25
+ )
26
+ print("Model loaded successfully!")
27
+
28
+ def move_to_device(obj, device):
29
+ """Recursively move tensors to device"""
30
+ if torch.is_tensor(obj):
31
+ return obj.to(device)
32
+ elif isinstance(obj, dict):
33
+ return {k: move_to_device(v, device) for k, v in obj.items()}
34
+ elif isinstance(obj, list):
35
+ return [move_to_device(v, device) for v in obj]
36
+ elif isinstance(obj, tuple):
37
+ return tuple(move_to_device(v, device) for v in obj)
38
+ elif hasattr(obj, 'to'):
39
+ return obj.to(device)
40
+ else:
41
+ return obj
42
+
43
+ @spaces.GPU(duration=60)
44
+ def chat_text_only(message):
45
+ try:
46
+ device = "cuda"
47
+
48
+ # Move entire model to GPU
49
+ model.to(device)
50
+
51
+ generation_config = dict(
52
+ max_new_tokens=512,
53
+ do_sample=True,
54
+ temperature=0.7,
55
+ eos_token_id=tokenizer.eos_token_id
56
+ )
57
+
58
+ # Tokenize on CPU then move to GPU
59
+ inputs = tokenizer(message, return_tensors="pt")
60
+ inputs = move_to_device(inputs, device)
61
+
62
+ # Generate
63
+ with torch.no_grad():
64
+ response, _ = model.chat(
65
+ tokenizer,
66
+ None,
67
+ message,
68
+ generation_config,
69
+ history=None,
70
+ return_history=True
71
+ )
72
+
73
+ # Move model back to CPU
74
+ model.to("cpu")
75
+ torch.cuda.empty_cache()
76
+ gc.collect()
77
+
78
+ return response
79
+
80
+ except Exception as e:
81
+ # Ensure model is back on CPU even if error occurs
82
+ model.to("cpu")
83
+ torch.cuda.empty_cache()
84
+ gc.collect()
85
+ return f"Error: {str(e)}"
86
+
87
+ @spaces.GPU(duration=60)
88
+ def chat_with_image(image, message):
89
+ if image is None:
90
+ return "Please upload an image."
91
+
92
+ try:
93
+ device = "cuda"
94
+
95
+ # Move entire model to GPU
96
+ model.to(device)
97
+
98
+ generation_config = dict(
99
+ max_new_tokens=512,
100
+ do_sample=True,
101
+ temperature=0.7,
102
+ eos_token_id=tokenizer.eos_token_id
103
+ )
104
+
105
+ # Process image
106
+ image_features = image_processor(image)
107
+
108
+ # Move all image features to GPU
109
+ image_features = move_to_device(image_features, device)
110
+
111
+ # Add image token to message if not present
112
+ if "<image>" not in message:
113
+ message = f"<image>\n{message}"
114
+
115
+ # Generate
116
+ with torch.no_grad():
117
+ response = model.chat(
118
+ tokenizer=tokenizer,
119
+ question=message,
120
+ generation_config=generation_config,
121
+ **image_features
122
+ )
123
+
124
+ # Move model back to CPU
125
+ model.to("cpu")
126
+ torch.cuda.empty_cache()
127
+ gc.collect()
128
+
129
+ return response
130
+
131
+ except Exception as e:
132
+ # Ensure model is back on CPU even if error occurs
133
+ model.to("cpu")
134
+ torch.cuda.empty_cache()
135
+ gc.collect()
136
+ return f"Error: {str(e)}"
137
+
138
+ @spaces.GPU(duration=60)
139
+ def chat_with_two_images(image1, image2, message):
140
+ if image1 is None or image2 is None:
141
+ return "Please upload both images."
142
+
143
+ try:
144
+ device = "cuda"
145
+
146
+ # Move entire model to GPU
147
+ model.to(device)
148
+
149
+ generation_config = dict(
150
+ max_new_tokens=512,
151
+ do_sample=True,
152
+ temperature=0.7,
153
+ eos_token_id=tokenizer.eos_token_id
154
+ )
155
+
156
+ # Process both images
157
+ image_features = image_processor([image1, image2])
158
+
159
+ # Move all image features to GPU
160
+ image_features = move_to_device(image_features, device)
161
+
162
+ # Format message for two images
163
+ if "<image-1>" not in message and "<image-2>" not in message:
164
+ message = f"<image-1>: <image>\n<image-2>: <image>\n{message}"
165
+
166
+ # Generate
167
+ with torch.no_grad():
168
+ response = model.chat(
169
+ tokenizer=tokenizer,
170
+ question=message,
171
+ generation_config=generation_config,
172
+ **image_features
173
+ )
174
+
175
+ # Move model back to CPU
176
+ model.to("cpu")
177
+ torch.cuda.empty_cache()
178
+ gc.collect()
179
+
180
+ return response
181
+
182
+ except Exception as e:
183
+ # Ensure model is back on CPU even if error occurs
184
+ model.to("cpu")
185
+ torch.cuda.empty_cache()
186
+ gc.collect()
187
+ return f"Error: {str(e)}"
188
+
189
+ # Create Gradio interface
190
+ def create_interface():
191
+ with gr.Blocks(title="Llama Nemotron Nano VL 8B", theme=gr.themes.Soft()) as demo:
192
+ gr.Markdown("# 🦙 Llama Nemotron Nano VL 8B Vision-Language Model")
193
+ gr.Markdown("Chat with a powerful vision-language model that can understand both text and images!")
194
+
195
+ with gr.Tabs():
196
+ # Text-only chat tab
197
+ with gr.TabItem("💬 Text Chat"):
198
+ gr.Markdown("### Chat with the model using text only")
199
+
200
+ with gr.Row():
201
+ with gr.Column():
202
+ text_input = gr.Textbox(
203
+ label="Your message",
204
+ placeholder="Ask me anything...",
205
+ lines=3
206
+ )
207
+ text_submit = gr.Button("Send", variant="primary")
208
+
209
+ with gr.Column():
210
+ text_output = gr.Textbox(
211
+ label="Model Response",
212
+ lines=10,
213
+ max_lines=20
214
+ )
215
+
216
+ text_submit.click(
217
+ chat_text_only,
218
+ inputs=[text_input],
219
+ outputs=[text_output]
220
+ )
221
+
222
+ # Example questions
223
+ gr.Examples(
224
+ examples=[
225
+ ["What is artificial intelligence?"],
226
+ ["Explain quantum computing in simple terms."],
227
+ ["What happened in 1969?"],
228
+ ["Write a short story about a robot."]
229
+ ],
230
+ inputs=[text_input]
231
+ )
232
+
233
+ # Single image chat tab
234
+ with gr.TabItem("🖼️ Image + Text Chat"):
235
+ gr.Markdown("### Upload an image and ask questions about it")
236
+
237
+ with gr.Row():
238
+ with gr.Column():
239
+ image_input = gr.Image(
240
+ label="Upload Image",
241
+ type="pil"
242
+ )
243
+ image_text_input = gr.Textbox(
244
+ label="Your question about the image",
245
+ placeholder="What do you see in this image?",
246
+ lines=3
247
+ )
248
+ image_submit = gr.Button("Analyze", variant="primary")
249
+
250
+ with gr.Column():
251
+ image_output = gr.Textbox(
252
+ label="Model Response",
253
+ lines=10,
254
+ max_lines=20
255
+ )
256
+
257
+ image_submit.click(
258
+ chat_with_image,
259
+ inputs=[image_input, image_text_input],
260
+ outputs=[image_output]
261
+ )
262
+
263
+ # Example prompts
264
+ gr.Examples(
265
+ examples=[
266
+ ["Describe what you see in this image."],
267
+ ["What objects are in this image?"],
268
+ ["Extract any text from this image."],
269
+ ["What is the main subject of this image?"]
270
+ ],
271
+ inputs=[image_text_input]
272
+ )
273
+
274
+ # Two images comparison tab
275
+ with gr.TabItem("🖼️🖼️ Compare Two Images"):
276
+ gr.Markdown("### Upload two images and ask the model to compare them")
277
+
278
+ with gr.Row():
279
+ with gr.Column():
280
+ image1_input = gr.Image(
281
+ label="First Image",
282
+ type="pil"
283
+ )
284
+ image2_input = gr.Image(
285
+ label="Second Image",
286
+ type="pil"
287
+ )
288
+ two_images_text_input = gr.Textbox(
289
+ label="Your question about both images",
290
+ placeholder="Compare these two images...",
291
+ lines=3
292
+ )
293
+ two_images_submit = gr.Button("Compare", variant="primary")
294
+
295
+ with gr.Column():
296
+ two_images_output = gr.Textbox(
297
+ label="Model Response",
298
+ lines=10,
299
+ max_lines=20
300
+ )
301
+
302
+ two_images_submit.click(
303
+ chat_with_two_images,
304
+ inputs=[image1_input, image2_input, two_images_text_input],
305
+ outputs=[two_images_output]
306
+ )
307
+
308
+ # Example prompts
309
+ gr.Examples(
310
+ examples=[
311
+ ["What are the main differences between these two images?"],
312
+ ["Describe both images briefly."],
313
+ ["Which image is more colorful?"],
314
+ ["Compare the subjects in these images."]
315
+ ],
316
+ inputs=[two_images_text_input]
317
+ )
318
+
319
+ # Footer
320
+ gr.Markdown("---")
321
+ gr.Markdown("⚡ Powered by NVIDIA Llama Nemotron Nano VL 8B")
322
+
323
+ return demo
324
+
325
+ # Create and launch the interface
326
+ if __name__ == "__main__":
327
+ demo = create_interface()
328
+ demo.queue() # Enable queuing for Zero GPU
329
+ demo.launch(
330
+ server_name="0.0.0.0",
331
+ server_port=7860
332
+ )