From c664aad444375339cf5003de5195903b21a2ef77 Mon Sep 17 00:00:00 2001
From: Illia Oleksiienko <io@ece.au.dk>
Date: Sun, 16 Oct 2022 00:29:45 +0000
Subject: [PATCH] Add tracking 3d demo inside detection demo

---
 .../demos/voxel_object_detection_3d/README.md              | 4 ++--
 .../demos/voxel_object_detection_3d/demo.py                | 7 ++++++-
 2 files changed, 8 insertions(+), 3 deletions(-)

diff --git a/projects/python/perception/object_detection_3d/demos/voxel_object_detection_3d/README.md b/projects/python/perception/object_detection_3d/demos/voxel_object_detection_3d/README.md
index 0a8303a1be..b6b48042c8 100644
--- a/projects/python/perception/object_detection_3d/demos/voxel_object_detection_3d/README.md
+++ b/projects/python/perception/object_detection_3d/demos/voxel_object_detection_3d/README.md
@@ -19,12 +19,12 @@ pip install -e .
 ## Running the example
 Car 3D Object Detection using [TANet](https://arxiv.org/abs/1912.05163) from [KITTI](http://www.cvlibs.net/datasets/kitti/eval_object.php?obj_benchmark=3d)-like dataset
 ```bash
-python3 demo.py --ip=0.0.0.0 --port=2605 --algorithm=voxel --model_name=tanet_car_xyres_16 --source=disk --data_path=/data/sets/kitti_second/training/velodyne --model_config=configs/tanet_car_xyres_16.proto
+python3 demo.py --ip=0.0.0.0 --port=2605 --algorithm=voxel --model_name=tanet_car_xyres_16 --source=disk --data_path=/data/sets/kitti_tracking/training/velodyne/0000--model_config=configs/tanet_car_xyres_16.proto
 ```
 
 Car 3D Object Detection using [PointPillars](https://arxiv.org/abs/1812.05784) from [KITTI](http://www.cvlibs.net/datasets/kitti/eval_object.php?obj_benchmark=3d)-like dataset
 ```bash
-python3 demo.py --ip=0.0.0.0 --port=2605 --algorithm=voxel --model_name=pointpillars_car_xyres_16 --source=disk --data_path=/data/sets/kitti_second/training/velodyne --model_config=configs/tanet_car_xyres_16.proto
+python3 demo.py --ip=0.0.0.0 --port=2605 --algorithm=voxel --model_name=pointpillars_car_xyres_16 --source=disk --data_path=/data/sets/kitti_tracking/training/velodyne/0000 --model_config=configs/pointpillars_car_xyres_16.proto
 ```
 
 3D Object Detection using a specially trained model X for O3M Lidar
diff --git a/projects/python/perception/object_detection_3d/demos/voxel_object_detection_3d/demo.py b/projects/python/perception/object_detection_3d/demos/voxel_object_detection_3d/demo.py
index d113b26a05..2ce25d7733 100644
--- a/projects/python/perception/object_detection_3d/demos/voxel_object_detection_3d/demo.py
+++ b/projects/python/perception/object_detection_3d/demos/voxel_object_detection_3d/demo.py
@@ -23,6 +23,7 @@ from flask import Flask, Response, render_template, request
 
 # OpenDR imports
 from opendr.perception.object_detection_3d import VoxelObjectDetection3DLearner
+from opendr.perception.object_tracking_3d import ObjectTracking3DAb3dmotLearner
 from data_generators import (
     lidar_point_cloud_generator,
     disk_point_cloud_generator,
@@ -162,6 +163,7 @@ def voxel_object_detection_3d(config_path, model_name=None):
 
         # Init model
         detection_learner = VoxelObjectDetection3DLearner(config_path)
+        tracking_learner = ObjectTracking3DAb3dmotLearner()
 
         if model_name is not None and not os.path.exists(
             "./models/" + model_name
@@ -172,6 +174,7 @@ def voxel_object_detection_3d(config_path, model_name=None):
 
     else:
         detection_learner = None
+        tracking_learner = None
 
     def process_key(key):
 
@@ -284,8 +287,10 @@ def voxel_object_detection_3d(config_path, model_name=None):
 
             if predict:
                 predictions = detection_learner.infer(point_cloud)
+                tracking_predictions = tracking_learner.infer(predictions)
             else:
                 predictions = []
+                tracking_predictions = []
 
             if len(predictions) > 0:
                 print(
@@ -296,7 +301,7 @@ def voxel_object_detection_3d(config_path, model_name=None):
             t = time.time()
 
             frame_bev_2 = draw_point_cloud_bev(
-                point_cloud.data, predictions, scale, xs, ys
+                point_cloud.data, tracking_predictions, scale, xs, ys
             )
             frame_proj_2 = draw_point_cloud_projected_numpy(
                 point_cloud.data,
-- 
GitLab