| import os |
| import cv2 |
| import tqdm |
| import uuid |
| import logging |
|
|
| import torch |
| import trackers |
| import numpy as np |
| import gradio as gr |
| import imageio.v3 as iio |
| import supervision as sv |
|
|
| from pathlib import Path |
| from typing import List, Optional, Tuple |
|
|
| from PIL import Image |
| from pipeline import build_pipeline |
| from utils import cfg, load_config, load_onnx_model |
|
|
|
|
| |
| DETECTORS = { |
| "yolo8n-640": 'downloads/yolo8n-640.onnx', |
| "yolo8n-416": 'downloads/yolo8n-416.onnx', |
| } |
| DEFAULT_DETECTOR = list(DETECTORS.keys())[0] |
| DEFAULT_CONFIDENCE_THRESHOLD = 0.6 |
|
|
|
|
| |
| IMAGE_EXAMPLES = [ |
| {"path": "./examples/images/forest.jpg", "label": "Local Image"}, |
| {"path": "./examples/images/river.jpg", "label": "Local Image"}, |
| {"path": "./examples/images/ocean.jpg", "label": "Local Image"}, |
| ] |
|
|
| |
| MAX_NUM_FRAMES = 250 |
| ALLOWED_VIDEO_EXTENSIONS = {".mp4", ".avi", ".mov"} |
| VIDEO_OUTPUT_DIR = Path("static/videos") |
| VIDEO_OUTPUT_DIR.mkdir(parents=True, exist_ok=True) |
|
|
| class TrackingAlgorithm: |
| BYTETRACK = "ByteTrack (2021)" |
| DEEPSORT = "DeepSORT (2017)" |
| SORT = "SORT (2016)" |
|
|
| TRACKERS = [None, TrackingAlgorithm.BYTETRACK, TrackingAlgorithm.DEEPSORT, TrackingAlgorithm.SORT] |
| VIDEO_EXAMPLES = [ |
| {"path": "./examples/videos/sea.mp4", "label": "Local Video", "tracker": TrackingAlgorithm.BYTETRACK, "classes": "Person, Boat"}, |
| {"path": "./examples/videos/forest.mp4", "label": "Local Video", "tracker": TrackingAlgorithm.BYTETRACK, "classes": "LightVehicle, Person, Boat"}, |
| ] |
|
|
|
|
| |
| |
| color = sv.ColorPalette.from_hex([ |
| "#ffff00", "#ff9b00", "#ff8080", "#ff66b2", "#ff66ff", "#b266ff", |
| "#9999ff", "#3399ff", "#66ffff", "#33ff99", "#66ff66", "#99ff00" |
| ]) |
|
|
|
|
| logging.basicConfig( |
| level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" |
| ) |
| logger = logging.getLogger(__name__) |
|
|
|
|
| def get_pipeline(config: dict, onnx_path: str): |
| pipeline = build_pipeline(config) |
| load_onnx_model(pipeline.detector, onnx_path) |
| return pipeline |
|
|
|
|
| def detect_objects( |
| config: dict, |
| onnx_path: str, |
| images: List[np.ndarray] | np.ndarray, |
| confidence_threshold: float = DEFAULT_CONFIDENCE_THRESHOLD, |
| target_size: Optional[Tuple[int, int]] = None, |
| classes: Optional[List[str]] = None, |
| ): |
| config.defrost() |
| config.detector.thresholds.confidence = float(confidence_threshold) |
| config.freeze() |
| pipeline = get_pipeline(config, onnx_path) |
| id2label = pipeline.detector.get_category_mapping() |
| label2id = {v: k for k, v in pipeline.detector.get_category_mapping().items()} |
| if classes is not None: |
| wrong_classes = [cls for cls in classes if cls not in label2id] |
| if wrong_classes: |
| gr.Warning(f"Classes not found in model config: {wrong_classes}") |
| keep_ids = [label2id[cls] for cls in classes if cls in label2id] |
| else: |
| keep_ids = None |
|
|
| if isinstance(images, np.ndarray) and images.ndim == 4: |
| images = [x for x in images] |
|
|
| results = [] |
| for img in tqdm.tqdm(images, desc="Processing frames"): |
| output_ = pipeline(img) |
| output_reshaped = { |
| "scores": torch.from_numpy(output_.confidence) if isinstance(output_.confidence, np.ndarray) else output_.confidence, |
| "labels": torch.from_numpy(output_.class_id) if isinstance(output_.class_id, np.ndarray) else output_.class_id, |
| "boxes": torch.from_numpy(output_.xyxy) if isinstance(output_.xyxy, np.ndarray) else output_.xyxy, |
| } |
| results.append(output_reshaped) |
| if target_size: |
| |
| scale_x = target_size[0] / img.shape[1] |
| scale_y = target_size[1] / img.shape[0] |
| output_reshaped["boxes"][:, [0, 2]] *= scale_x |
| output_reshaped["boxes"][:, [1, 3]] *= scale_y |
| |
| |
|
|
| |
| for i, result in enumerate(results): |
| results[i] = {k: v for k, v in result.items()} |
| if keep_ids is not None: |
| keep = torch.isin(results[i]["labels"], torch.tensor(keep_ids)) |
| results[i] = {k: v[keep] for k, v in results[i].items()} |
| |
| |
| return results, pipeline.detector.get_category_mapping() |
|
|
|
|
| def process_image( |
| model: str = DEFAULT_DETECTOR, |
| image: Optional[Image.Image] = None, |
| confidence_threshold: float = DEFAULT_CONFIDENCE_THRESHOLD, |
| ): |
| |
| load_config(cfg, f'configs/{model}.yaml') |
| results, id2label = detect_objects( |
| config=cfg.pipeline, |
| onnx_path=DETECTORS[model], |
| images=[np.array(image)], |
| confidence_threshold=confidence_threshold, |
| ) |
| result = results[0] |
|
|
| annotations = [] |
| for label, score, box in zip(result["labels"], result["scores"], result["boxes"]): |
| text_label = id2label[label.item()] |
| formatted_label = f"{text_label} ({score:.2f})" |
| x_min, y_min, x_max, y_max = box.cpu().numpy().round().astype(int) |
| x_min = max(0, x_min) |
| y_min = max(0, y_min) |
| x_max = min(image.width - 1, x_max) |
| y_max = min(image.height - 1, y_max) |
| annotations.append(((x_min, y_min, x_max, y_max), formatted_label)) |
|
|
| return (image, annotations) |
|
|
|
|
| def get_target_size(image_height, image_width, max_size: int): |
| if image_height < max_size and image_width < max_size: |
| new_height, new_width = image_height, image_width |
| elif image_height > image_width: |
| new_height = max_size |
| new_width = int(image_width * max_size / image_height) |
| else: |
| new_width = max_size |
| new_height = int(image_height * max_size / image_width) |
| |
| |
| new_height = new_height // 2 * 2 |
| new_width = new_width // 2 * 2 |
|
|
| return new_width, new_height |
|
|
|
|
| def read_video_k_frames(video_path: str, k: int, read_every_i_frame: int = 1): |
| cap = cv2.VideoCapture(video_path) |
| frames = [] |
| i = 0 |
| progress_bar = tqdm.tqdm(total=k, desc="Reading frames") |
| while cap.isOpened() and len(frames) < k: |
| ret, frame = cap.read() |
| if not ret: |
| break |
| if i % read_every_i_frame == 0: |
| frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) |
| progress_bar.update(1) |
| i += 1 |
| cap.release() |
| progress_bar.close() |
| return frames |
|
|
|
|
| def get_tracker(tracker: str, fps: float): |
| if tracker == TrackingAlgorithm.SORT: |
| return trackers.SORTTracker(frame_rate=fps) |
| elif tracker == TrackingAlgorithm.DEEPSORT: |
| feature_extractor = trackers.DeepSORTFeatureExtractor.from_timm("mobilenetv4_conv_small.e1200_r224_in1k", device="cpu") |
| return trackers.DeepSORTTracker(feature_extractor, frame_rate=fps) |
| elif tracker == TrackingAlgorithm.BYTETRACK: |
| return sv.ByteTrack(frame_rate=int(fps)) |
| else: |
| raise ValueError(f"Invalid tracker: {tracker}") |
|
|
|
|
| def update_tracker(tracker, detections, frame): |
| tracker_name = tracker.__class__.__name__ |
| if tracker_name == "SORTTracker": |
| return tracker.update(detections) |
| elif tracker_name == "DeepSORTTracker": |
| return tracker.update(detections, frame) |
| elif tracker_name == "ByteTrack": |
| return tracker.update_with_detections(detections) |
| else: |
| raise ValueError(f"Invalid tracker: {tracker}") |
|
|
|
|
| def process_video( |
| video_path: str, |
| checkpoint: str, |
| tracker_algorithm: Optional[str] = None, |
| classes: str = "all", |
| confidence_threshold: float = DEFAULT_CONFIDENCE_THRESHOLD, |
| progress: gr.Progress = gr.Progress(track_tqdm=True), |
| ) -> str: |
| |
| if not video_path or not os.path.isfile(video_path): |
| raise ValueError(f"Invalid video path: {video_path}") |
|
|
| ext = os.path.splitext(video_path)[1].lower() |
| if ext not in ALLOWED_VIDEO_EXTENSIONS: |
| raise ValueError(f"Unsupported video format: {ext}, supported formats: {ALLOWED_VIDEO_EXTENSIONS}") |
|
|
| video_info = sv.VideoInfo.from_video_path(video_path) |
| read_each_i_frame = max(1, video_info.fps // 25) |
| target_fps = video_info.fps / read_each_i_frame |
| target_width, target_height = get_target_size(video_info.height, video_info.width, 1080) |
|
|
| n_frames_to_read = min(MAX_NUM_FRAMES, video_info.total_frames // read_each_i_frame) |
| frames = read_video_k_frames(video_path, n_frames_to_read, read_each_i_frame) |
| frames = [cv2.resize(frame, (target_width, target_height), interpolation=cv2.INTER_CUBIC) for frame in frames] |
|
|
| |
| |
| color_lookup = sv.ColorLookup.TRACK if tracker_algorithm else sv.ColorLookup.CLASS |
|
|
| box_annotator = sv.BoxAnnotator(color, color_lookup=color_lookup, thickness=1) |
| label_annotator = sv.LabelAnnotator(color, color_lookup=color_lookup, text_scale=0.5) |
|
|
| |
| if classes != "all": |
| classes_list = [cls.strip() for cls in classes.split(",")] |
| else: |
| classes_list = None |
|
|
| load_config(cfg, f'configs/{checkpoint}.yaml') |
| results, id2label = detect_objects( |
| config=cfg.pipeline, |
| onnx_path=DETECTORS[checkpoint], |
| images=np.array(frames), |
| confidence_threshold=confidence_threshold, |
| target_size=(target_height, target_width), |
| classes=classes_list, |
| ) |
|
|
|
|
| annotated_frames = [] |
|
|
| |
| if tracker_algorithm: |
| tracker = get_tracker(tracker_algorithm, target_fps) |
| for frame, result in progress.tqdm(zip(frames, results), desc="Tracking objects", total=len(frames)): |
| detections = sv.Detections.from_transformers(result, id2label=id2label) |
| detections = detections.with_nms(threshold=0.95, class_agnostic=True) |
| detections = update_tracker(tracker, detections, frame) |
| labels = [f"#{tracker_id} {id2label[class_id]}" for class_id, tracker_id in zip(detections.class_id, detections.tracker_id)] |
| annotated_frame = box_annotator.annotate(scene=frame, detections=detections) |
| annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels) |
| annotated_frames.append(annotated_frame) |
| |
| else: |
| for frame, result in tqdm.tqdm(zip(frames, results), desc="Annotating frames", total=len(frames)): |
| detections = sv.Detections.from_transformers(result, id2label=id2label) |
| detections = detections.with_nms(threshold=0.95, class_agnostic=True) |
| annotated_frame = box_annotator.annotate(scene=frame, detections=detections) |
| annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=detections) |
| annotated_frames.append(annotated_frame) |
|
|
| output_filename = os.path.join(VIDEO_OUTPUT_DIR, f"output_{uuid.uuid4()}.mp4") |
| iio.imwrite(output_filename, annotated_frames, fps=target_fps, codec="h264") |
| return output_filename |
|
|
|
|
|
|
| def create_image_inputs() -> List[gr.components.Component]: |
| return [ |
| gr.Image( |
| label="Upload Image", |
| type="pil", |
| sources=["upload", "webcam"], |
| interactive=True, |
| elem_classes="input-component", |
| ), |
| gr.Dropdown( |
| choices=list(DETECTORS.keys()), |
| label="Select Model Checkpoint", |
| value=DEFAULT_DETECTOR, |
| elem_classes="input-component", |
| ), |
| gr.Slider( |
| minimum=0.1, |
| maximum=1.0, |
| value=DEFAULT_CONFIDENCE_THRESHOLD, |
| step=0.1, |
| label="Confidence Threshold", |
| elem_classes="input-component", |
| ), |
| ] |
|
|
|
|
| def create_video_inputs() -> List[gr.components.Component]: |
| return [ |
| gr.Video( |
| label="Upload Video", |
| sources=["upload"], |
| interactive=True, |
| format="mp4", |
| elem_classes="input-component", |
| ), |
| gr.Dropdown( |
| choices=list(DETECTORS.keys()), |
| label="Select Model Checkpoint", |
| value=DEFAULT_DETECTOR, |
| elem_classes="input-component", |
| ), |
| gr.Dropdown( |
| choices=TRACKERS, |
| label="Select Tracker (Optional)", |
| value=None, |
| elem_classes="input-component", |
| ), |
| gr.TextArea( |
| label="Specify Class Names to Detect (comma separated)", |
| value="all", |
| lines=1, |
| elem_classes="input-component", |
| ), |
| gr.Slider( |
| minimum=0.1, |
| maximum=1.0, |
| value=DEFAULT_CONFIDENCE_THRESHOLD, |
| step=0.1, |
| label="Confidence Threshold", |
| elem_classes="input-component", |
| ), |
| ] |
|
|
|
|
| def create_button_row() -> List[gr.Button]: |
| return [ |
| gr.Button( |
| f"Detect Objects", variant="primary", elem_classes="action-button" |
| ), |
| gr.Button(f"Clear", variant="secondary", elem_classes="action-button"), |
| ] |
|
|
|
|
| |
| with gr.Blocks(theme=gr.themes.Ocean()) as demo: |
| gr.Markdown( |
| """ |
| # Pipeline for Aerial Search and Rescue Demo |
| Experience state-of-the-art object detection with Open Source [WALDO30](https://huggingface.co/StephanST/WALDO30) models. |
| - **Image** and **Video** modes are supported. |
| - Select a model and adjust the confidence threshold to see detections! |
| - On video mode, you can enable tracking powered by [Supervision](https://github.com/roboflow/supervision) and [Trackers](https://github.com/roboflow/trackers) from Roboflow. |
| |
| For more details and source code, visit the [PiSAR](https://github.com/eadali/PiSAR). |
| """, |
| elem_classes="header-text", |
| ) |
|
|
| with gr.Tabs(): |
| with gr.Tab("Image"): |
| with gr.Row(): |
| with gr.Column(scale=1, min_width=300): |
| with gr.Group(): |
| ( |
| image_input, |
| image_model_checkpoint, |
| image_confidence_threshold, |
| ) = create_image_inputs() |
| image_detect_button, image_clear_button = create_button_row() |
| with gr.Column(scale=2): |
| image_output = gr.AnnotatedImage( |
| label="Detection Results", |
| show_label=True, |
| color_map=None, |
| elem_classes="output-component", |
| ) |
| gr.Examples( |
| examples=[ |
| [ |
| DEFAULT_DETECTOR, |
| example["path"], |
| DEFAULT_CONFIDENCE_THRESHOLD, |
| ] |
| for example in IMAGE_EXAMPLES |
| ], |
| inputs=[ |
| image_model_checkpoint, |
| image_input, |
| image_confidence_threshold, |
| ], |
| outputs=[image_output], |
| fn=process_image, |
| label="Select an image example to populate inputs", |
| cache_examples=True, |
| cache_mode="lazy", |
| ) |
|
|
| with gr.Tab("Video"): |
| gr.Markdown( |
| f"The input video will be processed in ~25 FPS (up to {MAX_NUM_FRAMES} frames in result)." |
| ) |
| with gr.Row(): |
| with gr.Column(scale=1, min_width=300): |
| with gr.Group(): |
| video_input, video_checkpoint, video_tracker, video_classes, video_confidence_threshold = create_video_inputs() |
| video_detect_button, video_clear_button = create_button_row() |
| with gr.Column(scale=2): |
| video_output = gr.Video( |
| label="Detection Results", |
| format="mp4", |
| elem_classes="output-component", |
| ) |
|
|
| gr.Examples( |
| examples=[ |
| [example["path"], DEFAULT_DETECTOR, example["tracker"], example["classes"], DEFAULT_CONFIDENCE_THRESHOLD] |
| for example in VIDEO_EXAMPLES |
| ], |
| inputs=[video_input, video_checkpoint, video_tracker, video_classes, video_confidence_threshold], |
| outputs=[video_output], |
| fn=process_video, |
| cache_examples=False, |
| label="Select a video example to populate inputs", |
| ) |
|
|
| |
| image_clear_button.click( |
| fn=lambda: ( |
| None, |
| DEFAULT_DETECTOR, |
| DEFAULT_CONFIDENCE_THRESHOLD, |
| None, |
| ), |
| outputs=[ |
| image_input, |
| image_model_checkpoint, |
| image_confidence_threshold, |
| image_output, |
| ], |
| ) |
|
|
| |
| video_clear_button.click( |
| fn=lambda: ( |
| None, |
| DEFAULT_DETECTOR, |
| None, |
| "all", |
| DEFAULT_CONFIDENCE_THRESHOLD, |
| None, |
| ), |
| outputs=[ |
| video_input, |
| video_checkpoint, |
| video_tracker, |
| video_classes, |
| video_confidence_threshold, |
| video_output, |
| ], |
| ) |
|
|
| |
| image_detect_button.click( |
| fn=process_image, |
| inputs=[ |
| image_model_checkpoint, |
| image_input, |
| image_confidence_threshold, |
| ], |
| outputs=[image_output], |
| ) |
|
|
| |
| video_detect_button.click( |
| fn=process_video, |
| inputs=[video_input, video_checkpoint, video_tracker, video_classes, video_confidence_threshold], |
| outputs=[video_output], |
| ) |
|
|
| if __name__ == "__main__": |
| demo.queue(max_size=20).launch() |