Source code for tracklab.main

import os
from pathlib import Path

import rich.logging
import torch
import hydra
import warnings
import logging

from tracklab.datastruct import TrackerState
from tracklab.pipeline import Pipeline
from tracklab.utils import monkeypatch_hydra, progress, wandb

from hydra.utils import instantiate
from omegaconf import OmegaConf


os.environ["HYDRA_FULL_ERROR"] = "1"
log = logging.getLogger(__name__)

warnings.filterwarnings("ignore")


[docs] @hydra.main(version_base=None, config_path="pkg://tracklab.configs", config_name="config") def main(cfg): device = init_environment(cfg) # Instantiate all modules tracking_dataset = instantiate(cfg.dataset) evaluator = instantiate(cfg.eval, tracking_dataset=tracking_dataset) log.info(f"{tracking_dataset}") modules = [] if cfg.pipeline is not None: for name in cfg.pipeline: module = cfg.modules[name] inst_module = instantiate(module, device=device, tracking_dataset=tracking_dataset) modules.append(inst_module) pipeline = Pipeline(models=modules) # Train tracking modules for module in modules: if module.training_enabled: module.train(tracking_dataset, pipeline, evaluator, OmegaConf.to_container(cfg.dataset, resolve=True)) # Test tracking if cfg.test_tracking: log.info(f"Starting tracking operation on {cfg.dataset.eval_set} set.") # Init tracker state and tracking engine tracking_set = tracking_dataset.sets[cfg.dataset.eval_set] tracker_state = TrackerState(tracking_set, pipeline=pipeline, **cfg.state) tracking_engine = instantiate( cfg.engine, modules=pipeline, tracker_state=tracker_state, ) # Run tracking and visualization tracking_engine.track_dataset() # Evaluation evaluate(cfg, evaluator, tracker_state) # Save tracker state if tracker_state.save_file is not None: log.info(f"Saved state at : {tracker_state.save_file.resolve()}") close_environment() return 0
[docs] def set_sharing_strategy(): torch.multiprocessing.set_sharing_strategy( "file_system" )
[docs] def init_environment(cfg): # For Hydra and Slurm compatibility progress.use_rich = cfg.use_rich set_sharing_strategy() # Do not touch if torch.backends.mps.is_available(): device = "mps" elif torch.cuda.is_available(): device = "cuda" else: device = "cpu" if cfg.use_rich: for handler in log.root.handlers: if type(handler) is logging.StreamHandler: handler.setLevel(logging.ERROR) log.root.addHandler(rich.logging.RichHandler(level=logging.INFO)) else: # TODO : Fix for mmcv fix. This should be done in a nicer way for handler in log.root.handlers: if type(handler) is logging.StreamHandler: handler.setLevel(logging.INFO) wandb.init(cfg) log.info(f"Run directory: {Path().absolute()}") log.info(f"Using device: '{device}'.") if cfg.print_config: log.info(OmegaConf.to_yaml(cfg)) return device
[docs] def close_environment(): wandb.finish()
[docs] def evaluate(cfg, evaluator, tracker_state): if cfg.get("eval_tracking", True): # and cfg.dataset.nframes == -1: log.info("Starting evaluation.") evaluator.run(tracker_state) elif not cfg.get("eval_tracking", True): log.warning("Skipping evaluation because 'eval_tracking' was set to False.") else: log.warning( "Skipping evaluation because only part of video was tracked (i.e. 'cfg.dataset.nframes' was not set " "to -1)" )
if __name__ == "__main__": main()