File size: 2,613 Bytes
dab5eb8
 
 
 
 
5e272cd
 
 
dab5eb8
 
5e272cd
dab5eb8
5e272cd
 
 
 
 
dab5eb8
 
 
 
 
5e272cd
dab5eb8
 
 
 
 
5e272cd
dab5eb8
5e272cd
dab5eb8
5e272cd
 
 
 
 
 
 
 
 
 
 
 
 
 
dab5eb8
 
5e272cd
dab5eb8
5e272cd
 
 
 
 
 
dab5eb8
 
5e272cd
dab5eb8
 
 
 
5e272cd
dab5eb8
 
 
 
5e272cd
dab5eb8
5e272cd
dab5eb8
5e272cd
dab5eb8
5e272cd
dab5eb8
 
 
5e272cd
dab5eb8
 
5e272cd
 
 
dab5eb8
 
5e272cd
dab5eb8
 
5e272cd
dab5eb8
 
5e272cd
a38e917
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import torch
from transformers import VideoMAEForVideoClassification, VideoMAEImageProcessor
from decord import VideoReader, cpu
import gradio as gr

# -------------------------------
# Device
# -------------------------------
device = "cuda" if torch.cuda.is_available() else "cpu"

# -------------------------------
# Load processor and model
# -------------------------------
processor = VideoMAEImageProcessor.from_pretrained(
    "MCG-NJU/videomae-small-finetuned-ssv2"
)

model = VideoMAEForVideoClassification.from_pretrained(
    "MCG-NJU/videomae-small-finetuned-ssv2",
    num_labels=14,
    ignore_mismatched_sizes=True
)

checkpoint = torch.load("videomae_best.pth", map_location=device)
model.load_state_dict(checkpoint["model_state_dict"])
model.to(device)
model.eval()

# -------------------------------
# Class mapping
# -------------------------------
id2class = {
    0: "AFGHANISTAN",
    1: "AFRICA",
    2: "ANDHRA_PRADESH",
    3: "ARGENTINA",
    4: "DELHI",
    5: "DENMARK",
    6: "ENGLAND",
    7: "GANGTOK",
    8: "GOA",
    9: "GUJARAT",
    10: "HARYANA",
    11: "HIMACHAL_PRADESH",
    12: "JAIPUR",
    13: "JAMMU_AND_KASHMIR"
}

# -------------------------------
# Video preprocessing
# -------------------------------
def preprocess_video(video_file, processor, num_frames=16):
    """
    Preprocess a video file-like object for VideoMAE.
    """
    video_path = video_file.name
    vr = VideoReader(video_path, ctx=cpu(0))
    total_frames = len(vr)

    if total_frames < num_frames:
        indices = [i % total_frames for i in range(num_frames)]
    else:
        indices = torch.linspace(0, total_frames - 1, num_frames).long().tolist()

    video = vr.get_batch(indices).asnumpy()
    inputs = processor(list(video), return_tensors="pt")
    return inputs["pixel_values"][0]

# -------------------------------
# Prediction function
# -------------------------------
def predict_video(video_file):
    pixel_values = preprocess_video(video_file, processor)
    pixel_values = pixel_values.unsqueeze(0).to(device)

    with torch.no_grad():
        logits = model(pixel_values=pixel_values).logits
        pred_index = torch.argmax(logits, dim=1).item()

    return id2class[pred_index]

# -------------------------------
# Gradio Interface
# -------------------------------
iface = gr.Interface(
    fn=predict_video,
    inputs=gr.File(file_types=[".mp4"]),  # Accept any MP4 file
    outputs="text",
    title="VideoMAE Classification API",
    description="Upload a .mp4 video file to get the predicted class."
)

# Launch Space (public URL)
iface.launch(share=True)