Source code for tracklab.wrappers.track.bot_sort_api

import torch
import numpy as np
import pandas as pd
from pathlib import Path

from tracklab.pipeline import ImageLevelModule
from tracklab.utils.coordinates import ltrb_to_ltwh
import bot_sort.bot_sort as bot_sort

import logging

from tracklab.utils.cv2 import cv2_load_image

log = logging.getLogger(__name__)

[docs] class BotSORT(ImageLevelModule): input_columns = [ "bbox_ltwh", "bbox_conf", "category_id", ] output_columns = ["track_id", "track_bbox_ltwh", "track_bbox_conf"] def __init__(self, cfg, device, **kwargs): super().__init__(batch_size=1) self.cfg = cfg self.device = device self.reset()
[docs] def reset(self): """Reset the tracker state to start tracking in a new video.""" self.model = bot_sort.BoTSORT( Path(self.cfg.model_weights), self.device, self.cfg.fp16, **self.cfg.hyperparams )
[docs] @torch.no_grad() def preprocess(self, image, detections: pd.DataFrame, metadata: pd.Series): processed_detections = [] if len(detections) == 0: return {"input": []} for det_id, detection in detections.iterrows(): ltrb = detection.bbox.ltrb() conf = detection.bbox.conf() cls = detection.category_id tracklab_id = int( processed_detections.append( np.array([*ltrb, conf, cls, tracklab_id]) ) return { "input": np.stack(processed_detections) }
[docs] @torch.no_grad() def process(self, batch, detections: pd.DataFrame, metadatas: pd.DataFrame): if len(detections) == 0: return [] inputs = batch["input"][0] # Nx7 [l,t,r,b,conf,class,tracklab_id] inputs = inputs[inputs[:, 4] > self.cfg.min_confidence] image = cv2_load_image(metadatas['file_path'].values[0]) results = self.model.update(inputs, image) results = np.asarray(results) # N'x8 [l,t,r,b,track_id,class,conf,idx] if results.size: track_bbox_ltwh = [ltrb_to_ltwh(x) for x in results[:, :4]] track_bbox_conf = list(results[:, 6]) track_ids = list(results[:, 4]) idxs = list(results[:, 7].astype(int)) assert set(idxs).issubset( detections.index ), "Mismatch of indexes during the tracking. The results should match the detections." results = pd.DataFrame( { "track_bbox_ltwh": track_bbox_ltwh, "track_bbox_conf": track_bbox_conf, "track_id": track_ids, "idxs": idxs, } ) results.set_index("idxs", inplace=True, drop=True) return results else: return []