diff --git a/run/eval_tracking_3d.py b/run/eval_tracking_3d.py index 795eb282ca27b92ee302d54094ce6c802f291e99..a36f8e6c3d9e4cad6c0074d229d316d9d5f0d242 100644 --- a/run/eval_tracking_3d.py +++ b/run/eval_tracking_3d.py @@ -395,32 +395,39 @@ def test_vnn_tapp( ) -def all_save_detection_inference(start=0): +def all_save_detection_inference(start=0, index=0, devices=1): models = [ # add models - ("tanet_car", test_tanet), - ("tapp_car", test_tapp), - ("pointpillars_car", test_pointpillars), + # ("tanet_car", test_tanet), + # ("tapp_car", test_tapp), + # # ("pointpillars_car", test_pointpillars), *[("vnn_tanet_car_s" + str(i), test_vnn_tanet) for i in range(1, 3 + 1)], + *[("badvnn_tanet_car_s" + str(i), test_vnn_tanet) for i in range(1, 3 + 1)], *[("vnn_tapp_car_s" + str(i), test_vnn_tapp) for i in range(1, 3 + 1)], - *[ - ("vnn_pointpillars_car_s" + str(i), test_vnn_pointpillars) - for i in range(1, 7 + 1) - ], - *[ - ("vnna_pointpillars_car_s" + str(i), test_vnna_pointpillars) - for i in range(1, 4 + 1) - ], + # *[ + # ("vnn_pointpillars_car_s" + str(i), test_vnn_pointpillars) + # for i in range(1, 7 + 1) + # ], + # *[ + # ("vnna_pointpillars_car_s" + str(i), test_vnna_pointpillars) + # for i in range(1, 4 + 1) + # ], ] - for i, (model_name, test_func) in enumerate(models): - print(":::", i + 1, "/", len(models)) + i = index + + while i < len(models): + + print(":::", i, "/", len(models)) if i < start: + i += devices continue + model_name, test_func = models[i] test_func(name=model_name, mode="save_detection_inference") + i += devices if __name__ == "__main__": fire.Fire()