import torch from transformers import VideoMAEForVideoClassification, VideoMAEImageProcessor from decord import VideoReader, cpu import gradio as gr # ------------------------------- # Device # ------------------------------- device = "cuda" if torch.cuda.is_available() else "cpu" # ------------------------------- # Load processor and model # ------------------------------- processor = VideoMAEImageProcessor.from_pretrained( "MCG-NJU/videomae-small-finetuned-ssv2" ) model = VideoMAEForVideoClassification.from_pretrained( "MCG-NJU/videomae-small-finetuned-ssv2", num_labels=14, ignore_mismatched_sizes=True ) checkpoint = torch.load("videomae_best.pth", map_location=device) model.load_state_dict(checkpoint["model_state_dict"]) model.to(device) model.eval() # ------------------------------- # Class mapping # ------------------------------- id2class = { 0: "AFGHANISTAN", 1: "AFRICA", 2: "ANDHRA_PRADESH", 3: "ARGENTINA", 4: "DELHI", 5: "DENMARK", 6: "ENGLAND", 7: "GANGTOK", 8: "GOA", 9: "GUJARAT", 10: "HARYANA", 11: "HIMACHAL_PRADESH", 12: "JAIPUR", 13: "JAMMU_AND_KASHMIR" } # ------------------------------- # Video preprocessing # ------------------------------- def preprocess_video(video_file, processor, num_frames=16): """ Preprocess a video file-like object for VideoMAE. """ video_path = video_file.name vr = VideoReader(video_path, ctx=cpu(0)) total_frames = len(vr) if total_frames < num_frames: indices = [i % total_frames for i in range(num_frames)] else: indices = torch.linspace(0, total_frames - 1, num_frames).long().tolist() video = vr.get_batch(indices).asnumpy() inputs = processor(list(video), return_tensors="pt") return inputs["pixel_values"][0] # ------------------------------- # Prediction function # ------------------------------- def predict_video(video_file): pixel_values = preprocess_video(video_file, processor) pixel_values = pixel_values.unsqueeze(0).to(device) with torch.no_grad(): logits = model(pixel_values=pixel_values).logits pred_index = torch.argmax(logits, dim=1).item() return id2class[pred_index] # ------------------------------- # Gradio Interface # ------------------------------- iface = gr.Interface( fn=predict_video, inputs=gr.File(file_types=[".mp4"]), # Accept any MP4 file outputs="text", title="VideoMAE Classification API", description="Upload a .mp4 video file to get the predicted class." ) # Launch Space (public URL) iface.launch(share=True)