diff --git a/run/create_plots.py b/run/create_plots.py
index 6d863b25d233beed8b8f0418cf750ee2693d2fa9..cf7eb81d1e142d443d2ff3d74bd782d490dc6d52 100644
--- a/run/create_plots.py
+++ b/run/create_plots.py
@@ -30,7 +30,8 @@ import re
 files = glob("run/models/*/eval_*.txt")
 # files = glob("run/models/tanet_*/eval_*.txt") + glob("run/models/vnn_tanet_*/eval_*.txt")
 # files = glob("run/models/*pointpillars*/eval_*.txt")
-files = glob("run/models/tanet_*/eval_*.txt") + glob("run/models/vnnclass*/eval_*.txt")
+# files = glob("run/models/tanet_*/eval_*.txt") + glob("run/models/vnnclass*/eval_*.txt")
+files = glob("run/models/*lens*/eval_*.txt")
 
 
 def read_data(file, frame):
diff --git a/run/train_3d.py b/run/train_3d.py
index 4eb63549f1c4703958440b15de46b9417d328893..b570a05e6eedeff58aa24c14e282649af4c8f944 100644
--- a/run/train_3d.py
+++ b/run/train_3d.py
@@ -184,7 +184,6 @@ def train_vnna_pointpillars(
     )
 
 
-
 def train_lens_pointpillars(
     device="cuda:0",
     load=0,
@@ -197,7 +196,6 @@ def train_lens_pointpillars(
     )
 
 
-
 def train_blens_pointpillars(
     device="cuda:0",
     load=0,
@@ -243,7 +241,6 @@ def train_ens_tanet(
     )
 
 
-
 def train_lens_tanet(
     device="cuda:0",
     load=0,
@@ -260,6 +257,23 @@ def train_lens_tanet(
         config=config,
     )
 
+
+def train_blens_tanet(
+    device="cuda:0",
+    load=0,
+    name="blens_tanet_car",
+    config="blens_xyres_16.proto",
+    samples_list=[3],
+):
+    return train_model(
+        "tanet",
+        device=device,
+        load=load,
+        name=name,
+        samples_list=samples_list,
+        config=config,
+    )
+
 def train_vnn_tapp(
     device="cuda:0",
     load=0,
@@ -284,13 +298,15 @@ def test_model(
     samples_list=[1],
     config="xyres_16.proto",
     eval_suffix="classic",
+    num_ensembles=None,
 ):
 
     config = os.path.join(config_roots[model_type], config,)
     model_path = os.path.join(temp_dir, name)
 
     learner = VoxelObjectDetection3DLearner(
-        model_config_path=config, device=device, checkpoint_after_iter=1000, return_uncertainty=True
+        model_config_path=config, device=device, checkpoint_after_iter=1000, return_uncertainty=True,
+        num_ensembles=num_ensembles
     )
     learner.load(model_path)
 
@@ -304,7 +320,8 @@ def test_model(
                 f.write("samples = " + str(samples) + "\n")
                 f.write("3d AP: " + str(result[2][0, :, 0]))
                 f.write("\n\n")
-            except Exception as e:
+            # except Exception as e:
+            except NotImplementedError as e:
                 f.write("samples = " + str(samples) + "\n")
                 f.write("3d AP: [-1, -1, -1]")
                 f.write(str(e))
@@ -345,6 +362,7 @@ def test_pointpillars(
     samples_list=[1],
     config="xyres_16.proto",
     eval_suffix="classic",
+    num_ensembles=None,
 ):
     return test_model(
         "pointpillars",
@@ -353,6 +371,7 @@ def test_pointpillars(
         samples_list=samples_list,
         config=config,
         eval_suffix=eval_suffix,
+        num_ensembles=num_ensembles,
     )
 
 
@@ -362,6 +381,7 @@ def test_tanet(
     samples_list=[1],
     config="xyres_16.proto",
     eval_suffix="classic",
+    num_ensembles=None,
 ):
     return test_model(
         "tanet",
@@ -370,6 +390,7 @@ def test_tanet(
         samples_list=samples_list,
         config=config,
         eval_suffix=eval_suffix,
+        num_ensembles=num_ensembles,
     )
 
 
@@ -421,6 +442,40 @@ def test_vnna_pointpillars(
         eval_suffix=eval_suffix,
     )
 
+def test_lens_pointpillars(
+    device="cuda:0",
+    name="lens_pointpillars_car",
+    samples_list=[2],
+    num_ensembles=2,
+    config="lens_xyres_16.proto",
+    eval_suffix="lens",
+):
+    return test_pointpillars(
+        device=device,
+        name=name,
+        samples_list=samples_list,
+        config=config,
+        eval_suffix=eval_suffix,
+        num_ensembles=num_ensembles,
+    )
+
+def test_blens_pointpillars(
+    device="cuda:0",
+    name="blens_pointpillars_car",
+    num_ensembles=2,
+    samples_list=[2],
+    config="blens_xyres_16.proto",
+    eval_suffix="lens",
+):
+    return test_pointpillars(
+        device=device,
+        name=name,
+        samples_list=samples_list,
+        config=config,
+        eval_suffix=eval_suffix,
+        num_ensembles=num_ensembles,
+    )
+
 
 def test_vnn_tanet(
     device="cuda:0",
@@ -439,6 +494,43 @@ def test_vnn_tanet(
     )
 
 
+def test_lens_tanet(
+    device="cuda:0",
+    name="lens_tanet_car",
+    num_ensembles=2,
+    samples_list=[2],
+    config="lens_xyres_16.proto",
+    eval_suffix="lens",
+):
+    return test_model(
+        "tanet",
+        device=device,
+        name=name,
+        samples_list=samples_list,
+        config=config,
+        eval_suffix=eval_suffix,
+        num_ensembles=num_ensembles,
+    )
+
+
+def test_blens_tanet(
+    device="cuda:0",
+    name="blens_tanet_car",
+    num_ensembles=2,
+    samples_list=[2],
+    config="blens_xyres_16.proto",
+    eval_suffix="lens",
+):
+    return test_model(
+        "tanet",
+        device=device,
+        name=name,
+        samples_list=samples_list,
+        config=config,
+        eval_suffix=eval_suffix,
+        num_ensembles=num_ensembles,
+    )
+
 def train_vnn_from_classical_model(
     model_type,
     device="cuda:0",
diff --git a/src/opendr/perception/object_detection_3d/voxel_object_detection_3d/second_detector/configs/pointpillars/car/blens_xyres_16.proto b/src/opendr/perception/object_detection_3d/voxel_object_detection_3d/second_detector/configs/pointpillars/car/blens_xyres_16.proto
index 0d0125f66faa5cd5411761f093df1177c94bf928..a6417b2e6932d0dfb7dfd3c7b8c3089c9d321ee4 100644
--- a/src/opendr/perception/object_detection_3d/voxel_object_detection_3d/second_detector/configs/pointpillars/car/blens_xyres_16.proto
+++ b/src/opendr/perception/object_detection_3d/voxel_object_detection_3d/second_detector/configs/pointpillars/car/blens_xyres_16.proto
@@ -148,8 +148,8 @@ train_config: {
 
   inter_op_parallelism_threads: 4
   intra_op_parallelism_threads: 4
-  steps: 296960 # 1856 steps per epoch * 160 epochs
-  steps_per_eval: 9280 # 1856 steps per epoch * 5 epochs
+  steps: 890880 # 1856 steps per epoch * 160 epochs
+  steps_per_eval: 18560 # 1856 steps per epoch * 5 epochs = 9280
   save_checkpoints_secs : 1800 # half hour
   save_summary_steps : 10
   enable_mixed_precision: false
diff --git a/src/opendr/perception/object_detection_3d/voxel_object_detection_3d/second_detector/configs/pointpillars/car/lens_xyres_16.proto b/src/opendr/perception/object_detection_3d/voxel_object_detection_3d/second_detector/configs/pointpillars/car/lens_xyres_16.proto
index 6e3d6d71d9c890fe50706e67c30e94b8375ca41d..c99ccf5e042eb57f045aa5c0aaa5878a32c4f115 100644
--- a/src/opendr/perception/object_detection_3d/voxel_object_detection_3d/second_detector/configs/pointpillars/car/lens_xyres_16.proto
+++ b/src/opendr/perception/object_detection_3d/voxel_object_detection_3d/second_detector/configs/pointpillars/car/lens_xyres_16.proto
@@ -148,8 +148,8 @@ train_config: {
 
   inter_op_parallelism_threads: 4
   intra_op_parallelism_threads: 4
-  steps: 296960 # 1856 steps per epoch * 160 epochs
-  steps_per_eval: 9280 # 1856 steps per epoch * 5 epochs
+  steps: 890880 # 1856 steps per epoch * 160 epochs
+  steps_per_eval: 18560 # 1856 steps per epoch * 5 epochs = 9280
   save_checkpoints_secs : 1800 # half hour
   save_summary_steps : 10
   enable_mixed_precision: false
diff --git a/src/opendr/perception/object_detection_3d/voxel_object_detection_3d/second_detector/configs/tanet/car/blens_xyres_16.proto b/src/opendr/perception/object_detection_3d/voxel_object_detection_3d/second_detector/configs/tanet/car/blens_xyres_16.proto
new file mode 100644
index 0000000000000000000000000000000000000000..c2c3840c8482688fde5685c89c074b17f9ac6910
--- /dev/null
+++ b/src/opendr/perception/object_detection_3d/voxel_object_detection_3d/second_detector/configs/tanet/car/blens_xyres_16.proto
@@ -0,0 +1,173 @@
+model: {
+  second: {
+    voxel_generator {
+      point_cloud_range : [0, -39.68, -3, 69.12, 39.68, 1]
+      voxel_size : [0.16, 0.16, 4]
+      max_number_of_points_per_voxel : 100
+    }
+    num_class: 1
+    voxel_feature_extractor: {
+      module_class_name: "PillarFeature_TANet"
+      num_filters: [64]
+      with_distance: false
+    }
+    middle_feature_extractor: {
+      module_class_name: "PointPillarsScatter"
+    }
+    rpn: {
+      module_class_name: "BLEnsPSA"
+      layer_nums: [3, 5, 5]
+      layer_strides: [2, 2, 2]
+      num_filters: [64, 128, 256]
+      upsample_strides: [1, 2, 4]
+      num_upsample_filters: [128, 128, 128]
+      use_groupnorm: false
+      num_groups: 32
+    }
+    loss: {
+      classification_loss: {
+        weighted_sigmoid_focal: {
+          alpha: 0.25
+          gamma: 2.0
+          anchorwise_output: true
+        }
+      }
+      localization_loss: {
+        weighted_smooth_l1: {
+          sigma: 3.0
+          code_weight: [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
+        }
+      }
+      classification_weight: 1.0
+      localization_weight: 2.0
+    }
+    # Outputs
+    use_sigmoid_score: true
+    encode_background_as_zeros: true
+    encode_rad_error_by_sin: true
+
+    use_direction_classifier: true
+    direction_loss_weight: 0.2
+    use_aux_classifier: false
+    # Loss
+    pos_class_weight: 1.0
+    neg_class_weight: 1.0
+
+    loss_norm_type: NormByNumPositives
+    # Postprocess
+    post_center_limit_range: [0, -39.68, -5, 69.12, 39.68, 5]
+    use_rotate_nms: false
+    use_multi_class_nms: false
+    nms_pre_max_size: 1000
+    nms_post_max_size: 300
+    nms_score_threshold: 0.3 #0.05
+    nms_iou_threshold: 0.1 #0.5
+
+    use_bev: false
+    num_point_features: 4
+    without_reflectivity: false
+    box_coder: {
+      ground_box3d_coder: {
+        linear_dim: false
+        encode_angle_vector: false
+      }
+    }
+    target_assigner: {
+      anchor_generators: {
+         anchor_generator_stride: {
+           sizes: [1.6, 3.9, 1.56] # wlh
+           strides: [0.32, 0.32, 0.0] # if generate only 1 z_center, z_stride will be ignored
+           offsets: [0.16, -39.52, -1.78] # origin_offset + strides / 2
+           rotations: [0, 1.57] # 0, pi/2
+           matched_threshold : 0.6
+           unmatched_threshold : 0.45
+         }
+       }
+
+      sample_positive_fraction : -1
+      sample_size : 512
+      region_similarity_calculator: {
+        nearest_iou_similarity: {
+        }
+      }
+    }
+  }
+}
+
+
+train_input_reader: {
+  record_file_path: "kitti_train.tfrecord"
+  class_names: ["Car"]
+  max_num_epochs : 160
+  batch_size: 1
+  prefetch_size : 25
+  max_number_of_voxels: 12000
+  shuffle_points: true
+  num_workers: 2
+  groundtruth_localization_noise_std: [0.25, 0.25, 0.25]
+  groundtruth_rotation_uniform_noise: [-0.15707963267, 0.15707963267]
+  global_rotation_uniform_noise: [-0.78539816, 0.78539816]
+  global_scaling_uniform_noise: [0.95, 1.05]
+  global_random_rotation_range_per_object: [0, 0]
+  anchor_area_threshold: 1
+  remove_points_after_sample: false
+  groundtruth_points_drop_percentage: 0.0
+  groundtruth_drop_max_keep_points: 15
+  database_sampler {
+    database_info_path: "kitti_dbinfos_train.pkl"
+    sample_groups {
+      name_to_max_num {
+        key: "Car"
+        value: 15
+      }
+    }
+    database_prep_steps {
+      filter_by_min_num_points {
+        min_num_point_pairs {
+          key: "Car"
+          value: 5
+        }
+      }
+    }
+    database_prep_steps {
+      filter_by_difficulty {
+        removed_difficulties: [-1]
+      }
+    }
+    global_random_rotation_range_per_object: [0, 0]
+    rate: 1.0
+  }
+
+  remove_unknown_examples: false
+  remove_environment: false
+  kitti_info_path: "kitti_infos_train.pkl"
+  kitti_root_path: ""
+}
+
+train_config: {
+
+  inter_op_parallelism_threads: 4
+  intra_op_parallelism_threads: 4
+  steps: 890880 # 1856 steps per epoch * 160 epochs
+  steps_per_eval: 18560 # 1856 steps per epoch * 5 epochs = 9280
+  save_checkpoints_secs : 1800 # half hour
+  save_summary_steps : 10
+  enable_mixed_precision: false
+  loss_scale_factor : 512.0
+  clear_metrics_every_epoch: false
+}
+
+eval_input_reader: {
+  record_file_path: "kitti_val.tfrecord"
+  class_names: ["Car"]
+  batch_size: 1
+  max_num_epochs : 160
+  prefetch_size : 25
+  max_number_of_voxels: 12000
+  shuffle_points: false
+  num_workers: 1
+  anchor_area_threshold: 1
+  remove_environment: false
+  kitti_info_path: "kitti_infos_val.pkl"
+  kitti_root_path: ""
+}
diff --git a/src/opendr/perception/object_detection_3d/voxel_object_detection_3d/second_detector/configs/tanet/car/lens_xyres_16.proto b/src/opendr/perception/object_detection_3d/voxel_object_detection_3d/second_detector/configs/tanet/car/lens_xyres_16.proto
new file mode 100644
index 0000000000000000000000000000000000000000..602e2a35c24132199579b3d9d371ac286d813ff0
--- /dev/null
+++ b/src/opendr/perception/object_detection_3d/voxel_object_detection_3d/second_detector/configs/tanet/car/lens_xyres_16.proto
@@ -0,0 +1,173 @@
+model: {
+  second: {
+    voxel_generator {
+      point_cloud_range : [0, -39.68, -3, 69.12, 39.68, 1]
+      voxel_size : [0.16, 0.16, 4]
+      max_number_of_points_per_voxel : 100
+    }
+    num_class: 1
+    voxel_feature_extractor: {
+      module_class_name: "PillarFeature_TANet"
+      num_filters: [64]
+      with_distance: false
+    }
+    middle_feature_extractor: {
+      module_class_name: "PointPillarsScatter"
+    }
+    rpn: {
+      module_class_name: "LEnsPSA"
+      layer_nums: [3, 5, 5]
+      layer_strides: [2, 2, 2]
+      num_filters: [64, 128, 256]
+      upsample_strides: [1, 2, 4]
+      num_upsample_filters: [128, 128, 128]
+      use_groupnorm: false
+      num_groups: 32
+    }
+    loss: {
+      classification_loss: {
+        weighted_sigmoid_focal: {
+          alpha: 0.25
+          gamma: 2.0
+          anchorwise_output: true
+        }
+      }
+      localization_loss: {
+        weighted_smooth_l1: {
+          sigma: 3.0
+          code_weight: [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
+        }
+      }
+      classification_weight: 1.0
+      localization_weight: 2.0
+    }
+    # Outputs
+    use_sigmoid_score: true
+    encode_background_as_zeros: true
+    encode_rad_error_by_sin: true
+
+    use_direction_classifier: true
+    direction_loss_weight: 0.2
+    use_aux_classifier: false
+    # Loss
+    pos_class_weight: 1.0
+    neg_class_weight: 1.0
+
+    loss_norm_type: NormByNumPositives
+    # Postprocess
+    post_center_limit_range: [0, -39.68, -5, 69.12, 39.68, 5]
+    use_rotate_nms: false
+    use_multi_class_nms: false
+    nms_pre_max_size: 1000
+    nms_post_max_size: 300
+    nms_score_threshold: 0.3 #0.05
+    nms_iou_threshold: 0.1 #0.5
+
+    use_bev: false
+    num_point_features: 4
+    without_reflectivity: false
+    box_coder: {
+      ground_box3d_coder: {
+        linear_dim: false
+        encode_angle_vector: false
+      }
+    }
+    target_assigner: {
+      anchor_generators: {
+         anchor_generator_stride: {
+           sizes: [1.6, 3.9, 1.56] # wlh
+           strides: [0.32, 0.32, 0.0] # if generate only 1 z_center, z_stride will be ignored
+           offsets: [0.16, -39.52, -1.78] # origin_offset + strides / 2
+           rotations: [0, 1.57] # 0, pi/2
+           matched_threshold : 0.6
+           unmatched_threshold : 0.45
+         }
+       }
+
+      sample_positive_fraction : -1
+      sample_size : 512
+      region_similarity_calculator: {
+        nearest_iou_similarity: {
+        }
+      }
+    }
+  }
+}
+
+
+train_input_reader: {
+  record_file_path: "kitti_train.tfrecord"
+  class_names: ["Car"]
+  max_num_epochs : 160
+  batch_size: 1
+  prefetch_size : 25
+  max_number_of_voxels: 12000
+  shuffle_points: true
+  num_workers: 2
+  groundtruth_localization_noise_std: [0.25, 0.25, 0.25]
+  groundtruth_rotation_uniform_noise: [-0.15707963267, 0.15707963267]
+  global_rotation_uniform_noise: [-0.78539816, 0.78539816]
+  global_scaling_uniform_noise: [0.95, 1.05]
+  global_random_rotation_range_per_object: [0, 0]
+  anchor_area_threshold: 1
+  remove_points_after_sample: false
+  groundtruth_points_drop_percentage: 0.0
+  groundtruth_drop_max_keep_points: 15
+  database_sampler {
+    database_info_path: "kitti_dbinfos_train.pkl"
+    sample_groups {
+      name_to_max_num {
+        key: "Car"
+        value: 15
+      }
+    }
+    database_prep_steps {
+      filter_by_min_num_points {
+        min_num_point_pairs {
+          key: "Car"
+          value: 5
+        }
+      }
+    }
+    database_prep_steps {
+      filter_by_difficulty {
+        removed_difficulties: [-1]
+      }
+    }
+    global_random_rotation_range_per_object: [0, 0]
+    rate: 1.0
+  }
+
+  remove_unknown_examples: false
+  remove_environment: false
+  kitti_info_path: "kitti_infos_train.pkl"
+  kitti_root_path: ""
+}
+
+train_config: {
+
+  inter_op_parallelism_threads: 4
+  intra_op_parallelism_threads: 4
+  steps: 890880 # 1856 steps per epoch * 160 epochs
+  steps_per_eval: 18560 # 1856 steps per epoch * 5 epochs = 9280
+  save_checkpoints_secs : 1800 # half hour
+  save_summary_steps : 10
+  enable_mixed_precision: false
+  loss_scale_factor : 512.0
+  clear_metrics_every_epoch: false
+}
+
+eval_input_reader: {
+  record_file_path: "kitti_val.tfrecord"
+  class_names: ["Car"]
+  batch_size: 1
+  max_num_epochs : 160
+  prefetch_size : 25
+  max_number_of_voxels: 12000
+  shuffle_points: false
+  num_workers: 1
+  anchor_area_threshold: 1
+  remove_environment: false
+  kitti_info_path: "kitti_infos_val.pkl"
+  kitti_root_path: ""
+}
diff --git a/src/opendr/perception/object_detection_3d/voxel_object_detection_3d/second_detector/pytorch/models/tanet.py b/src/opendr/perception/object_detection_3d/voxel_object_detection_3d/second_detector/pytorch/models/tanet.py
index e1b62f4d14ae849ee1f04155291fa2d1b12ba01c..ef0d1f73ff33e54e8ade53152bebfed64a280db0 100644
--- a/src/opendr/perception/object_detection_3d/voxel_object_detection_3d/second_detector/pytorch/models/tanet.py
+++ b/src/opendr/perception/object_detection_3d/voxel_object_detection_3d/second_detector/pytorch/models/tanet.py
@@ -24,6 +24,9 @@ from opendr.perception.object_detection_3d.voxel_object_detection_3d.second_dete
     VariationalLinear,
 )
 from opendr.perception.object_detection_3d.voxel_object_detection_3d.second_detector.pytorch.models.ensemble import create_ensemble
+from opendr.perception.object_detection_3d.voxel_object_detection_3d.second_detector.pytorch.models.lens import (
+    LayerEnsembleConvolution, LayerEnsembleConvolutionTranspose, LayerEnsembleBase, create_layer_ensemble_network
+)
 
 cfg = edict(
     yaml.safe_load(
@@ -1134,3 +1137,754 @@ class EnsPSA(nn.Module):
                 result[value] = [torch.mean(torch.stack(result[value]), dim=0)]
 
         return result
+
+
+
+def createLEnsPSABase(Convolution, ConvolutionTranspose):
+    class LEnsPSABase(nn.Module):
+        def __init__(
+            self,
+            use_norm=True,
+            num_class=2,
+            layer_nums=[3, 5, 5],
+            layer_strides=[2, 2, 2],
+            num_filters=[128, 128, 256],
+            upsample_strides=[1, 2, 4],
+            num_upsample_filters=[256, 256, 256],
+            num_input_filters=128,
+            num_anchor_per_loc=2,
+            encode_background_as_zeros=True,
+            use_direction_classifier=True,
+            use_groupnorm=False,
+            num_groups=32,
+            use_bev=False,
+            box_code_size=7,
+            name="lenspsa",
+        ):
+            """
+            :param use_norm:
+            :param num_class:
+            :param layer_nums:
+            :param layer_strides:
+            :param num_filters:
+            :param upsample_strides:
+            :param num_upsample_filters:
+            :param num_input_filters:
+            :param num_anchor_per_loc:
+            :param encode_background_as_zeros:
+            :param use_direction_classifier:
+            :param use_groupnorm:
+            :param num_groups:
+            :param use_bev:
+            :param box_code_size:
+            :param name:
+            """
+            super(LEnsPSABase, self).__init__()
+            self._num_anchor_per_loc = num_anchor_per_loc  # 2
+            self._use_direction_classifier = use_direction_classifier  # True
+            self._use_bev = use_bev  # False
+            assert len(layer_nums) == 3
+            assert len(layer_strides) == len(layer_nums)
+            assert len(num_filters) == len(layer_nums)
+            assert len(upsample_strides) == len(layer_nums)
+            assert len(num_upsample_filters) == len(layer_nums)
+            factors = []
+            for i in range(len(layer_nums)):
+                assert int(np.prod(layer_strides[: i + 1])) % upsample_strides[i] == 0
+                factors.append(np.prod(layer_strides[: i + 1]) // upsample_strides[i])
+            assert all([x == factors[0] for x in factors])
+            if use_norm:  # True
+                if use_groupnorm:
+                    BatchNorm2d = change_default_args(num_groups=num_groups, eps=1e-3)(
+                        GroupNorm
+                    )
+                else:
+                    BatchNorm2d = change_default_args(eps=1e-3, momentum=0.01)(
+                        nn.BatchNorm2d
+                    )
+                Conv2d = lambda *args, **kwargs: Convolution(bias=False, *args, **kwargs)
+                ConvTranspose2d = lambda *args, **kwargs: ConvolutionTranspose(bias=False, *args, **kwargs)
+            else:
+                BatchNorm2d = Empty
+                Conv2d = lambda *args, **kwargs: Convolution(bias=True, *args, **kwargs)
+                ConvTranspose2d = lambda *args, **kwargs: ConvolutionTranspose(bias=True, *args, **kwargs)
+
+            # note that when stride > 1, conv2d with same padding isn't
+            # equal to pad-conv2d. we should use pad-conv2d.
+            block2_input_filters = num_filters[0]
+            if use_bev:
+                self.bev_extractor = Sequential(
+                    Conv2d(6, 32, 3, padding=1),
+                    BatchNorm2d(32),
+                    nn.ReLU(),
+                    Conv2d(32, 64, 3, padding=1),
+                    BatchNorm2d(64),
+                    nn.ReLU(),
+                    nn.MaxPool2d(2, 2),
+                )
+                block2_input_filters += 64
+
+            self.block1 = Sequential(
+                nn.ZeroPad2d(1),
+                Conv2d(num_input_filters, num_filters[0], 3, stride=layer_strides[0]),
+                BatchNorm2d(num_filters[0]),
+                nn.ReLU(),
+            )
+            for i in range(layer_nums[0]):
+                self.block1.add(Conv2d(num_filters[0], num_filters[0], 3, padding=1))
+                self.block1.add(BatchNorm2d(num_filters[0]))
+                self.block1.add(nn.ReLU())
+            self.deconv1 = Sequential(
+                ConvTranspose2d(
+                    num_filters[0],
+                    num_upsample_filters[0],
+                    upsample_strides[0],
+                    stride=upsample_strides[0],
+                ),
+                BatchNorm2d(num_upsample_filters[0]),
+                nn.ReLU(),
+            )
+            self.block2 = Sequential(
+                nn.ZeroPad2d(1),
+                Conv2d(block2_input_filters, num_filters[1], 3, stride=layer_strides[1]),
+                BatchNorm2d(num_filters[1]),
+                nn.ReLU(),
+            )
+            for i in range(layer_nums[1]):
+                self.block2.add(Conv2d(num_filters[1], num_filters[1], 3, padding=1))
+                self.block2.add(BatchNorm2d(num_filters[1]))
+                self.block2.add(nn.ReLU())
+            self.deconv2 = Sequential(
+                ConvTranspose2d(
+                    num_filters[1],
+                    num_upsample_filters[1],
+                    upsample_strides[1],
+                    stride=upsample_strides[1],
+                ),
+                BatchNorm2d(num_upsample_filters[1]),
+                nn.ReLU(),
+            )
+            self.block3 = Sequential(
+                nn.ZeroPad2d(1),
+                Conv2d(num_filters[1], num_filters[2], 3, stride=layer_strides[2]),
+                BatchNorm2d(num_filters[2]),
+                nn.ReLU(),
+            )
+            for i in range(layer_nums[2]):
+                self.block3.add(Conv2d(num_filters[2], num_filters[2], 3, padding=1))
+                self.block3.add(BatchNorm2d(num_filters[2]))
+                self.block3.add(nn.ReLU())
+            self.deconv3 = Sequential(
+                ConvTranspose2d(
+                    num_filters[2],
+                    num_upsample_filters[2],
+                    upsample_strides[2],
+                    stride=upsample_strides[2],
+                ),
+                BatchNorm2d(num_upsample_filters[2]),
+                nn.ReLU(),
+            )
+            if encode_background_as_zeros:
+                num_cls = num_anchor_per_loc * num_class
+            else:
+                num_cls = num_anchor_per_loc * (num_class + 1)
+            self.conv_cls = nn.Conv2d(sum(num_upsample_filters), num_cls, 1)
+            self.conv_box = nn.Conv2d(
+                sum(num_upsample_filters), num_anchor_per_loc * box_code_size, 1
+            )
+            if use_direction_classifier:
+                self.conv_dir_cls = nn.Conv2d(
+                    sum(num_upsample_filters), num_anchor_per_loc * 2, 1
+                )
+
+            self.bottle_conv = nn.Conv2d(
+                sum(num_upsample_filters), sum(num_upsample_filters) // 3, 1
+            )
+
+            self.block1_dec2x = nn.MaxPool2d(kernel_size=2)  # C=64
+            self.block1_dec4x = nn.MaxPool2d(kernel_size=4)  # C=64
+
+            self.block2_dec2x = nn.MaxPool2d(kernel_size=2)  # C=128
+            self.block2_inc2x = ConvTranspose2d(
+                num_filters[1],
+                num_filters[0] // 2,
+                upsample_strides[1],
+                stride=upsample_strides[1],
+            )  # C=32
+
+            self.block3_inc2x = ConvTranspose2d(
+                num_filters[2],
+                num_filters[1] // 2,
+                upsample_strides[1],
+                stride=upsample_strides[1],
+            )  # C=64
+            self.block3_inc4x = ConvTranspose2d(
+                num_filters[2],
+                num_filters[0] // 2,
+                upsample_strides[2],
+                stride=upsample_strides[2],
+            )  # C=32
+
+            self.fusion_block1 = nn.Conv2d(
+                num_filters[0] + num_filters[0] // 2 + num_filters[0] // 2,
+                num_filters[0],
+                1,
+            )
+            self.fusion_block2 = nn.Conv2d(
+                num_filters[0] + num_filters[1] + num_filters[1] // 2, num_filters[1], 1
+            )
+            self.fusion_block3 = nn.Conv2d(
+                num_filters[0] + num_filters[1] + num_filters[2], num_filters[2], 1
+            )
+
+            self.refine_up1 = Sequential(
+                ConvTranspose2d(
+                    num_filters[0],
+                    num_upsample_filters[0],
+                    upsample_strides[0],
+                    stride=upsample_strides[0],
+                ),
+                BatchNorm2d(num_upsample_filters[0]),
+                nn.ReLU(),
+            )
+            self.refine_up2 = Sequential(
+                ConvTranspose2d(
+                    num_filters[1],
+                    num_upsample_filters[1],
+                    upsample_strides[1],
+                    stride=upsample_strides[1],
+                ),
+                BatchNorm2d(num_upsample_filters[1]),
+                nn.ReLU(),
+            )
+            self.refine_up3 = Sequential(
+                ConvTranspose2d(
+                    num_filters[2],
+                    num_upsample_filters[2],
+                    upsample_strides[2],
+                    stride=upsample_strides[2],
+                ),
+                BatchNorm2d(num_upsample_filters[2]),
+                nn.ReLU(),
+            )
+
+            C_Bottle = cfg.PSA.C_Bottle
+            C = cfg.PSA.C_Reudce
+
+            self.RF1 = Sequential(  # 3*3
+                Conv2d(C_Bottle * 2, C, kernel_size=1, stride=1),
+                BatchNorm2d(C),
+                nn.ReLU(inplace=True),
+                Conv2d(C, C_Bottle * 2, kernel_size=3, stride=1, padding=1, dilation=1),
+                BatchNorm2d(C_Bottle * 2),
+                nn.ReLU(inplace=True),
+            )
+
+            self.RF2 = Sequential(  # 5*5
+                Conv2d(C_Bottle, C, kernel_size=3, stride=1, padding=1),
+                BatchNorm2d(C),
+                nn.ReLU(inplace=True),
+                Conv2d(C, C_Bottle, kernel_size=3, stride=1, padding=1, dilation=1),
+                BatchNorm2d(C_Bottle),
+                nn.ReLU(inplace=True),
+            )
+
+            self.RF3 = Sequential(  # 7*7
+                Conv2d(C_Bottle // 2, C, kernel_size=3, stride=1, padding=1),
+                BatchNorm2d(C),
+                nn.ReLU(inplace=True),
+                Conv2d(C, C, kernel_size=3, stride=1, padding=1),
+                BatchNorm2d(C),
+                nn.ReLU(inplace=True),
+                Conv2d(C, C_Bottle // 2, kernel_size=3, stride=1, padding=1),
+                BatchNorm2d(C_Bottle // 2),
+                nn.ReLU(inplace=True),
+            )
+
+            self.concat_conv1 = nn.Conv2d(
+                num_filters[1], num_filters[1], kernel_size=3, padding=1
+            )
+            self.concat_conv2 = nn.Conv2d(
+                num_filters[1], num_filters[1], kernel_size=3, padding=1
+            )
+            self.concat_conv3 = nn.Conv2d(
+                num_filters[1], num_filters[1], kernel_size=3, padding=1
+            )
+
+            self.refine_cls = nn.Conv2d(sum(num_upsample_filters), num_cls, 1)
+            self.refine_loc = nn.Conv2d(
+                sum(num_upsample_filters), num_anchor_per_loc * box_code_size, 1
+            )
+            if use_direction_classifier:
+                self.refine_dir = nn.Conv2d(
+                    sum(num_upsample_filters), num_anchor_per_loc * 2, 1
+                )
+
+        def forward(self, x, bev=None, samples=None, combine_predictions=None):
+            x1 = self.block1(x)
+            up1 = self.deconv1(x1)
+
+            x2 = self.block2(x1)
+            up2 = self.deconv2(x2)
+            x3 = self.block3(x2)
+            up3 = self.deconv3(x3)
+            coarse_feat = torch.cat([up1, up2, up3], dim=1)
+            box_preds = self.conv_box(coarse_feat)
+            cls_preds = self.conv_cls(coarse_feat)
+
+            # [N, C, y(H), x(W)]
+            box_preds = box_preds.permute(0, 2, 3, 1).contiguous()
+            cls_preds = cls_preds.permute(0, 2, 3, 1).contiguous()
+            ret_dict = {
+                "box_preds": box_preds,
+                "cls_preds": cls_preds,
+            }
+            if self._use_direction_classifier:
+                dir_cls_preds = self.conv_dir_cls(coarse_feat)
+                dir_cls_preds = dir_cls_preds.permute(0, 2, 3, 1).contiguous()
+                ret_dict["dir_cls_preds"] = dir_cls_preds
+
+            blottle_conv = self.bottle_conv(coarse_feat)
+
+            x1_dec2x = self.block1_dec2x(x1)
+            x1_dec4x = self.block1_dec4x(x1)
+
+            x2_dec2x = self.block2_dec2x(x2)
+            x2_inc2x = self.block2_inc2x(x2)
+
+            x3_inc2x = self.block3_inc2x(x3)
+            x3_inc4x = self.block3_inc4x(x3)
+
+            concat_block1 = torch.cat([x1, x2_inc2x, x3_inc4x], dim=1)
+            fusion_block1 = self.fusion_block1(concat_block1)
+
+            concat_block2 = torch.cat([x1_dec2x, x2, x3_inc2x], dim=1)
+            fusion_block2 = self.fusion_block2(concat_block2)
+
+            concat_block3 = torch.cat([x1_dec4x, x2_dec2x, x3], dim=1)
+            fusion_block3 = self.fusion_block3(concat_block3)
+
+            refine_up1 = self.RF3(fusion_block1)
+            refine_up1 = self.refine_up1(refine_up1)
+            refine_up2 = self.RF2(fusion_block2)
+            refine_up2 = self.refine_up2(refine_up2)
+            refine_up3 = self.RF1(fusion_block3)
+            refine_up3 = self.refine_up3(refine_up3)
+
+            branch1_sum_wise = refine_up1 + blottle_conv
+            branch2_sum_wise = refine_up2 + blottle_conv
+            branch3_sum_wise = refine_up3 + blottle_conv
+
+            concat_conv1 = self.concat_conv1(branch1_sum_wise)
+            concat_conv2 = self.concat_conv2(branch2_sum_wise)
+            concat_conv3 = self.concat_conv3(branch3_sum_wise)
+
+            PSA_output = torch.cat([concat_conv1, concat_conv2, concat_conv3], dim=1)
+
+            refine_cls_preds = self.refine_cls(PSA_output)
+            refine_loc_preds = self.refine_loc(PSA_output)
+
+            refine_loc_preds = refine_loc_preds.permute(0, 2, 3, 1).contiguous()
+            refine_cls_preds = refine_cls_preds.permute(0, 2, 3, 1).contiguous()
+            ret_dict["Refine_loc_preds"] = refine_loc_preds
+            ret_dict["Refine_cls_preds"] = refine_cls_preds
+
+            if self._use_direction_classifier:
+                refine_dir_preds = self.refine_dir(PSA_output)
+                refine_dir_preds = refine_dir_preds.permute(0, 2, 3, 1).contiguous()
+                ret_dict["Refine_dir_preds"] = refine_dir_preds
+
+            return ret_dict
+    return LEnsPSABase
+
+
+def createBlockLEnsPSABase(Convolution, ConvolutionTranspose):
+    class BLEnsPSABase(nn.Module):
+        def __init__(
+            self,
+            use_norm=True,
+            num_class=2,
+            layer_nums=[3, 5, 5],
+            layer_strides=[2, 2, 2],
+            num_filters=[128, 128, 256],
+            upsample_strides=[1, 2, 4],
+            num_upsample_filters=[256, 256, 256],
+            num_input_filters=128,
+            num_anchor_per_loc=2,
+            encode_background_as_zeros=True,
+            use_direction_classifier=True,
+            use_groupnorm=False,
+            num_groups=32,
+            use_bev=False,
+            box_code_size=7,
+            name="lenspsa",
+        ):
+            """
+            :param use_norm:
+            :param num_class:
+            :param layer_nums:
+            :param layer_strides:
+            :param num_filters:
+            :param upsample_strides:
+            :param num_upsample_filters:
+            :param num_input_filters:
+            :param num_anchor_per_loc:
+            :param encode_background_as_zeros:
+            :param use_direction_classifier:
+            :param use_groupnorm:
+            :param num_groups:
+            :param use_bev:
+            :param box_code_size:
+            :param name:
+            """
+            super(BLEnsPSABase, self).__init__()
+            self._num_anchor_per_loc = num_anchor_per_loc  # 2
+            self._use_direction_classifier = use_direction_classifier  # True
+            self._use_bev = use_bev  # False
+            assert len(layer_nums) == 3
+            assert len(layer_strides) == len(layer_nums)
+            assert len(num_filters) == len(layer_nums)
+            assert len(upsample_strides) == len(layer_nums)
+            assert len(num_upsample_filters) == len(layer_nums)
+            factors = []
+            for i in range(len(layer_nums)):
+                assert int(np.prod(layer_strides[: i + 1])) % upsample_strides[i] == 0
+                factors.append(np.prod(layer_strides[: i + 1]) // upsample_strides[i])
+            assert all([x == factors[0] for x in factors])
+            if use_norm:  # True
+                if use_groupnorm:
+                    BatchNorm2d = change_default_args(num_groups=num_groups, eps=1e-3)(
+                        GroupNorm
+                    )
+                else:
+                    BatchNorm2d = change_default_args(eps=1e-3, momentum=0.01)(
+                        nn.BatchNorm2d
+                    )
+                Conv2d = lambda *args, **kwargs: Convolution(bias=False, *args, **kwargs)
+                ConvTranspose2d = lambda *args, **kwargs: ConvolutionTranspose(bias=False, *args, **kwargs)
+            else:
+                BatchNorm2d = Empty
+                Conv2d = lambda *args, **kwargs: Convolution(bias=True, *args, **kwargs)
+                ConvTranspose2d = lambda *args, **kwargs: ConvolutionTranspose(bias=True, *args, **kwargs)
+
+            # note that when stride > 1, conv2d with same padding isn't
+            # equal to pad-conv2d. we should use pad-conv2d.
+            block2_input_filters = num_filters[0]
+            if use_bev:
+                LayerEnsembleBase.start_subcollection()
+                self.bev_extractor = Sequential(
+                    Conv2d(6, 32, 3, padding=1),
+                    BatchNorm2d(32),
+                    nn.ReLU(),
+                    Conv2d(32, 64, 3, padding=1),
+                    BatchNorm2d(64),
+                    nn.ReLU(),
+                    nn.MaxPool2d(2, 2),
+                )
+                block2_input_filters += 64
+                LayerEnsembleBase.connect(LayerEnsembleBase.subcollect())
+
+            LayerEnsembleBase.start_subcollection()
+
+            self.block1 = Sequential(
+                nn.ZeroPad2d(1),
+                Conv2d(num_input_filters, num_filters[0], 3, stride=layer_strides[0]),
+                BatchNorm2d(num_filters[0]),
+                nn.ReLU(),
+            )
+            for i in range(layer_nums[0]):
+                self.block1.add(Conv2d(num_filters[0], num_filters[0], 3, padding=1))
+                self.block1.add(BatchNorm2d(num_filters[0]))
+                self.block1.add(nn.ReLU())
+            self.deconv1 = Sequential(
+                ConvTranspose2d(
+                    num_filters[0],
+                    num_upsample_filters[0],
+                    upsample_strides[0],
+                    stride=upsample_strides[0],
+                ),
+                BatchNorm2d(num_upsample_filters[0]),
+                nn.ReLU(),
+            )
+
+            LayerEnsembleBase.connect(LayerEnsembleBase.subcollect())
+            LayerEnsembleBase.start_subcollection()
+            self.block2 = Sequential(
+                nn.ZeroPad2d(1),
+                Conv2d(block2_input_filters, num_filters[1], 3, stride=layer_strides[1]),
+                BatchNorm2d(num_filters[1]),
+                nn.ReLU(),
+            )
+            for i in range(layer_nums[1]):
+                self.block2.add(Conv2d(num_filters[1], num_filters[1], 3, padding=1))
+                self.block2.add(BatchNorm2d(num_filters[1]))
+                self.block2.add(nn.ReLU())
+            self.deconv2 = Sequential(
+                ConvTranspose2d(
+                    num_filters[1],
+                    num_upsample_filters[1],
+                    upsample_strides[1],
+                    stride=upsample_strides[1],
+                ),
+                BatchNorm2d(num_upsample_filters[1]),
+                nn.ReLU(),
+            )
+
+            LayerEnsembleBase.connect(LayerEnsembleBase.subcollect())
+            LayerEnsembleBase.start_subcollection()
+            self.block3 = Sequential(
+                nn.ZeroPad2d(1),
+                Conv2d(num_filters[1], num_filters[2], 3, stride=layer_strides[2]),
+                BatchNorm2d(num_filters[2]),
+                nn.ReLU(),
+            )
+            for i in range(layer_nums[2]):
+                self.block3.add(Conv2d(num_filters[2], num_filters[2], 3, padding=1))
+                self.block3.add(BatchNorm2d(num_filters[2]))
+                self.block3.add(nn.ReLU())
+            self.deconv3 = Sequential(
+                ConvTranspose2d(
+                    num_filters[2],
+                    num_upsample_filters[2],
+                    upsample_strides[2],
+                    stride=upsample_strides[2],
+                ),
+                BatchNorm2d(num_upsample_filters[2]),
+                nn.ReLU(),
+            )
+
+            LayerEnsembleBase.connect(LayerEnsembleBase.subcollect())
+            if encode_background_as_zeros:
+                num_cls = num_anchor_per_loc * num_class
+            else:
+                num_cls = num_anchor_per_loc * (num_class + 1)
+            self.conv_cls = nn.Conv2d(sum(num_upsample_filters), num_cls, 1)
+            self.conv_box = nn.Conv2d(
+                sum(num_upsample_filters), num_anchor_per_loc * box_code_size, 1
+            )
+            if use_direction_classifier:
+                self.conv_dir_cls = nn.Conv2d(
+                    sum(num_upsample_filters), num_anchor_per_loc * 2, 1
+                )
+
+            self.bottle_conv = nn.Conv2d(
+                sum(num_upsample_filters), sum(num_upsample_filters) // 3, 1
+            )
+
+            self.block1_dec2x = nn.MaxPool2d(kernel_size=2)  # C=64
+            self.block1_dec4x = nn.MaxPool2d(kernel_size=4)  # C=64
+
+            self.block2_dec2x = nn.MaxPool2d(kernel_size=2)  # C=128
+            LayerEnsembleBase.start_subcollection()
+            self.block2_inc2x = ConvTranspose2d(
+                num_filters[1],
+                num_filters[0] // 2,
+                upsample_strides[1],
+                stride=upsample_strides[1],
+            )  # C=32
+
+            self.block3_inc2x = ConvTranspose2d(
+                num_filters[2],
+                num_filters[1] // 2,
+                upsample_strides[1],
+                stride=upsample_strides[1],
+            )  # C=64
+            self.block3_inc4x = ConvTranspose2d(
+                num_filters[2],
+                num_filters[0] // 2,
+                upsample_strides[2],
+                stride=upsample_strides[2],
+            )  # C=32
+
+            LayerEnsembleBase.connect(LayerEnsembleBase.subcollect())
+
+            self.fusion_block1 = nn.Conv2d(
+                num_filters[0] + num_filters[0] // 2 + num_filters[0] // 2,
+                num_filters[0],
+                1,
+            )
+            self.fusion_block2 = nn.Conv2d(
+                num_filters[0] + num_filters[1] + num_filters[1] // 2, num_filters[1], 1
+            )
+            self.fusion_block3 = nn.Conv2d(
+                num_filters[0] + num_filters[1] + num_filters[2], num_filters[2], 1
+            )
+
+
+            LayerEnsembleBase.start_subcollection()
+            self.refine_up1 = Sequential(
+                ConvTranspose2d(
+                    num_filters[0],
+                    num_upsample_filters[0],
+                    upsample_strides[0],
+                    stride=upsample_strides[0],
+                ),
+                BatchNorm2d(num_upsample_filters[0]),
+                nn.ReLU(),
+            )
+            self.refine_up2 = Sequential(
+                ConvTranspose2d(
+                    num_filters[1],
+                    num_upsample_filters[1],
+                    upsample_strides[1],
+                    stride=upsample_strides[1],
+                ),
+                BatchNorm2d(num_upsample_filters[1]),
+                nn.ReLU(),
+            )
+            self.refine_up3 = Sequential(
+                ConvTranspose2d(
+                    num_filters[2],
+                    num_upsample_filters[2],
+                    upsample_strides[2],
+                    stride=upsample_strides[2],
+                ),
+                BatchNorm2d(num_upsample_filters[2]),
+                nn.ReLU(),
+            )
+            
+            LayerEnsembleBase.connect(LayerEnsembleBase.subcollect())
+            LayerEnsembleBase.start_subcollection()
+
+            C_Bottle = cfg.PSA.C_Bottle
+            C = cfg.PSA.C_Reudce
+
+            self.RF1 = Sequential(  # 3*3
+                Conv2d(C_Bottle * 2, C, kernel_size=1, stride=1),
+                BatchNorm2d(C),
+                nn.ReLU(inplace=True),
+                Conv2d(C, C_Bottle * 2, kernel_size=3, stride=1, padding=1, dilation=1),
+                BatchNorm2d(C_Bottle * 2),
+                nn.ReLU(inplace=True),
+            )
+
+            LayerEnsembleBase.connect(LayerEnsembleBase.subcollect())
+            LayerEnsembleBase.start_subcollection()
+
+            self.RF2 = Sequential(  # 5*5
+                Conv2d(C_Bottle, C, kernel_size=3, stride=1, padding=1),
+                BatchNorm2d(C),
+                nn.ReLU(inplace=True),
+                Conv2d(C, C_Bottle, kernel_size=3, stride=1, padding=1, dilation=1),
+                BatchNorm2d(C_Bottle),
+                nn.ReLU(inplace=True),
+            )
+
+            LayerEnsembleBase.connect(LayerEnsembleBase.subcollect())
+            LayerEnsembleBase.start_subcollection()
+
+            self.RF3 = Sequential(  # 7*7
+                Conv2d(C_Bottle // 2, C, kernel_size=3, stride=1, padding=1),
+                BatchNorm2d(C),
+                nn.ReLU(inplace=True),
+                Conv2d(C, C, kernel_size=3, stride=1, padding=1),
+                BatchNorm2d(C),
+                nn.ReLU(inplace=True),
+                Conv2d(C, C_Bottle // 2, kernel_size=3, stride=1, padding=1),
+                BatchNorm2d(C_Bottle // 2),
+                nn.ReLU(inplace=True),
+            )
+
+            LayerEnsembleBase.connect(LayerEnsembleBase.subcollect())
+
+            self.concat_conv1 = nn.Conv2d(
+                num_filters[1], num_filters[1], kernel_size=3, padding=1
+            )
+            self.concat_conv2 = nn.Conv2d(
+                num_filters[1], num_filters[1], kernel_size=3, padding=1
+            )
+            self.concat_conv3 = nn.Conv2d(
+                num_filters[1], num_filters[1], kernel_size=3, padding=1
+            )
+
+            self.refine_cls = nn.Conv2d(sum(num_upsample_filters), num_cls, 1)
+            self.refine_loc = nn.Conv2d(
+                sum(num_upsample_filters), num_anchor_per_loc * box_code_size, 1
+            )
+            if use_direction_classifier:
+                self.refine_dir = nn.Conv2d(
+                    sum(num_upsample_filters), num_anchor_per_loc * 2, 1
+                )
+
+        def forward(self, x, bev=None, samples=None, combine_predictions=None):
+            x1 = self.block1(x)
+            up1 = self.deconv1(x1)
+
+            x2 = self.block2(x1)
+            up2 = self.deconv2(x2)
+            x3 = self.block3(x2)
+            up3 = self.deconv3(x3)
+            coarse_feat = torch.cat([up1, up2, up3], dim=1)
+            box_preds = self.conv_box(coarse_feat)
+            cls_preds = self.conv_cls(coarse_feat)
+
+            # [N, C, y(H), x(W)]
+            box_preds = box_preds.permute(0, 2, 3, 1).contiguous()
+            cls_preds = cls_preds.permute(0, 2, 3, 1).contiguous()
+            ret_dict = {
+                "box_preds": box_preds,
+                "cls_preds": cls_preds,
+            }
+            if self._use_direction_classifier:
+                dir_cls_preds = self.conv_dir_cls(coarse_feat)
+                dir_cls_preds = dir_cls_preds.permute(0, 2, 3, 1).contiguous()
+                ret_dict["dir_cls_preds"] = dir_cls_preds
+
+            blottle_conv = self.bottle_conv(coarse_feat)
+
+            x1_dec2x = self.block1_dec2x(x1)
+            x1_dec4x = self.block1_dec4x(x1)
+
+            x2_dec2x = self.block2_dec2x(x2)
+            x2_inc2x = self.block2_inc2x(x2)
+
+            x3_inc2x = self.block3_inc2x(x3)
+            x3_inc4x = self.block3_inc4x(x3)
+
+            concat_block1 = torch.cat([x1, x2_inc2x, x3_inc4x], dim=1)
+            fusion_block1 = self.fusion_block1(concat_block1)
+
+            concat_block2 = torch.cat([x1_dec2x, x2, x3_inc2x], dim=1)
+            fusion_block2 = self.fusion_block2(concat_block2)
+
+            concat_block3 = torch.cat([x1_dec4x, x2_dec2x, x3], dim=1)
+            fusion_block3 = self.fusion_block3(concat_block3)
+
+            refine_up1 = self.RF3(fusion_block1)
+            refine_up1 = self.refine_up1(refine_up1)
+            refine_up2 = self.RF2(fusion_block2)
+            refine_up2 = self.refine_up2(refine_up2)
+            refine_up3 = self.RF1(fusion_block3)
+            refine_up3 = self.refine_up3(refine_up3)
+
+            branch1_sum_wise = refine_up1 + blottle_conv
+            branch2_sum_wise = refine_up2 + blottle_conv
+            branch3_sum_wise = refine_up3 + blottle_conv
+
+            concat_conv1 = self.concat_conv1(branch1_sum_wise)
+            concat_conv2 = self.concat_conv2(branch2_sum_wise)
+            concat_conv3 = self.concat_conv3(branch3_sum_wise)
+
+            PSA_output = torch.cat([concat_conv1, concat_conv2, concat_conv3], dim=1)
+
+            refine_cls_preds = self.refine_cls(PSA_output)
+            refine_loc_preds = self.refine_loc(PSA_output)
+
+            refine_loc_preds = refine_loc_preds.permute(0, 2, 3, 1).contiguous()
+            refine_cls_preds = refine_cls_preds.permute(0, 2, 3, 1).contiguous()
+            ret_dict["Refine_loc_preds"] = refine_loc_preds
+            ret_dict["Refine_cls_preds"] = refine_cls_preds
+
+            if self._use_direction_classifier:
+                refine_dir_preds = self.refine_dir(PSA_output)
+                refine_dir_preds = refine_dir_preds.permute(0, 2, 3, 1).contiguous()
+                ret_dict["Refine_dir_preds"] = refine_dir_preds
+
+            return ret_dict
+    return BLEnsPSABase
+
+def LEnsPSA(*args, **kwargs):
+    return create_layer_ensemble_network(createLEnsPSABase, *args, **kwargs)
+
+def BLEnsPSA(*args, **kwargs):
+    return create_layer_ensemble_network(createBlockLEnsPSABase, *args, **kwargs)
diff --git a/src/opendr/perception/object_detection_3d/voxel_object_detection_3d/second_detector/pytorch/models/voxelnet.py b/src/opendr/perception/object_detection_3d/voxel_object_detection_3d/second_detector/pytorch/models/voxelnet.py
index 8d5b211a6cee794979e34d90ea080889efa874e6..0d5e1ca82c9987b84938c6b80595caf5306232ec 100644
--- a/src/opendr/perception/object_detection_3d/voxel_object_detection_3d/second_detector/pytorch/models/voxelnet.py
+++ b/src/opendr/perception/object_detection_3d/voxel_object_detection_3d/second_detector/pytorch/models/voxelnet.py
@@ -29,7 +29,7 @@ from opendr.perception.object_detection_3d.voxel_object_detection_3d.second_dete
     PointPillarsScatter,
 )
 from opendr.perception.object_detection_3d.voxel_object_detection_3d.second_detector.pytorch.models.tanet import (
-    VPSA, PillarFeature_TANet, PSA, EnsPSA
+    VPSA, PillarFeature_TANet, PSA, EnsPSA, LEnsPSA, BLEnsPSA
 )
 from opendr.perception.object_detection_3d.voxel_object_detection_3d.second_detector.pytorch.models.loss_utils import (
     create_refine_loss
@@ -911,7 +911,6 @@ class VARPN(nn.Module):
         return ret_dict
 
 
-
 def createLEnsRPNBase(Convolution, ConvolutionTranspose):
     class LEnsRPNBase(nn.Module):
         def __init__(
@@ -1090,7 +1089,6 @@ def createLEnsRPNBase(Convolution, ConvolutionTranspose):
     return LEnsRPNBase
 
 
-
 def createBlockLEnsRPNBase(Convolution: LayerEnsembleBase, ConvolutionTranspose: LayerEnsembleBase):
     class BLEnsRPNBase(nn.Module):
         def __init__(
@@ -1450,6 +1448,7 @@ class VoxelNet(nn.Module):
         rpn_class_dict = {
             "RPN": RPN, "PSA": PSA, "VPSA": VPSA, "EnsPSA": EnsPSA, "VRPN": VRPN, "VARPN": VARPN,
             "LEnsRPN": LEnsRPN, "BLEnsRPN": BLEnsRPN,
+            "LEnsPSA": LEnsPSA, "BLEnsPSA": BLEnsPSA, 
         }
         self.rpn_class_name = rpn_class_name
         rpn_class = rpn_class_dict[rpn_class_name]
diff --git a/src/opendr/perception/object_detection_3d/voxel_object_detection_3d/second_detector/run.py b/src/opendr/perception/object_detection_3d/voxel_object_detection_3d/second_detector/run.py
index 7bd87f6b93b93f2b88e1aa06aa722c85a420dea2..db8dca4354c1d24cdfdfcf3909f982c2a5d6fc15 100644
--- a/src/opendr/perception/object_detection_3d/voxel_object_detection_3d/second_detector/run.py
+++ b/src/opendr/perception/object_detection_3d/voxel_object_detection_3d/second_detector/run.py
@@ -710,6 +710,8 @@ def predict_kitti_to_anno(
         return annos_coarse, annos_refine
     else:
         predictions_dicts_coarse = net(example, samples=samples, combine_predictions=True)
+        predictions_dicts_coarse = predictions_dicts_coarse[0]
+
         annos_coarse = comput_kitti_output(
             predictions_dicts_coarse,
             batch_image_shape,