Spaces:
Sleeping
Sleeping
File size: 2,613 Bytes
dab5eb8 5e272cd dab5eb8 5e272cd dab5eb8 5e272cd dab5eb8 5e272cd dab5eb8 5e272cd dab5eb8 5e272cd dab5eb8 5e272cd dab5eb8 5e272cd dab5eb8 5e272cd dab5eb8 5e272cd dab5eb8 5e272cd dab5eb8 5e272cd dab5eb8 5e272cd dab5eb8 5e272cd dab5eb8 5e272cd dab5eb8 5e272cd dab5eb8 5e272cd dab5eb8 5e272cd dab5eb8 5e272cd dab5eb8 5e272cd a38e917 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 |
import torch
from transformers import VideoMAEForVideoClassification, VideoMAEImageProcessor
from decord import VideoReader, cpu
import gradio as gr
# -------------------------------
# Device
# -------------------------------
device = "cuda" if torch.cuda.is_available() else "cpu"
# -------------------------------
# Load processor and model
# -------------------------------
processor = VideoMAEImageProcessor.from_pretrained(
"MCG-NJU/videomae-small-finetuned-ssv2"
)
model = VideoMAEForVideoClassification.from_pretrained(
"MCG-NJU/videomae-small-finetuned-ssv2",
num_labels=14,
ignore_mismatched_sizes=True
)
checkpoint = torch.load("videomae_best.pth", map_location=device)
model.load_state_dict(checkpoint["model_state_dict"])
model.to(device)
model.eval()
# -------------------------------
# Class mapping
# -------------------------------
id2class = {
0: "AFGHANISTAN",
1: "AFRICA",
2: "ANDHRA_PRADESH",
3: "ARGENTINA",
4: "DELHI",
5: "DENMARK",
6: "ENGLAND",
7: "GANGTOK",
8: "GOA",
9: "GUJARAT",
10: "HARYANA",
11: "HIMACHAL_PRADESH",
12: "JAIPUR",
13: "JAMMU_AND_KASHMIR"
}
# -------------------------------
# Video preprocessing
# -------------------------------
def preprocess_video(video_file, processor, num_frames=16):
"""
Preprocess a video file-like object for VideoMAE.
"""
video_path = video_file.name
vr = VideoReader(video_path, ctx=cpu(0))
total_frames = len(vr)
if total_frames < num_frames:
indices = [i % total_frames for i in range(num_frames)]
else:
indices = torch.linspace(0, total_frames - 1, num_frames).long().tolist()
video = vr.get_batch(indices).asnumpy()
inputs = processor(list(video), return_tensors="pt")
return inputs["pixel_values"][0]
# -------------------------------
# Prediction function
# -------------------------------
def predict_video(video_file):
pixel_values = preprocess_video(video_file, processor)
pixel_values = pixel_values.unsqueeze(0).to(device)
with torch.no_grad():
logits = model(pixel_values=pixel_values).logits
pred_index = torch.argmax(logits, dim=1).item()
return id2class[pred_index]
# -------------------------------
# Gradio Interface
# -------------------------------
iface = gr.Interface(
fn=predict_video,
inputs=gr.File(file_types=[".mp4"]), # Accept any MP4 file
outputs="text",
title="VideoMAE Classification API",
description="Upload a .mp4 video file to get the predicted class."
)
# Launch Space (public URL)
iface.launch(share=True)
|