ecuenca40 commited on
Commit
5b67aa1
·
verified ·
1 Parent(s): 62b18d1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -4
app.py CHANGED
@@ -1,14 +1,49 @@
1
  import gradio as gr
2
- from transformers import AutoProcessor, AutoModelForCausalLM
3
  from PIL import Image
 
4
  import torch
5
 
6
  model_id = "google/medgemma-4b-it"
7
 
8
- processor = AutoProcessor.from_pretrained(model_id)
9
- model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True).to(
10
- "cuda" if torch.cuda.is_available() else "cpu"
 
11
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  def generate_report(image):
14
  if image is None:
 
1
  import gradio as gr
2
+ from transformers import AutoProcessor, AutoModelForImageTextToText
3
  from PIL import Image
4
+ import requests
5
  import torch
6
 
7
  model_id = "google/medgemma-4b-it"
8
 
9
+ model = AutoModelForImageTextToText.from_pretrained(
10
+ model_id,
11
+ torch_dtype=torch.bfloat16,
12
+ device_map="auto",
13
  )
14
+ processor = AutoProcessor.from_pretrained(model_id)
15
+
16
+ # Image attribution: Stillwaterising, CC0, via Wikimedia Commons
17
+ image_url = "https://upload.wikimedia.org/wikipedia/commons/c/c8/Chest_Xray_PA_3-8-2010.png"
18
+ image = Image.open(requests.get(image_url, headers={"User-Agent": "example"}, stream=True).raw)
19
+
20
+ messages = [
21
+ {
22
+ "role": "system",
23
+ "content": [{"type": "text", "text": "You are an expert radiologist."}]
24
+ },
25
+ {
26
+ "role": "user",
27
+ "content": [
28
+ {"type": "text", "text": "Describe this X-ray"},
29
+ {"type": "image", "image": image}
30
+ ]
31
+ }
32
+ ]
33
+
34
+ inputs = processor.apply_chat_template(
35
+ messages, add_generation_prompt=True, tokenize=True,
36
+ return_dict=True, return_tensors="pt"
37
+ ).to(model.device, dtype=torch.bfloat16)
38
+
39
+ input_len = inputs["input_ids"].shape[-1]
40
+
41
+ with torch.inference_mode():
42
+ generation = model.generate(**inputs, max_new_tokens=200, do_sample=False)
43
+ generation = generation[0][input_len:]
44
+
45
+ decoded = processor.decode(generation, skip_special_tokens=True)
46
+ print(decoded)
47
 
48
  def generate_report(image):
49
  if image is None: