Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| import clip | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model, preprocess = clip.load("ViT-B/32", device=device) | |
| def predict(image, labels): | |
| labels = labels.split(',') | |
| image = preprocess(image).unsqueeze(0).to(device) | |
| text = clip.tokenize([f"a photo of a {c}" for c in labels]).to(device) | |
| with torch.inference_mode(): | |
| logits_per_image, logits_per_text = model(image, text) | |
| probs = logits_per_image.softmax(dim=-1).cpu().numpy() | |
| return {k: float(v) for k, v in zip(labels, probs[0])} | |
| # probs = predict(Image.open("../CLIP/CLIP.png"), "cat, dog, ball") | |
| # print(probs) | |
| gr.Interface(fn=predict, | |
| inputs=[ | |
| gr.inputs.Image(label="Image to classify.", type="pil"), | |
| gr.inputs.Textbox(lines=1, label="Comma separated classes", placeholder="Enter your classes separated by ','",)], | |
| theme="grass", | |
| outputs="label", | |
| description="Zero Shot Image classification..").launch() | |