Spaces:
Running
Running
| import torch | |
| from transformers import VideoMAEForVideoClassification, VideoMAEImageProcessor | |
| from decord import VideoReader, cpu | |
| import gradio as gr | |
| # ------------------------------- | |
| # Device | |
| # ------------------------------- | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # ------------------------------- | |
| # Load processor and model | |
| # ------------------------------- | |
| processor = VideoMAEImageProcessor.from_pretrained( | |
| "MCG-NJU/videomae-small-finetuned-ssv2" | |
| ) | |
| model = VideoMAEForVideoClassification.from_pretrained( | |
| "MCG-NJU/videomae-small-finetuned-ssv2", | |
| num_labels=14, | |
| ignore_mismatched_sizes=True | |
| ) | |
| checkpoint = torch.load("videomae_best.pth", map_location=device) | |
| model.load_state_dict(checkpoint["model_state_dict"]) | |
| model.to(device) | |
| model.eval() | |
| # ------------------------------- | |
| # Class mapping | |
| # ------------------------------- | |
| id2class = { | |
| 0: "AFGHANISTAN", | |
| 1: "AFRICA", | |
| 2: "ANDHRA_PRADESH", | |
| 3: "ARGENTINA", | |
| 4: "DELHI", | |
| 5: "DENMARK", | |
| 6: "ENGLAND", | |
| 7: "GANGTOK", | |
| 8: "GOA", | |
| 9: "GUJARAT", | |
| 10: "HARYANA", | |
| 11: "HIMACHAL_PRADESH", | |
| 12: "JAIPUR", | |
| 13: "JAMMU_AND_KASHMIR" | |
| } | |
| # ------------------------------- | |
| # Video preprocessing | |
| # ------------------------------- | |
| def preprocess_video(video_file, processor, num_frames=16): | |
| """ | |
| Preprocess a video file-like object for VideoMAE. | |
| """ | |
| video_path = video_file.name | |
| vr = VideoReader(video_path, ctx=cpu(0)) | |
| total_frames = len(vr) | |
| if total_frames < num_frames: | |
| indices = [i % total_frames for i in range(num_frames)] | |
| else: | |
| indices = torch.linspace(0, total_frames - 1, num_frames).long().tolist() | |
| video = vr.get_batch(indices).asnumpy() | |
| inputs = processor(list(video), return_tensors="pt") | |
| return inputs["pixel_values"][0] | |
| # ------------------------------- | |
| # Prediction function | |
| # ------------------------------- | |
| def predict_video(video_file): | |
| pixel_values = preprocess_video(video_file, processor) | |
| pixel_values = pixel_values.unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| logits = model(pixel_values=pixel_values).logits | |
| pred_index = torch.argmax(logits, dim=1).item() | |
| return id2class[pred_index] | |
| # ------------------------------- | |
| # Gradio Interface | |
| # ------------------------------- | |
| iface = gr.Interface( | |
| fn=predict_video, | |
| inputs=gr.File(file_types=[".mp4"]), # Accept any MP4 file | |
| outputs="text", | |
| title="VideoMAE Classification API", | |
| description="Upload a .mp4 video file to get the predicted class." | |
| ) | |
| # Launch Space (public URL) | |
| iface.launch(share=True) | |