Source code for tracklab.wrappers.bbox_detector.yolo_ultralytics_api

import logging
from typing import Any

import torch
import pandas as pd

from ultralytics import YOLO

from tracklab.pipeline.imagelevel_module import ImageLevelModule
from tracklab.utils.coordinates import ltrb_to_ltwh

log = logging.getLogger(__name__)


[docs] def collate_fn(batch): idxs = [b[0] for b in batch] images = [b["image"] for _, b in batch] shapes = [b["shape"] for _, b in batch] return idxs, (images, shapes)
[docs] class YOLOUltralytics(ImageLevelModule): collate_fn = collate_fn input_columns = [] output_columns = [ "image_id", "video_id", "category_id", "bbox_ltwh", "bbox_conf", ] def __init__(self, cfg, device, batch_size, **kwargs): super().__init__(batch_size) self.cfg = cfg self.device = device self.model = YOLO(cfg.path_to_checkpoint) self.model.to(device) self.id = 0
[docs] @torch.no_grad() def preprocess(self, image, detections, metadata: pd.Series): return { "image": image, "shape": (image.shape[1], image.shape[0]), }
[docs] @torch.no_grad() def process(self, batch: Any, detections: pd.DataFrame, metadatas: pd.DataFrame): images, shapes = batch results_by_image = self.model(images, verbose=False) detections = [] for results, shape, (_, metadata) in zip( results_by_image, shapes, metadatas.iterrows() ): for bbox in results.boxes.cpu().numpy(): # check for `person` class if bbox.cls == 0 and bbox.conf >= self.cfg.min_confidence: detections.append( pd.Series( dict( image_id=metadata.name, bbox_ltwh=ltrb_to_ltwh(bbox.xyxy[0], shape), bbox_conf=bbox.conf[0], video_id=metadata.video_id, category_id=1, # `person` class in posetrack ), name=self.id, ) ) self.id += 1 return detections