File size: 2,653 Bytes
47b7a70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import torch
from transformers import VideoMAEForVideoClassification, VideoMAEImageProcessor
from decord import VideoReader, cpu
import gradio as gr

# -------------------------------
# Device
# -------------------------------
device = "cuda" if torch.cuda.is_available() else "cpu"

# -------------------------------
# Load processor and model
# -------------------------------
processor = VideoMAEImageProcessor.from_pretrained("MCG-NJU/videomae-small-finetuned-ssv2")
model = VideoMAEForVideoClassification.from_pretrained(
    "MCG-NJU/videomae-small-finetuned-ssv2",
    num_labels=14,
    ignore_mismatched_sizes=True
)
checkpoint = torch.load("videomae_best.pth", map_location=device)
model.load_state_dict(checkpoint["model_state_dict"])
model.to(device)
model.eval()

# -------------------------------
# Class mapping
# -------------------------------
id2class = {
    0: "AFGHANISTAN",
    1: "AFRICA",
    2: "ANDHRA_PRADESH",
    3: "ARGENTINA",
    4: "DELHI",
    5: "DENMARK",
    6: "ENGLAND",
    7: "GANGTOK",
    8: "GOA",
    9: "GUJARAT",
    10: "HARYANA",
    11: "HIMACHAL_PRADESH",
    12: "JAIPUR",
    13: "JAMMU_AND_KASHMIR"
}

# -------------------------------
# Video preprocessing
# -------------------------------
def preprocess_video(video_path, processor, num_frames=16):
    vr = VideoReader(video_path, ctx=cpu(0))
    total_frames = len(vr)
    if total_frames < num_frames:
        indices = [i % total_frames for i in range(num_frames)]
    else:
        indices = torch.linspace(0, total_frames - 1, num_frames).long().tolist()
    video = vr.get_batch(indices).asnumpy()
    inputs = processor(list(video), return_tensors="pt")
    return inputs["pixel_values"][0]

# -------------------------------
# Prediction function
# -------------------------------
def predict_video(video_file):
    # video_file is a file-like object from Gradio
    video_path = video_file.name
    pixel_values = preprocess_video(video_path, processor)
    pixel_values = pixel_values.unsqueeze(0).to(device)
    with torch.no_grad():
        logits = model(pixel_values=pixel_values).logits
        pred_index = torch.argmax(logits, dim=1).item()
    return id2class[pred_index]

# -------------------------------
# Gradio Interface
# -------------------------------
iface = gr.Interface(
    fn=predict_video,
    inputs=gr.Video(type="file"),
    outputs="text",
    title="VideoMAE Classification API",
    description="Upload a video and get the predicted class."
)

# Expose API
iface.launch(server_name="0.0.0.0", server_port=7860, share=True)