diff --git a/run/eval_tracking_3d.py b/run/eval_tracking_3d.py
new file mode 100644
index 0000000000000000000000000000000000000000..3275444c004de4efc0e80da577ffb4bf93db1a66
--- /dev/null
+++ b/run/eval_tracking_3d.py
@@ -0,0 +1,225 @@
+import sys
+import os
+import torch
+import tqdm
+import fire
+from opendr.engine.target import BoundingBox3D
+from opendr.engine.datasets import PointCloudsDatasetIterator
+from opendr.perception.object_detection_3d import VoxelObjectDetection3DLearner
+from opendr.perception.object_detection_3d import (
+    LabeledPointCloudsDatasetIterator,
+)
+from opendr.perception.object_tracking_3d.datasets.kitti_tracking import LabeledTrackingPointCloudsDatasetIterator
+
+config_roots = {
+    "pointpillars": os.path.join(
+        ".",
+        "src",
+        "opendr",
+        "perception",
+        "object_detection_3d",
+        "voxel_object_detection_3d",
+        "second_detector",
+        "configs",
+        "pointpillars",
+        "car",
+    ),
+    "tanet": os.path.join(
+        ".",
+        "src",
+        "opendr",
+        "perception",
+        "object_detection_3d",
+        "voxel_object_detection_3d",
+        "second_detector",
+        "configs",
+        "tanet",
+        "car",
+    ),
+}
+temp_dir = "./run/models"
+subsets_path = os.path.join(
+    ".",
+    "src",
+    "opendr",
+    "perception",
+    "object_detection_3d",
+    "datasets",
+    "kitti_subsets",
+)
+
+dataset_tracking_path = "/data/sets/kitti_tracking"
+datasets = {}
+
+all_track_ids = ["0000"]
+
+for track_id in all_track_ids:
+    datasets[track_id] = LabeledTrackingPointCloudsDatasetIterator(
+        dataset_tracking_path + "/training/velodyne/" + track_id,
+        dataset_tracking_path + "/training/label_02/" + track_id + ".txt",
+        dataset_tracking_path + "/training/calib/" + track_id + ".txt",
+    )
+
+
+def save_detection_inference(
+    model_type,
+    device="cuda:0",
+    name="pointpillars_car",
+    samples_list=[1],
+    config="xyres_16.proto",
+    eval_suffix="classic",
+):
+
+    config = os.path.join(config_roots[model_type], config,)
+    model_path = os.path.join(temp_dir, name)
+
+    learner = VoxelObjectDetection3DLearner(
+        model_config_path=config, device=device, checkpoint_after_iter=1000, return_uncertainty=True
+    )
+    learner.load(model_path)
+
+    results = {}
+    for samples in samples_list:
+        print("samples =", samples)
+
+        for track_id in all_track_ids:
+
+            dataset = datasets[track_id]
+            
+            os.makedirs(os.path.join(
+                model_path, "tracking_inference_detections", eval_suffix, "samples_" + str(samples)
+            ), exist_ok=True)
+
+            with open(
+                os.path.join(
+                    model_path,
+                    "tracking_inference_detections",
+                    eval_suffix,
+                    "samples_" + str(samples),
+                    track_id + ".txt"
+                ), "w"
+            ) as f:
+                for frame, (input, _) in tqdm.tqdm(enumerate(dataset), total=len(dataset)):
+                    output = learner.infer(input, samples=samples)
+
+                    result = "\n".join(box.to_kitti_tracking_string(frame) for box in output)
+
+                    if len(output) > 0:
+                        result += "\n"
+
+                    f.write(result)
+
+
+
+def test_model(
+    model_type,
+    device="cuda:0",
+    name="pointpillars_car",
+    samples_list=[1],
+    config="xyres_16.proto",
+    eval_suffix="classic",
+):
+
+    return save_detection_inference(
+        model_type=model_type,
+        device=device,
+        name=name,
+        samples_list=samples_list,
+        config=config,
+        eval_suffix=eval_suffix,
+    )
+
+    # config = os.path.join(config_roots[model_type], config,)
+    # model_path = os.path.join(temp_dir, name)
+
+    # learner = VoxelObjectDetection3DLearner(
+    #     model_config_path=config, device=device, checkpoint_after_iter=1000
+    # )
+    # learner.load(model_path)
+
+
+
+def test_pointpillars(
+    device="cuda:0",
+    name="pointpillars_car",
+    samples_list=[1],
+    config="xyres_16.proto",
+    eval_suffix="classic",
+):
+    return test_model(
+        "pointpillars",
+        device=device,
+        name=name,
+        samples_list=samples_list,
+        config=config,
+        eval_suffix=eval_suffix,
+    )
+
+
+def test_tanet(
+    device="cuda:0",
+    name="tanet_car",
+    samples_list=[1],
+    config="xyres_16.proto",
+    eval_suffix="classic",
+):
+    return test_model(
+        "tanet",
+        device=device,
+        name=name,
+        samples_list=samples_list,
+        config=config,
+        eval_suffix=eval_suffix,
+    )
+
+
+def test_vnn_pointpillars(
+    device="cuda:0",
+    name="vnn_pointpillars_car",
+    samples_list=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
+    config="vnn_xyres_16.proto",
+    eval_suffix="vnn",
+):
+    return test_pointpillars(
+        device=device,
+        name=name,
+        samples_list=samples_list,
+        config=config,
+        eval_suffix=eval_suffix,
+    )
+
+def test_vnna_pointpillars(
+    device="cuda:0",
+    name="vnna_pointpillars_car",
+    samples_list=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
+    config="vnna_xyres_16.proto",
+    eval_suffix="vnn",
+):
+    return test_pointpillars(
+        device=device,
+        name=name,
+        samples_list=samples_list,
+        config=config,
+        eval_suffix=eval_suffix,
+    )
+
+
+def test_vnn_tanet(
+    device="cuda:0",
+    name="vnn_tanet_car",
+    samples_list=[1, 2, 3, 4],
+    config="vnn_xyres_16.proto",
+    eval_suffix="vnn",
+):
+    return test_model(
+        "tanet",
+        device=device,
+        name=name,
+        samples_list=samples_list,
+        config=config,
+        eval_suffix=eval_suffix,
+    )
+
+
+if __name__ == "__main__":
+    fire.Fire()
diff --git a/run/train_3d.py b/run/train_3d.py
index 0d5c4f5b82ab49add42af023bebff516e4aad58c..4be41991e3143865d206025c61e79473cedd817a 100644
--- a/run/train_3d.py
+++ b/run/train_3d.py
@@ -64,6 +64,7 @@ def save_model(
         device=device,
         checkpoint_after_iter=1000,
         checkpoint_load_iter=load,
+        return_uncertainty=True
     )
 
     learner.save(model_path)
@@ -95,6 +96,7 @@ def train_model(
             device=device,
             checkpoint_after_iter=1000,
             checkpoint_load_iter=load,
+            return_uncertainty=True
         )
 
         learner.fit(
@@ -137,6 +139,23 @@ def train_tanet(
     )
 
 
+def train_tapp(
+    device="cuda:0",
+    load=0,
+    name="tapp_car",
+    config="tapp_16.proto",
+    samples_list=None,
+):
+    return train_model(
+        "tanet",
+        device=device,
+        load=load,
+        name=name,
+        samples_list=samples_list,
+        config=config,
+    )
+
+
 def train_vnn_pointpillars(
     device="cuda:0",
     load=0,
@@ -178,6 +197,23 @@ def train_vnn_tanet(
     )
 
 
+def train_vnn_tapp(
+    device="cuda:0",
+    load=0,
+    name="vnn_tapp_car",
+    config="vnn_tapp_16.proto",
+    samples_list=[2, 1],
+):
+    return train_model(
+        "tanet",
+        device=device,
+        load=load,
+        name=name,
+        samples_list=samples_list,
+        config=config,
+    )
+
+
 def test_model(
     model_type,
     device="cuda:0",
@@ -191,7 +227,7 @@ def test_model(
     model_path = os.path.join(temp_dir, name)
 
     learner = VoxelObjectDetection3DLearner(
-        model_config_path=config, device=device, checkpoint_after_iter=1000
+        model_config_path=config, device=device, checkpoint_after_iter=1000, return_uncertainty=True
     )
     learner.load(model_path)
 
@@ -246,6 +282,23 @@ def test_tanet(
     )
 
 
+def test_tapp(
+    device="cuda:0",
+    name="tapp_car",
+    samples_list=[1],
+    config="tapp_16.proto",
+    eval_suffix="classic",
+):
+    return test_model(
+        "tanet",
+        device=device,
+        name=name,
+        samples_list=samples_list,
+        config=config,
+        eval_suffix=eval_suffix,
+    )
+
+
 def test_vnn_pointpillars(
     device="cuda:0",
     name="vnn_pointpillars_car",
@@ -294,5 +347,22 @@ def test_vnn_tanet(
     )
 
 
+def test_vnn_tapp(
+    device="cuda:0",
+    name="vnn_tapp_car",
+    samples_list=[1, 2, 3, 4],
+    config="vnn_tapp_16.proto",
+    eval_suffix="vnn",
+):
+    return test_model(
+        "tanet",
+        device=device,
+        name=name,
+        samples_list=samples_list,
+        config=config,
+        eval_suffix=eval_suffix,
+    )
+
+
 if __name__ == "__main__":
     fire.Fire()
diff --git a/src/opendr/engine/target.py b/src/opendr/engine/target.py
index 6b26fca636d2929db926550d0b2029233e5f7248..95b1937ccb670f63b2c65e936c3a90b827131d89 100644
--- a/src/opendr/engine/target.py
+++ b/src/opendr/engine/target.py
@@ -675,18 +675,18 @@ class BoundingBox3D(Target):
         return result
 
     def to_kitti_tracking_string(self, frame):
-        result = " ".join([
+        result = " ".join(map(str, [
             frame,
             self.name,
-            self.truncated,
-            self.occluded,
-            self.alpha,
-            *self.bbox2d,  # x y w h
+            float(self.truncated),
+            float(self.occluded),
+            float(self.alpha),
+            *([0,0,0,0] if self.bbox2d is None else self.bbox2d),  # x y w h
             *self.dimensions,  # w h l 
             *self.location,  # x y z
             self.rotation_y,
-            self.score,
-        ])
+            float(self.confidence),
+        ]))
 
         return result
 
@@ -1111,3 +1111,296 @@ class Heatmap(Target):
         :rtype: str
         """
         return str(self.data)
+
+
+
+class UncertaintyBoundingBox3D(BoundingBox3D):
+    """
+    This target is used for 3D Object Detection and Tracking with uncertainty.
+    A bounding box is described by its location (x, y, z), dimensions (w, h, d) and rotation (along vertical y axis).
+    Additional fields are used to describe confidence (score), 2D projection of the box on camera image (bbox2d),
+    truncation (truncated) and occlusion (occluded) levels, the name of an object (name) and
+    observation angle of an object (alpha).
+    UncertaintyBoundingBox3D provides uncertainty values for all regressed fields (this excludes name).
+    """
+
+    def __init__(
+            self,
+            name,
+            truncated,
+            occluded,
+            alpha,
+            bbox2d,
+            dimensions,
+            location,
+            rotation_y,
+            variance_truncated,
+            variance_occluded,
+            variance_alpha,
+            variance_bbox2d,
+            variance_dimensions,
+            variance_location,
+            variance_rotation_y,
+            score=0,
+            variance_score=0,
+    ):
+        super().__init__(
+            name,
+            truncated,
+            occluded,
+            alpha,
+            bbox2d,
+            dimensions,
+            location,
+            rotation_y,
+            score=0,
+        )
+        self.data.update({
+            "variance_truncated": variance_truncated,
+            "variance_occluded": variance_occluded,
+            "variance_alpha": variance_alpha,
+            "variance_bbox2d": variance_bbox2d,
+            "variance_dimensions": variance_dimensions,
+            "variance_location": variance_location,
+            "variance_rotation_y": variance_rotation_y,
+        })
+        self.confidence = score
+        self.variance_confidence = variance_score
+
+    def kitti(self):
+        result = super().kitti()
+
+        result["variance_truncated"] = np.array([self.data["variance_truncated"]])
+        result["variance_occluded"] = np.array([self.data["variance_occluded"]])
+        result["variance_alpha"] = np.array([self.data["variance_alpha"]])
+        result["variance_bbox"] = np.array([self.data["variance_bbox2d"]])
+        result["variance_dimensions"] = np.array([self.data["variance_dimensions"]])
+        result["variance_location"] = np.array([self.data["variance_location"]])
+        result["variance_rotation_y"] = np.array([self.data["variance_rotation_y"]])
+        result["variance_score"] = np.array([self.variance_confidence])
+
+        return result
+
+    def to_kitti_tracking_string(self, frame):
+
+        result = super().to_kitti_tracking_string(frame)
+
+        result += " ".join(map(str, [
+            float(self.variance_truncated),
+            float(self.variance_occluded),
+            float(self.variance_alpha),
+            *([0,0,0,0] if self.variance_bbox2d is None else self.variance_bbox2d),  # x y w h
+            *self.variance_dimensions,  # w h l 
+            *self.variance_location,  # x y z
+            self.variance_rotation_y,
+            float(self.variance_confidence),
+        ]))
+
+        return result
+
+    @property
+    def name(self):
+        return self.data["name"]
+
+    @property
+    def truncated(self):
+        return self.data["truncated"]
+
+    @property
+    def occluded(self):
+        return self.data["occluded"]
+
+    @property
+    def alpha(self):
+        return self.data["alpha"]
+
+    @property
+    def bbox2d(self):
+        return self.data["bbox2d"]
+
+    @property
+    def dimensions(self):
+        return self.data["dimensions"]
+
+    @property
+    def location(self):
+        return self.data["location"]
+
+    @property
+    def rotation_y(self):
+        return self.data["rotation_y"]
+
+    @property
+    def variance_truncated(self):
+        return self.data["variance_truncated"]
+
+    @property
+    def variance_occluded(self):
+        return self.data["variance_occluded"]
+
+    @property
+    def variance_alpha(self):
+        return self.data["variance_alpha"]
+
+    @property
+    def variance_bbox2d(self):
+        return self.data["variance_bbox2d"]
+
+    @property
+    def variance_dimensions(self):
+        return self.data["variance_dimensions"]
+
+    @property
+    def variance_location(self):
+        return self.data["variance_location"]
+
+    @property
+    def variance_rotation_y(self):
+        return self.data["variance_rotation_y"]
+
+    def __repr__(self):
+        return "UncertaintyBoundingBox3D " + str(self)
+
+    def __str__(self):
+        return str(self.kitti())
+
+
+class UncertaintyBoundingBox3DList(Target):
+    """
+    This target is used for 3D Object Detection with uncertainty. It contains a list of UncertaintyBoundingBox3D targets.
+    A bounding box is described by its location (x, y, z), dimensions (l, h, w) and rotation (along vertical (y) axis).
+    Additional fields are used to describe confidence (score), 2D projection of the box on camera image (bbox2d),
+    truncation (truncated) and occlusion (occluded) levels, the name of an object (name) and
+    observation angle of an object (alpha).
+    UncertaintyBoundingBox3D provides uncertainty values for all regressed fields (this excludes name).
+    """
+
+    def __init__(
+            self,
+            bounding_boxes_3d
+    ):
+        super().__init__()
+        self.data = bounding_boxes_3d
+        self.confidence = None if len(self.data) == 0 else np.mean([box.confidence for box in self.data])
+        self.variance_confidence = None if len(self.data) == 0 else np.mean([box.variance_confidence for box in self.data])
+
+    @staticmethod
+    def from_kitti(boxes_kitti):
+
+        count = len(boxes_kitti["name"])
+
+        boxes3d = []
+
+        for i in range(count):
+            box3d = UncertaintyBoundingBox3D(
+                boxes_kitti["name"][i],
+                boxes_kitti["truncated"][i],
+                boxes_kitti["occluded"][i],
+                boxes_kitti["alpha"][i],
+                boxes_kitti["bbox"][i],
+                boxes_kitti["dimensions"][i],
+                boxes_kitti["location"][i],
+                boxes_kitti["rotation_y"][i],
+                boxes_kitti["variance_truncated"][i],
+                boxes_kitti["variance_occluded"][i],
+                boxes_kitti["variance_alpha"][i],
+                boxes_kitti["variance_bbox"][i],
+                boxes_kitti["variance_dimensions"][i],
+                boxes_kitti["variance_location"][i],
+                boxes_kitti["variance_rotation_y"][i],
+                boxes_kitti["score"][i],
+                boxes_kitti["variance_score"][i],
+            )
+
+            boxes3d.append(box3d)
+
+        return UncertaintyBoundingBox3DList(boxes3d)
+
+    def kitti(self):
+
+        result = {
+            "name": [],
+            "truncated": [],
+            "occluded": [],
+            "alpha": [],
+            "bbox": [],
+            "dimensions": [],
+            "location": [],
+            "rotation_y": [],
+            "variance_truncated": [],
+            "variance_occluded": [],
+            "variance_alpha": [],
+            "variance_bbox": [],
+            "variance_dimensions": [],
+            "variance_location": [],
+            "variance_rotation_y": [],
+            "score": [],
+            "variance_score": [],
+        }
+
+        if len(self.data) == 0:
+            return result
+        elif len(self.data) == 1:
+            return self.data[0].kitti()
+        else:
+
+            for box in self.data:
+                result["name"].append(box.data["name"])
+                result["truncated"].append(box.data["truncated"])
+                result["occluded"].append(box.data["occluded"])
+                result["alpha"].append(box.data["alpha"])
+                result["bbox"].append(box.data["bbox2d"])
+                result["dimensions"].append(box.data["dimensions"])
+                result["location"].append(box.data["location"])
+                result["rotation_y"].append(box.data["rotation_y"])
+                result["variance_truncated"].append(box.data["variance_truncated"])
+                result["variance_occluded"].append(box.data["variance_occluded"])
+                result["variance_alpha"].append(box.data["variance_alpha"])
+                result["variance_bbox"].append(box.data["variance_bbox2d"])
+                result["variance_dimensions"].append(box.data["variance_dimensions"])
+                result["variance_location"].append(box.data["variance_location"])
+                result["variance_rotation_y"].append(box.data["variance_rotation_y"])
+                result["score"].append(box.confidence)
+                result["variance_score"].append(box.variance_confidence)
+
+            result["name"] = np.array(result["name"])
+            result["truncated"] = np.array(result["truncated"])
+            result["occluded"] = np.array(result["occluded"])
+            result["alpha"] = np.array(result["alpha"])
+            result["bbox"] = np.array(result["bbox"])
+            result["dimensions"] = np.array(result["dimensions"])
+            result["location"] = np.array(result["location"])
+            result["rotation_y"] = np.array(result["rotation_y"])
+            result["variance_truncated"] = np.array(result["truncated"])
+            result["variance_occluded"] = np.array(result["occluded"])
+            result["variance_alpha"] = np.array(result["alpha"])
+            result["variance_bbox"] = np.array(result["bbox"])
+            result["variance_dimensions"] = np.array(result["dimensions"])
+            result["variance_location"] = np.array(result["location"])
+            result["variance_rotation_y"] = np.array(result["rotation_y"])
+            result["score"] = np.array(result["score"])
+
+            num_ground_truths = len(result["name"])
+            num_objects = len([x for x in result["name"] if x != "DontCare"])
+            index = list(range(num_objects)) + [-1] * (num_ground_truths - num_objects)
+            result["index"] = np.array(index, dtype=np.int32)
+            result["group_ids"] = np.arange(num_ground_truths, dtype=np.int32)
+
+        return result
+
+    @property
+    def boxes(self):
+        return self.data
+
+    def __getitem__(self, idx):
+        return self.boxes[idx]
+
+    def __len__(self):
+        return len(self.data)
+
+    def __repr__(self):
+        return "UncertaintyBoundingBox3DList " + str(self)
+
+    def __str__(self):
+        return str(self.kitti())
+
diff --git a/src/opendr/perception/object_detection_3d/voxel_object_detection_3d/second_detector/configs/tanet/car/tapp_16.proto b/src/opendr/perception/object_detection_3d/voxel_object_detection_3d/second_detector/configs/tanet/car/tapp_16.proto
new file mode 100644
index 0000000000000000000000000000000000000000..c947358ccbdf9096733125c9927347c8b36036bc
--- /dev/null
+++ b/src/opendr/perception/object_detection_3d/voxel_object_detection_3d/second_detector/configs/tanet/car/tapp_16.proto
@@ -0,0 +1,173 @@
+model: {
+  second: {
+    voxel_generator {
+      point_cloud_range : [0, -39.68, -3, 69.12, 39.68, 1]
+      voxel_size : [0.16, 0.16, 4]
+      max_number_of_points_per_voxel : 100
+    }
+    num_class: 1
+    voxel_feature_extractor: {
+      module_class_name: "PillarFeature_TANet"
+      num_filters: [64]
+      with_distance: false
+    }
+    middle_feature_extractor: {
+      module_class_name: "PointPillarsScatter"
+    }
+    rpn: {
+      module_class_name: "RPN"
+      layer_nums: [3, 5, 5]
+      layer_strides: [2, 2, 2]
+      num_filters: [64, 128, 256]
+      upsample_strides: [1, 2, 4]
+      num_upsample_filters: [128, 128, 128]
+      use_groupnorm: false
+      num_groups: 32
+    }
+    loss: {
+      classification_loss: {
+        weighted_sigmoid_focal: {
+          alpha: 0.25
+          gamma: 2.0
+          anchorwise_output: true
+        }
+      }
+      localization_loss: {
+        weighted_smooth_l1: {
+          sigma: 3.0
+          code_weight: [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
+        }
+      }
+      classification_weight: 1.0
+      localization_weight: 2.0
+    }
+    # Outputs
+    use_sigmoid_score: true
+    encode_background_as_zeros: true
+    encode_rad_error_by_sin: true
+
+    use_direction_classifier: true
+    direction_loss_weight: 0.2
+    use_aux_classifier: false
+    # Loss
+    pos_class_weight: 1.0
+    neg_class_weight: 1.0
+
+    loss_norm_type: NormByNumPositives
+    # Postprocess
+    post_center_limit_range: [0, -39.68, -5, 69.12, 39.68, 5]
+    use_rotate_nms: false
+    use_multi_class_nms: false
+    nms_pre_max_size: 1000
+    nms_post_max_size: 300
+    nms_score_threshold: 0.3 #0.05
+    nms_iou_threshold: 0.1 #0.5
+
+    use_bev: false
+    num_point_features: 4
+    without_reflectivity: false
+    box_coder: {
+      ground_box3d_coder: {
+        linear_dim: false
+        encode_angle_vector: false
+      }
+    }
+    target_assigner: {
+      anchor_generators: {
+         anchor_generator_stride: {
+           sizes: [1.6, 3.9, 1.56] # wlh
+           strides: [0.32, 0.32, 0.0] # if generate only 1 z_center, z_stride will be ignored
+           offsets: [0.16, -39.52, -1.78] # origin_offset + strides / 2
+           rotations: [0, 1.57] # 0, pi/2
+           matched_threshold : 0.6
+           unmatched_threshold : 0.45
+         }
+       }
+
+      sample_positive_fraction : -1
+      sample_size : 512
+      region_similarity_calculator: {
+        nearest_iou_similarity: {
+        }
+      }
+    }
+  }
+}
+
+
+train_input_reader: {
+  record_file_path: "kitti_train.tfrecord"
+  class_names: ["Car"]
+  max_num_epochs : 160
+  batch_size: 1
+  prefetch_size : 25
+  max_number_of_voxels: 12000
+  shuffle_points: true
+  num_workers: 2
+  groundtruth_localization_noise_std: [0.25, 0.25, 0.25]
+  groundtruth_rotation_uniform_noise: [-0.15707963267, 0.15707963267]
+  global_rotation_uniform_noise: [-0.78539816, 0.78539816]
+  global_scaling_uniform_noise: [0.95, 1.05]
+  global_random_rotation_range_per_object: [0, 0]
+  anchor_area_threshold: 1
+  remove_points_after_sample: false
+  groundtruth_points_drop_percentage: 0.0
+  groundtruth_drop_max_keep_points: 15
+  database_sampler {
+    database_info_path: "kitti_dbinfos_train.pkl"
+    sample_groups {
+      name_to_max_num {
+        key: "Car"
+        value: 15
+      }
+    }
+    database_prep_steps {
+      filter_by_min_num_points {
+        min_num_point_pairs {
+          key: "Car"
+          value: 5
+        }
+      }
+    }
+    database_prep_steps {
+      filter_by_difficulty {
+        removed_difficulties: [-1]
+      }
+    }
+    global_random_rotation_range_per_object: [0, 0]
+    rate: 1.0
+  }
+
+  remove_unknown_examples: false
+  remove_environment: false
+  kitti_info_path: "kitti_infos_train.pkl"
+  kitti_root_path: ""
+}
+
+train_config: {
+
+  inter_op_parallelism_threads: 4
+  intra_op_parallelism_threads: 4
+  steps: 296960 # 1856 steps per epoch * 160 epochs
+  steps_per_eval: 9280 # 1856 steps per epoch * 5 epochs
+  save_checkpoints_secs : 1800 # half hour
+  save_summary_steps : 10
+  enable_mixed_precision: false
+  loss_scale_factor : 512.0
+  clear_metrics_every_epoch: false
+}
+
+eval_input_reader: {
+  record_file_path: "kitti_val.tfrecord"
+  class_names: ["Car"]
+  batch_size: 1
+  max_num_epochs : 160
+  prefetch_size : 25
+  max_number_of_voxels: 12000
+  shuffle_points: false
+  num_workers: 1
+  anchor_area_threshold: 1
+  remove_environment: false
+  kitti_info_path: "kitti_infos_val.pkl"
+  kitti_root_path: ""
+}
diff --git a/src/opendr/perception/object_detection_3d/voxel_object_detection_3d/second_detector/configs/tanet/car/vnn_tapp_16.proto b/src/opendr/perception/object_detection_3d/voxel_object_detection_3d/second_detector/configs/tanet/car/vnn_tapp_16.proto
new file mode 100644
index 0000000000000000000000000000000000000000..23c5baae5b1c78c3522a26daf08702b0996a97d3
--- /dev/null
+++ b/src/opendr/perception/object_detection_3d/voxel_object_detection_3d/second_detector/configs/tanet/car/vnn_tapp_16.proto
@@ -0,0 +1,173 @@
+model: {
+  second: {
+    voxel_generator {
+      point_cloud_range : [0, -39.68, -3, 69.12, 39.68, 1]
+      voxel_size : [0.16, 0.16, 4]
+      max_number_of_points_per_voxel : 100
+    }
+    num_class: 1
+    voxel_feature_extractor: {
+      module_class_name: "PillarFeature_TANet"
+      num_filters: [64]
+      with_distance: false
+    }
+    middle_feature_extractor: {
+      module_class_name: "PointPillarsScatter"
+    }
+    rpn: {
+      module_class_name: "VRPN"
+      layer_nums: [3, 5, 5]
+      layer_strides: [2, 2, 2]
+      num_filters: [64, 128, 256]
+      upsample_strides: [1, 2, 4]
+      num_upsample_filters: [128, 128, 128]
+      use_groupnorm: false
+      num_groups: 32
+    }
+    loss: {
+      classification_loss: {
+        weighted_sigmoid_focal: {
+          alpha: 0.25
+          gamma: 2.0
+          anchorwise_output: true
+        }
+      }
+      localization_loss: {
+        weighted_smooth_l1: {
+          sigma: 3.0
+          code_weight: [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
+        }
+      }
+      classification_weight: 1.0
+      localization_weight: 2.0
+    }
+    # Outputs
+    use_sigmoid_score: true
+    encode_background_as_zeros: true
+    encode_rad_error_by_sin: true
+
+    use_direction_classifier: true
+    direction_loss_weight: 0.2
+    use_aux_classifier: false
+    # Loss
+    pos_class_weight: 1.0
+    neg_class_weight: 1.0
+
+    loss_norm_type: NormByNumPositives
+    # Postprocess
+    post_center_limit_range: [0, -39.68, -5, 69.12, 39.68, 5]
+    use_rotate_nms: false
+    use_multi_class_nms: false
+    nms_pre_max_size: 1000
+    nms_post_max_size: 300
+    nms_score_threshold: 0.3 #0.05
+    nms_iou_threshold: 0.1 #0.5
+
+    use_bev: false
+    num_point_features: 4
+    without_reflectivity: false
+    box_coder: {
+      ground_box3d_coder: {
+        linear_dim: false
+        encode_angle_vector: false
+      }
+    }
+    target_assigner: {
+      anchor_generators: {
+         anchor_generator_stride: {
+           sizes: [1.6, 3.9, 1.56] # wlh
+           strides: [0.32, 0.32, 0.0] # if generate only 1 z_center, z_stride will be ignored
+           offsets: [0.16, -39.52, -1.78] # origin_offset + strides / 2
+           rotations: [0, 1.57] # 0, pi/2
+           matched_threshold : 0.6
+           unmatched_threshold : 0.45
+         }
+       }
+
+      sample_positive_fraction : -1
+      sample_size : 512
+      region_similarity_calculator: {
+        nearest_iou_similarity: {
+        }
+      }
+    }
+  }
+}
+
+
+train_input_reader: {
+  record_file_path: "kitti_train.tfrecord"
+  class_names: ["Car"]
+  max_num_epochs : 160
+  batch_size: 1
+  prefetch_size : 25
+  max_number_of_voxels: 12000
+  shuffle_points: true
+  num_workers: 2
+  groundtruth_localization_noise_std: [0.25, 0.25, 0.25]
+  groundtruth_rotation_uniform_noise: [-0.15707963267, 0.15707963267]
+  global_rotation_uniform_noise: [-0.78539816, 0.78539816]
+  global_scaling_uniform_noise: [0.95, 1.05]
+  global_random_rotation_range_per_object: [0, 0]
+  anchor_area_threshold: 1
+  remove_points_after_sample: false
+  groundtruth_points_drop_percentage: 0.0
+  groundtruth_drop_max_keep_points: 15
+  database_sampler {
+    database_info_path: "kitti_dbinfos_train.pkl"
+    sample_groups {
+      name_to_max_num {
+        key: "Car"
+        value: 15
+      }
+    }
+    database_prep_steps {
+      filter_by_min_num_points {
+        min_num_point_pairs {
+          key: "Car"
+          value: 5
+        }
+      }
+    }
+    database_prep_steps {
+      filter_by_difficulty {
+        removed_difficulties: [-1]
+      }
+    }
+    global_random_rotation_range_per_object: [0, 0]
+    rate: 1.0
+  }
+
+  remove_unknown_examples: false
+  remove_environment: false
+  kitti_info_path: "kitti_infos_train.pkl"
+  kitti_root_path: ""
+}
+
+train_config: {
+
+  inter_op_parallelism_threads: 4
+  intra_op_parallelism_threads: 4
+  steps: 296960 # 1856 steps per epoch * 160 epochs
+  steps_per_eval: 9280 # 1856 steps per epoch * 5 epochs
+  save_checkpoints_secs : 1800 # half hour
+  save_summary_steps : 10
+  enable_mixed_precision: false
+  loss_scale_factor : 512.0
+  clear_metrics_every_epoch: false
+}
+
+eval_input_reader: {
+  record_file_path: "kitti_val.tfrecord"
+  class_names: ["Car"]
+  batch_size: 1
+  max_num_epochs : 160
+  prefetch_size : 25
+  max_number_of_voxels: 12000
+  shuffle_points: false
+  num_workers: 1
+  anchor_area_threshold: 1
+  remove_environment: false
+  kitti_info_path: "kitti_infos_val.pkl"
+  kitti_root_path: ""
+}
diff --git a/src/opendr/perception/object_detection_3d/voxel_object_detection_3d/second_detector/data/kitti_common.py b/src/opendr/perception/object_detection_3d/voxel_object_detection_3d/second_detector/data/kitti_common.py
index c074dcd7365fe8af7a02dd410eb5f6e3247da133..9d15a92a9a6d972646a9ee1e6826babeeb321f4c 100644
--- a/src/opendr/perception/object_detection_3d/voxel_object_detection_3d/second_detector/data/kitti_common.py
+++ b/src/opendr/perception/object_detection_3d/voxel_object_detection_3d/second_detector/data/kitti_common.py
@@ -648,7 +648,7 @@ def get_pseudo_label_anno():
     return annotations
 
 
-def get_start_result_anno():
+def get_start_result_anno(return_uncertainty):
     annotations = {}
     annotations.update({
         "name": [],
@@ -661,6 +661,18 @@ def get_start_result_anno():
         "rotation_y": [],
         "score": [],
     })
+
+    if return_uncertainty:
+        annotations.update({
+            "variance_truncated": [],
+            "variance_occluded": [],
+            "variance_alpha": [],
+            "variance_bbox": [],
+            "variance_dimensions": [],
+            "variance_location": [],
+            "variance_rotation_y": [],
+            "variance_score": [],
+        })
     return annotations
 
 
diff --git a/src/opendr/perception/object_detection_3d/voxel_object_detection_3d/second_detector/pytorch/core/box_coders.py b/src/opendr/perception/object_detection_3d/voxel_object_detection_3d/second_detector/pytorch/core/box_coders.py
index 0c3da0d1c2101f98e08b79905e5ca8f032d91225..d4d413394dcae80f0f5e73bc1942058a519cb16e 100644
--- a/src/opendr/perception/object_detection_3d/voxel_object_detection_3d/second_detector/pytorch/core/box_coders.py
+++ b/src/opendr/perception/object_detection_3d/voxel_object_detection_3d/second_detector/pytorch/core/box_coders.py
@@ -18,6 +18,10 @@ class GroundBox3dCoderTorch(GroundBox3dCoder):
         return box_torch_ops.second_box_decode(
             boxes, anchors, self.vec_encode, self.linear_dim
         )
+    def decode_torch_uncertainty(self, boxes_var, boxes_mean, anchors):
+        return box_torch_ops.second_box_decode_uncertainty(
+            boxes_var, boxes_mean, anchors, self.vec_encode, self.linear_dim
+        )
 
 
 class BevBoxCoderTorch(BevBoxCoder):
diff --git a/src/opendr/perception/object_detection_3d/voxel_object_detection_3d/second_detector/pytorch/core/box_torch_ops.py b/src/opendr/perception/object_detection_3d/voxel_object_detection_3d/second_detector/pytorch/core/box_torch_ops.py
index 050ea5ae69ee547a0a8702d42f81b99983311184..c36bea8317ec07eb6e7d3ce7512304ac52113759 100644
--- a/src/opendr/perception/object_detection_3d/voxel_object_detection_3d/second_detector/pytorch/core/box_torch_ops.py
+++ b/src/opendr/perception/object_detection_3d/voxel_object_detection_3d/second_detector/pytorch/core/box_torch_ops.py
@@ -102,6 +102,61 @@ def second_box_decode(box_encodings,
     return torch.cat([xg, yg, zg, wg, lg, hg, rg], dim=-1)
 
 
+
+def second_box_decode_uncertainty(
+    box_var_encodings,
+    box_mean_encodings,
+    anchors,
+    encode_angle_to_vector=False,
+    smooth_dim=False):
+    """box decode for VoxelNet in lidar
+    Args:
+        boxes ([N, 7] Tensor): normal boxes: x, y, z, w, l, h, r
+        anchors ([N, 7] Tensor): anchors
+    """
+    xa, ya, za, wa, la, ha, ra = torch.split(anchors, 1, dim=-1)
+    if encode_angle_to_vector:
+        xt, yt, zt, wt, lt, ht, rtx, rty = torch.split(
+            box_var_encodings,
+            1,
+            dim=-1
+        )
+        mxt, myt, mzt, mwt, mlt, mht, mrtx, mrty = torch.split(
+            box_mean_encodings,
+            1,
+            dim=-1
+        )
+
+    else:
+        xt, yt, zt, wt, lt, ht, rt = torch.split(box_var_encodings, 1, dim=-1)
+        mxt, myt, mzt, mwt, mlt, mht, mrt = torch.split(box_mean_encodings, 1, dim=-1)
+
+    def exp_var(m, s):
+        return (torch.exp(s ** 2) - 1) * (torch.exp(2 * m  + s ** 2))
+
+    za = za + ha / 2
+    diagonal = torch.sqrt(la**2 + wa**2)
+    xg = xt * diagonal ** 2
+    yg = yt * diagonal ** 2
+    zg = zt * ha ** 2
+    if smooth_dim:
+        lg = lt * la ** 2
+        wg = wt * wa ** 2
+        hg = ht * ha ** 2
+    else:
+        lg = exp_var(mlt, lt) * la ** 2
+        wg = exp_var(mwt, wt) * wa ** 2
+        hg = exp_var(mht, ht) * ha ** 2
+    if encode_angle_to_vector:
+        rgx = rtx
+        rgy = rty
+        rg = torch.atan2(rgy, rgx)
+    else:
+        rg = rt
+    return torch.cat([xg, yg, zg, wg, lg, hg, rg], dim=-1)
+
+
+
 def bev_box_encode(boxes,
                    anchors,
                    encode_angle_to_vector=False,
@@ -372,6 +427,16 @@ def project_to_image(points_3d, proj_mat):
     return point_2d_res
 
 
+def uncertainty_project_to_image(points_3d, proj_mat):
+    points_num = list(points_3d.shape)[:-1]
+    points_shape = np.concatenate([points_num, [1]], axis=0).tolist()
+    points_4 = torch.cat(
+        [points_3d, torch.zeros(*points_shape).type_as(points_3d)], dim=-1)
+    point_2d = points_4 @ (proj_mat.t() ** 2)
+    point_2d_res = point_2d[..., :2] / point_2d[..., 2:3]
+    return point_2d_res
+
+
 def camera_to_lidar(points, r_rect, velo2cam):
     num_points = points.shape[0]
     points = torch.cat(
@@ -380,6 +445,17 @@ def camera_to_lidar(points, r_rect, velo2cam):
     return lidar_points[..., :3]
 
 
+def uncertainty_camera_to_lidar(points, r_rect, velo2cam):
+    num_points = points.shape[0]
+    points = torch.cat(
+        [points, torch.ones(num_points, 1).type_as(points)], dim=-1)
+    
+    M = torch.inverse((r_rect @ velo2cam).t())
+    
+    lidar_points = points @ (M ** 2)
+    return lidar_points[..., :3]
+
+
 def lidar_to_camera(points, r_rect, velo2cam):
     num_points = points.shape[0]
     points = torch.cat(
@@ -396,6 +472,14 @@ def box_camera_to_lidar(data, r_rect, velo2cam):
     return torch.cat([xyz_lidar, w, l, h, r], dim=-1)
 
 
+def uncertainty_box_camera_to_lidar(data, r_rect, velo2cam):
+    xyz = data[..., 0:3]
+    l, h, w = data[..., 3:4], data[..., 4:5], data[..., 5:6]
+    r = data[..., 6:7]
+    xyz_lidar = uncertainty_camera_to_lidar(xyz, r_rect, velo2cam)
+    return torch.cat([xyz_lidar, w, l, h, r], dim=-1)
+
+
 def box_lidar_to_camera(data, r_rect, velo2cam):
     xyz_lidar = data[..., 0:3]
     w, l, h = data[..., 3:4], data[..., 4:5], data[..., 5:6]
diff --git a/src/opendr/perception/object_detection_3d/voxel_object_detection_3d/second_detector/pytorch/models/tanet.py b/src/opendr/perception/object_detection_3d/voxel_object_detection_3d/second_detector/pytorch/models/tanet.py
index b4d2bf2c8bb9ff690a754b88a4f70435058a5b01..1131da76dd16d27ee025b83c1fb62c1b00389b5c 100644
--- a/src/opendr/perception/object_detection_3d/voxel_object_detection_3d/second_detector/pytorch/models/tanet.py
+++ b/src/opendr/perception/object_detection_3d/voxel_object_detection_3d/second_detector/pytorch/models/tanet.py
@@ -930,7 +930,7 @@ class VPSA(nn.Module):
         all_box_preds = []
         all_cls_preds = []
         all_dir_cls_preds = []
-        all_refine_box_preds = []
+        all_refine_loc_preds = []
         all_refine_cls_preds = []
         all_refine_dir_cls_preds = []
 
@@ -994,8 +994,8 @@ class VPSA(nn.Module):
             refine_cls_preds_local = self.refine_cls(PSA_output)
             refine_loc_preds_local = self.refine_loc(PSA_output)
             
-            all_refine_box_preds.append(refine_cls_preds_local)
-            all_refine_cls_preds.append(refine_loc_preds_local)
+            all_refine_loc_preds.append(refine_loc_preds_local)
+            all_refine_cls_preds.append(refine_cls_preds_local)
 
             if self._use_direction_classifier:
                 refine_dir_preds_local = self.refine_dir(PSA_output)
@@ -1021,8 +1021,8 @@ class VPSA(nn.Module):
             torch.stack(all_cls_preds, dim=0), dim=0, unbiased=False
         )
 
-        refine_box_preds_var, refine_box_preds = torch.var_mean(
-            torch.stack(all_refine_box_preds, dim=0), dim=0, unbiased=False
+        refine_loc_preds_var, refine_loc_preds = torch.var_mean(
+            torch.stack(all_refine_loc_preds, dim=0), dim=0, unbiased=False
         )
         refine_cls_preds_var, refine_cls_preds = torch.var_mean(
             torch.stack(all_refine_cls_preds, dim=0), dim=0, unbiased=False
@@ -1033,9 +1033,9 @@ class VPSA(nn.Module):
         cls_preds = cls_preds.permute(0, 2, 3, 1).contiguous()
         box_preds_var = box_preds_var.permute(0, 2, 3, 1).contiguous()
         cls_preds_var = cls_preds_var.permute(0, 2, 3, 1).contiguous()
-        refine_box_preds = refine_box_preds.permute(0, 2, 3, 1).contiguous()
+        refine_loc_preds = refine_loc_preds.permute(0, 2, 3, 1).contiguous()
         refine_cls_preds = refine_cls_preds.permute(0, 2, 3, 1).contiguous()
-        refine_box_preds_var = refine_box_preds_var.permute(0, 2, 3, 1).contiguous()
+        refine_loc_preds_var = refine_loc_preds_var.permute(0, 2, 3, 1).contiguous()
         refine_cls_preds_var = refine_cls_preds_var.permute(0, 2, 3, 1).contiguous()
 
         ret_dict = {
@@ -1043,9 +1043,9 @@ class VPSA(nn.Module):
             "cls_preds": cls_preds,
             "box_preds_var": box_preds_var,
             "cls_preds_var": cls_preds_var,
-            "Refine_box_preds": refine_box_preds,
+            "Refine_loc_preds": refine_loc_preds,
             "Refine_cls_preds": refine_cls_preds,
-            "Refine_box_preds_var": refine_box_preds_var,
+            "Refine_loc_preds_var": refine_loc_preds_var,
             "Refine_cls_preds_var": refine_cls_preds_var,
         }
 
diff --git a/src/opendr/perception/object_detection_3d/voxel_object_detection_3d/second_detector/pytorch/models/voxelnet.py b/src/opendr/perception/object_detection_3d/voxel_object_detection_3d/second_detector/pytorch/models/voxelnet.py
index bbee7d580be0ee80cb55d4b0340da1bbd359c0cf..cda9f04d66dc512ade6daca4fc1e8bb338c2749d 100644
--- a/src/opendr/perception/object_detection_3d/voxel_object_detection_3d/second_detector/pytorch/models/voxelnet.py
+++ b/src/opendr/perception/object_detection_3d/voxel_object_detection_3d/second_detector/pytorch/models/voxelnet.py
@@ -1100,7 +1100,7 @@ class VoxelNet(nn.Module):
     def get_global_step(self):
         return int(self.global_step.cpu().numpy()[0])
 
-    def forward(self, example, refine_weight=2, samples=1):
+    def forward(self, example, refine_weight=2, samples=1, return_uncertainty=False):
         """module's forward should always accept dict and return loss.
         """
         voxels = example["voxels"]
@@ -1303,18 +1303,21 @@ class VoxelNet(nn.Module):
                     "cared": cared,
                 }
         else:
-            if self.rpn_class_name == "PSA" or self.rpn_class_name == "RefineDet":
-                coarse_output = self.predict_coarse(example, preds_dict, self.device)
-                refine_output = self.predict_refine(example, preds_dict, self.device)
+            if self.rpn_class_name in ["PSA", "VPSA", "RefineDet"]:
+                coarse_output = self.predict_coarse(example, preds_dict, self.device, return_uncertainty=return_uncertainty)
+                refine_output = self.predict_refine(example, preds_dict, self.device, return_uncertainty=return_uncertainty)
                 return coarse_output, refine_output
             else:
-                return self.predict_coarse(example, preds_dict, self.device)
+                return self.predict_coarse(example, preds_dict, self.device, return_uncertainty=return_uncertainty)
 
     def compute_predict(
         self,
         batch_box_preds,
         batch_cls_preds,
         batch_dir_preds,
+        batch_box_preds_var,
+        batch_cls_preds_var,
+        batch_dir_preds_var,
         batch_rect,
         batch_Trv2c,
         batch_P2,
@@ -1322,12 +1325,16 @@ class VoxelNet(nn.Module):
         batch_anchors_mask,
         num_class_with_bg,
         device,
+        return_uncertainty,
     ):
         predictions_dicts = []
         for (
                 box_preds,
                 cls_preds,
                 dir_preds,
+                box_preds_var,
+                cls_preds_var,
+                dir_preds_var,
                 rect,
                 Trv2c,
                 P2,
@@ -1337,6 +1344,9 @@ class VoxelNet(nn.Module):
                 batch_box_preds,
                 batch_cls_preds,
                 batch_dir_preds,
+                batch_box_preds_var,
+                batch_cls_preds_var,
+                batch_dir_preds_var,
                 batch_rect,
                 batch_Trv2c,
                 batch_P2,
@@ -1350,9 +1360,16 @@ class VoxelNet(nn.Module):
                 a_mask = a_mask.bool()
                 box_preds = box_preds[a_mask]
                 cls_preds = cls_preds[a_mask]
+
+                if return_uncertainty:
+                    box_preds_var = box_preds_var[a_mask]
+                    cls_preds_var = cls_preds_var[a_mask]
             if self._use_direction_classifier:
                 if a_mask is not None:
                     dir_preds = dir_preds[a_mask]
+
+                    if return_uncertainty:
+                        dir_preds_var = dir_preds_var[a_mask]
                 dir_labels = torch.max(dir_preds, dim=-1)[1]
             if self._encode_background_as_zeros:
                 # this don't support softmax
@@ -1398,27 +1415,45 @@ class VoxelNet(nn.Module):
                 )
                 selected_boxes, selected_labels, selected_scores = [], [], []
                 selected_dir_labels = []
+                selected_boxes_var = []
+                selected_cls_preds_var = []
+                selected_dir_preds_var = []
                 for i, selected in enumerate(selected_per_class):
                     if selected is not None:
                         num_dets = selected.shape[0]
                         selected_boxes.append(box_preds[selected])
                         selected_labels.append(
                             torch.full([num_dets], i, dtype=torch.int64))
+                        selected_boxes.append(box_preds[selected])
                         if self._use_direction_classifier:
                             selected_dir_labels.append(dir_labels[selected])
                         selected_scores.append(total_scores[selected, i])
+
+                        if return_uncertainty:
+                            selected_boxes_var.append(box_preds_var[selected])
+                            selected_dir_preds_var.append(dir_preds_var[selected])
+                            selected_cls_preds_var.append(cls_preds_var[selected])
                 if len(selected_boxes) > 0:
                     selected_boxes = torch.cat(selected_boxes, dim=0)
                     selected_labels = torch.cat(selected_labels, dim=0)
                     selected_scores = torch.cat(selected_scores, dim=0)
+                    if return_uncertainty:
+                        selected_boxes_var = torch.cat(selected_boxes_var, dim=0)
+                        selected_cls_preds_var = torch.cat(selected_cls_preds_var, dim=0)
                     if self._use_direction_classifier:
-                        selected_dir_labels = torch.cat(selected_dir_labels,
-                                                        dim=0)
+                        selected_dir_labels = torch.cat(
+                            selected_dir_labels,
+                            dim=0
+                        )
+                        if return_uncertainty:
+                            selected_dir_preds_var = torch.cat(selected_dir_preds_var, dim=0)
                 else:
                     selected_boxes = None
                     selected_labels = None
                     selected_scores = None
-                    selected_dir_labels = None
+                    selected_boxes_var = None
+                    selected_cls_preds_var = None
+                    selected_dir_preds_var = None
             else:
                 # get highest score per prediction, than apply nms
                 # to remove overlapped box.
@@ -1470,12 +1505,20 @@ class VoxelNet(nn.Module):
                         selected_dir_labels = dir_labels[selected]
                     selected_labels = top_labels[selected]
                     selected_scores = top_scores[selected]
+                    if return_uncertainty:
+                        selected_boxes_var = box_preds_var[selected]
+                        selected_cls_preds_var = cls_preds_var[selected]
+                        selected_dir_preds_var = dir_preds_var[selected]
             # finally generate predictions.
 
             if selected_boxes is not None:
                 box_preds = selected_boxes
                 scores = selected_scores
                 label_preds = selected_labels
+                if return_uncertainty:
+                    box_preds_var = selected_boxes_var
+                    cls_preds_var = selected_cls_preds_var
+                    dir_preds_var = selected_dir_preds_var
                 if self._use_direction_classifier:
                     dir_labels = selected_dir_labels
                     opp_labels = dir_labels.byte() ^ (
@@ -1488,10 +1531,18 @@ class VoxelNet(nn.Module):
                     )
                 final_box_preds = box_preds
                 final_scores = scores
+                if return_uncertainty:
+                    final_box_preds_var_camera = None
+                    box_2d_preds_var = None
 
                 if is_calib:
                     final_box_preds_camera = box_torch_ops.box_lidar_to_camera(
                         final_box_preds, rect, Trv2c)
+
+                    if return_uncertainty:
+                        final_box_preds_var_camera = box_torch_ops.uncertainty_box_camera_to_lidar(
+                            box_preds_var, rect, Trv2c
+                        )
                     locs = final_box_preds_camera[:, :3]
                     dims = final_box_preds_camera[:, 3:6]
                     angles = final_box_preds_camera[:, 6]
@@ -1504,31 +1555,42 @@ class VoxelNet(nn.Module):
                     minxy = torch.min(box_corners_in_image, dim=1)[0]
                     maxxy = torch.max(box_corners_in_image, dim=1)[0]
                     box_2d_preds = torch.cat([minxy, maxxy], dim=1)
+
+                    if return_uncertainty:
+                        box_2d_preds_var = torch.zeros_like(box_2d_preds)
                 else:
                     box_2d_preds = None
                     final_box_preds_camera = None
                 # predictions
                 predictions_dict = {
                     "bbox": box_2d_preds,
+                    "bbox_var": box_2d_preds_var if return_uncertainty else None,
                     "box3d_camera": final_box_preds_camera,
+                    "box3d_camera_var": final_box_preds_var_camera if return_uncertainty else None,
                     "box3d_lidar": final_box_preds,
+                    "box3d_lidar_var": box_preds_var if return_uncertainty else None,
                     "scores": final_scores,
+                    "scores_var": cls_preds_var if return_uncertainty else None,
                     "label_preds": label_preds,
                     "image_idx": img_idx,
                 }
             else:
                 predictions_dict = {
                     "bbox": None,
+                    "bbox_var": None,
                     "box3d_camera": None,
+                    "box3d_camera_var": None,
                     "box3d_lidar": None,
+                    "box3d_lidar_var": None,
                     "scores": None,
+                    "scores_var": None,
                     "label_preds": None,
                     "image_idx": img_idx,
                 }
             predictions_dicts.append(predictions_dict)
         return predictions_dicts
 
-    def predict_coarse(self, example, preds_dict, device):
+    def predict_coarse(self, example, preds_dict, device, return_uncertainty):
         t = time.time()
         batch_size = example["anchors"].shape[0]
         batch_anchors = example["anchors"].view(batch_size, -1, 7)
@@ -1556,38 +1618,79 @@ class VoxelNet(nn.Module):
         t = time.time()
         batch_box_preds = preds_dict["box_preds"]
         batch_cls_preds = preds_dict["cls_preds"]
-        batch_box_preds = batch_box_preds.view(batch_size, -1,
-                                               self._box_coder.code_size)
+        batch_box_preds = batch_box_preds.view(
+            batch_size, -1,
+            self._box_coder.code_size
+        )
+
+        batch_box_preds_var = [None] * batch_size
+        batch_cls_preds_var = [None] * batch_size
+
+        if return_uncertainty:
+            batch_box_preds_var = preds_dict["box_preds_var"]
+            batch_cls_preds_var = preds_dict["cls_preds_var"]
+            batch_box_preds_var = batch_box_preds_var.view(
+                batch_size, -1,
+                self._box_coder.code_size
+            )
         num_class_with_bg = self._num_class
         if not self._encode_background_as_zeros:
             num_class_with_bg = self._num_class + 1
 
-        batch_cls_preds = batch_cls_preds.view(batch_size, -1,
-                                               num_class_with_bg)
-        batch_box_preds = self._box_coder.decode_torch(batch_box_preds,
-                                                       batch_anchors)
+        if return_uncertainty:
+            batch_cls_preds_var = batch_cls_preds_var.view(
+                batch_size, -1,
+                num_class_with_bg
+            )
+            batch_box_preds_var = self._box_coder.decode_torch_uncertainty(
+                batch_box_preds_var,
+                batch_box_preds,
+                batch_anchors
+            )
+
+        batch_cls_preds = batch_cls_preds.view(
+            batch_size, -1,
+            num_class_with_bg
+        )
+
+        batch_box_preds = self._box_coder.decode_torch(
+            batch_box_preds,
+            batch_anchors
+        )
+
         if self._use_direction_classifier:
             batch_dir_preds = preds_dict["dir_cls_preds"]
             batch_dir_preds = batch_dir_preds.view(batch_size, -1, 2)
+
+            batch_dir_preds_var = [None] * batch_size
+
+            if return_uncertainty:
+                batch_dir_preds_var = preds_dict["dir_cls_preds_var"]
+                batch_dir_preds_var = batch_dir_preds_var.view(batch_size, -1, 2)
         else:
             batch_dir_preds = [None] * batch_size
+            batch_dir_preds_var = [None] * batch_size
 
         predictions_dicts = self.compute_predict(
             batch_box_preds,
             batch_cls_preds,
             batch_dir_preds,
+            batch_box_preds_var,
+            batch_cls_preds_var,
+            batch_dir_preds_var,
             batch_rect,
             batch_Trv2c,
             batch_P2,
             batch_imgidx,
             batch_anchors_mask,
             num_class_with_bg,
-            device=device
+            device=device,
+            return_uncertainty=return_uncertainty,
         )
         self._total_postprocess_time += time.time() - t
         return predictions_dicts
 
-    def predict_refine(self, example, preds_dict, device):
+    def predict_refine(self, example, preds_dict, device, return_uncertainty):
         t = time.time()
         batch_size = example["anchors"].shape[0]
         batch_anchors = example["anchors"].view(batch_size, -1, 7)
@@ -1623,39 +1726,89 @@ class VoxelNet(nn.Module):
         refine_box_preds = preds_dict["Refine_loc_preds"]
         refine_cls_preds = preds_dict["Refine_cls_preds"]
 
-        coarse_box_preds = coarse_box_preds.view(batch_size, -1,
-                                                 self._box_coder.code_size)
+        refine_box_preds_var = [None] * batch_size
+        refine_cls_preds_var = [None] * batch_size
 
-        refine_box_preds = refine_box_preds.view(batch_size, -1,
-                                                 self._box_coder.code_size)
+        if return_uncertainty:
+            refine_box_preds_var = preds_dict["Refine_loc_preds_var"]
+            refine_cls_preds_var = preds_dict["Refine_cls_preds_var"]
 
-        de_coarse_boxes = self._box_coder.decode_torch(coarse_box_preds,
-                                                       batch_anchors)
-        de_refine_boxes = self._box_coder.decode_torch(refine_box_preds,
-                                                       de_coarse_boxes)
+        coarse_box_preds = coarse_box_preds.view(
+            batch_size, -1,
+            self._box_coder.code_size
+        )
+
+        refine_box_preds = refine_box_preds.view(
+            batch_size, -1,
+            self._box_coder.code_size
+        )
+
+        if return_uncertainty:
+            refine_box_preds_var = refine_box_preds_var.view(
+                batch_size, -1,
+                self._box_coder.code_size
+            )
+
+        de_coarse_boxes = self._box_coder.decode_torch(
+            coarse_box_preds,
+            batch_anchors
+        )
+        de_refine_boxes = self._box_coder.decode_torch(
+            refine_box_preds,
+            de_coarse_boxes
+        )
+        
+        if return_uncertainty:
+            de_refine_boxes_var = self._box_coder.decode_torch_uncertainty(
+                refine_box_preds_var,
+                refine_box_preds,
+                de_coarse_boxes
+            )
 
         batch_box_preds = de_refine_boxes
         batch_cls_preds = refine_cls_preds
-        batch_cls_preds = batch_cls_preds.view(batch_size, -1,
-                                               num_class_with_bg)
+        batch_cls_preds = batch_cls_preds.view(
+            batch_size, -1,
+            num_class_with_bg)
+
+        batch_cls_preds_var = [None] * batch_size
+        batch_box_preds_var = [None] * batch_size
+
+        if return_uncertainty:
+            batch_cls_preds_var = refine_cls_preds_var
+            batch_box_preds_var = de_refine_boxes_var
+            batch_cls_preds_var = batch_cls_preds_var.view(
+                batch_size, -1,
+                num_class_with_bg)
 
         if self._use_direction_classifier:
             batch_dir_preds = preds_dict["Refine_dir_preds"]
             batch_dir_preds = batch_dir_preds.view(batch_size, -1, 2)
+
+            batch_dir_preds_var = [None] * batch_size
+
+            if return_uncertainty:
+                batch_dir_preds_var = preds_dict["Refine_dir_preds_var"]
+                batch_dir_preds_var = batch_dir_preds_var.view(batch_size, -1, 2)
         else:
             batch_dir_preds = [None] * batch_size
+            batch_dir_preds_var = [None] * batch_size
 
         predictions_dicts = self.compute_predict(
             batch_box_preds,
             batch_cls_preds,
             batch_dir_preds,
+            batch_box_preds_var,
+            batch_cls_preds_var,
+            batch_dir_preds_var,
             batch_rect,
             batch_Trv2c,
             batch_P2,
             batch_imgidx,
             batch_anchors_mask,
             num_class_with_bg,
-            device=device
+            device=device,
+            return_uncertainty=return_uncertainty,
         )
         self._total_postprocess_time += time.time() - t
         return predictions_dicts
diff --git a/src/opendr/perception/object_detection_3d/voxel_object_detection_3d/second_detector/run.py b/src/opendr/perception/object_detection_3d/voxel_object_detection_3d/second_detector/run.py
index fa5c98f5daa1acfe721ce768d651165007344102..6a1403982fd296973d6c732a1f5dd15464abbe15 100644
--- a/src/opendr/perception/object_detection_3d/voxel_object_detection_3d/second_detector/run.py
+++ b/src/opendr/perception/object_detection_3d/voxel_object_detection_3d/second_detector/run.py
@@ -606,6 +606,7 @@ def compute_lidar_kitti_output(
     center_limit_range,
     class_names,
     global_set,
+    return_uncertainty,
 ):
     annos = []
     for i, preds_dict in enumerate(predictions_dicts):
@@ -613,10 +614,18 @@ def compute_lidar_kitti_output(
             scores = preds_dict["scores"].detach().cpu().numpy()
             box_preds_lidar = preds_dict["box3d_lidar"].detach().cpu().numpy()
             label_preds = preds_dict["label_preds"].detach().cpu().numpy()
-            anno = kitti.get_start_result_anno()
+
+            if return_uncertainty:
+                scores_var = preds_dict["scores_var"].detach().cpu().numpy()
+                box_preds_lidar_var = preds_dict["box3d_lidar_var"].detach().cpu().numpy()
+            else:
+                scores_var = np.empty_like(scores)
+                box_preds_lidar_var = np.empty_like(box_preds_lidar)
+
+            anno = kitti.get_start_result_anno(return_uncertainty=return_uncertainty)
             num_example = 0
-            for box_lidar, score, label in zip(
-                box_preds_lidar, scores, label_preds
+            for box_lidar, box_lidar_var, score, score_var, label in zip(
+                box_preds_lidar, box_preds_lidar_var, scores, scores_var, label_preds
             ):
                 if center_limit_range is not None:
                     limit_range = np.array(center_limit_range)
@@ -631,6 +640,18 @@ def compute_lidar_kitti_output(
                 anno["dimensions"].append(box_lidar[3:6])
                 anno["location"].append(box_lidar[:3])
                 anno["rotation_y"].append(box_lidar[6])
+                anno["score"].append(score)
+
+                if return_uncertainty:
+                    anno["variance_truncated"].append(0.0)
+                    anno["variance_occluded"].append(0)
+                    anno["variance_alpha"].append(0)
+                    anno["variance_bbox"].append(None)
+                    anno["variance_dimensions"].append(box_lidar_var[3:6])
+                    anno["variance_location"].append(box_lidar_var[:3])
+                    anno["variance_rotation_y"].append(box_lidar_var[6])
+                    anno["variance_score"].append(np.mean(score_var))
+
                 if global_set is not None:
                     for i in range(100000):
                         if score in global_set:
@@ -638,7 +659,6 @@ def compute_lidar_kitti_output(
                         else:
                             global_set.add(score)
                             break
-                anno["score"].append(score)
 
                 num_example += 1
             if num_example != 0:
diff --git a/src/opendr/perception/object_detection_3d/voxel_object_detection_3d/voxel_object_detection_3d_learner.py b/src/opendr/perception/object_detection_3d/voxel_object_detection_3d/voxel_object_detection_3d_learner.py
index 85c5a401d3853c920377df0ae28772f425d9cb6a..9bd003712c97539f6abe5bc26f8a2e72ef8ec2af 100644
--- a/src/opendr/perception/object_detection_3d/voxel_object_detection_3d/voxel_object_detection_3d_learner.py
+++ b/src/opendr/perception/object_detection_3d/voxel_object_detection_3d/voxel_object_detection_3d_learner.py
@@ -25,7 +25,7 @@ from opendr.engine.datasets import (
     ExternalDataset,
     MappedDatasetIterator,
 )
-from opendr.engine.data import PointCloud
+from opendr.engine.data import PointCloud, PointCloudWithCalibration
 from opendr.perception.object_detection_3d.voxel_object_detection_3d.second_detector.load import (
     create_model as second_create_model,
     load_from_checkpoint,
@@ -55,7 +55,7 @@ from opendr.perception.object_detection_3d.voxel_object_detection_3d.second_dete
 from opendr.perception.object_detection_3d.voxel_object_detection_3d.second_detector.data.preprocess import (
     merge_second_batch,
 )
-from opendr.engine.target import BoundingBox3DList
+from opendr.engine.target import BoundingBox3DList, UncertaintyBoundingBox3DList
 from opendr.engine.constants import OPENDR_SERVER_URL
 from urllib.request import urlretrieve
 from urllib.error import URLError
@@ -97,6 +97,7 @@ class VoxelObjectDetection3DLearner(Learner):
             "decay_factor": 0.8,
             "staircase": True,
         },
+        return_uncertainty=False,
     ):
         # Pass the shared parameters on super's constructor so they can get initialized as class attributes
         super(VoxelObjectDetection3DLearner, self).__init__(
@@ -122,6 +123,7 @@ class VoxelObjectDetection3DLearner(Learner):
         self.model_dir = None
         self.eval_checkpoint_dir = None
         self.infer_point_cloud_mapper = None
+        self.calib_infer_point_cloud_mapper = None
 
         if tanet_config_path is not None:
             set_tanet_config(tanet_config_path)
@@ -131,6 +133,7 @@ class VoxelObjectDetection3DLearner(Learner):
         self.model.rpn_ort_session = None  # ONNX runtime inference session
         self.input_config_prepared = False
         self.eval_config_prepared = False
+        self.return_uncertainty = return_uncertainty
 
     def save(self, path, verbose=False):
         """
@@ -431,7 +434,7 @@ class VoxelObjectDetection3DLearner(Learner):
 
         return result
 
-    def infer(self, point_clouds):
+    def infer(self, point_clouds, samples=1):
 
         if self.model is None:
             raise ValueError("No model loaded or created")
@@ -452,33 +455,72 @@ class VoxelObjectDetection3DLearner(Learner):
             self.infer_point_cloud_mapper = infer_point_cloud_mapper
             self.model.eval()
 
+        if self.calib_infer_point_cloud_mapper is None:
+
+            def create_map_point_cloud_dataset_func():
+
+                prep_func = create_prep_func(
+                    self.input_config,
+                    self.model_config,
+                    False,
+                    self.voxel_generator,
+                    self.target_assigner,
+                    use_sampler=False,
+                )
+
+                def map(point_cloud_with_calibration):
+
+                    point_cloud = point_cloud_with_calibration.data
+                    calib = point_cloud_with_calibration.calib
+
+                    example = _prep_v9(point_cloud, calib, prep_func)
+
+                    if point_cloud_with_calibration.image_shape is not None:
+                        example["image_shape"] = point_cloud_with_calibration.image_shape
+
+                    return example
+
+                return map
+
+            self.calib_infer_point_cloud_mapper = create_map_point_cloud_dataset_func()
+            self.model.eval()
+
         input_data = None
 
-        if isinstance(point_clouds, PointCloud):
-            input_data = merge_second_batch(
-                [self.infer_point_cloud_mapper(point_clouds.data)]
-            )
+        def map_single(point_cloud):
+            if isinstance(point_clouds, PointCloudWithCalibration):
+                return self.calib_infer_point_cloud_mapper(point_cloud)
+            elif isinstance(point_clouds, PointCloud):
+                return self.infer_point_cloud_mapper(point_cloud.data)
+            else:
+                raise ValueError("PointCloud or PointCloudWithCalibration expected")
+
+        if isinstance(point_clouds, (PointCloud, PointCloudWithCalibration)):
+            input_data = merge_second_batch([map_single(point_clouds)])
         elif isinstance(point_clouds, list):
             input_data = merge_second_batch(
-                [self.infer_point_cloud_mapper(x.data) for x in point_clouds]
+                [map_single(x) for x in point_clouds]
             )
         else:
             return ValueError(
-                "point_clouds should be a PointCloud or a list of PointCloud"
+                "point_clouds should be a PointCloud or a list of PointClouds"
             )
 
         output = self.model(
-            example_convert_to_torch(input_data, self.float_dtype, device=self.device,)
+            example_convert_to_torch(input_data, self.float_dtype, device=self.device), samples=samples, return_uncertainty=self.return_uncertainty,
         )
 
-        if self.model_config.rpn.module_class_name == "PSA" or self.model_config.rpn.module_class_name == "RefineDet":
+        if self.model_config.rpn.module_class_name in ["PSA", "VPSA", "RefineDet"]:
             output = output[-1]
 
         annotations = compute_lidar_kitti_output(
-            output, self.center_limit_range, self.class_names, None
+            output, self.center_limit_range, self.class_names, None, return_uncertainty=self.return_uncertainty
         )
 
-        result = [BoundingBox3DList.from_kitti(anno) for anno in annotations]
+        if self.return_uncertainty:
+            result = [UncertaintyBoundingBox3DList.from_kitti(anno) for anno in annotations]
+        else:
+            result = [BoundingBox3DList.from_kitti(anno) for anno in annotations]
 
         if isinstance(point_clouds, PointCloud):
             return result[0]