From 49af004d8dadfd000a32f4c2ce39d97a64e394f4 Mon Sep 17 00:00:00 2001
From: Illia Oleksiienko <io@ece.au.dk>
Date: Mon, 22 Aug 2022 12:15:52 +0000
Subject: [PATCH] Add vnn pointpillars

---
 run/.gitignore                                |   1 +
 run/train_3d.py                               | 172 ++++++++++
 .../datasets/create_data_kitti.py             |   3 +-
 .../pointpillars/car/vnn_xyres_16.proto       | 173 ++++++++++
 .../second_detector/core/geometry.py          |   1 +
 .../pytorch/models/variational.py             | 323 ++++++++++++++++++
 .../pytorch/models/voxelnet.py                | 197 ++++++++++-
 7 files changed, 868 insertions(+), 2 deletions(-)
 create mode 100644 run/.gitignore
 create mode 100644 run/train_3d.py
 create mode 100644 src/opendr/perception/object_detection_3d/voxel_object_detection_3d/second_detector/configs/pointpillars/car/vnn_xyres_16.proto
 create mode 100644 src/opendr/perception/object_detection_3d/voxel_object_detection_3d/second_detector/pytorch/models/variational.py

diff --git a/run/.gitignore b/run/.gitignore
new file mode 100644
index 0000000000..d6d95cf828
--- /dev/null
+++ b/run/.gitignore
@@ -0,0 +1 @@
+models
\ No newline at end of file
diff --git a/run/train_3d.py b/run/train_3d.py
new file mode 100644
index 0000000000..ad5ef2677c
--- /dev/null
+++ b/run/train_3d.py
@@ -0,0 +1,172 @@
+import sys
+import os
+import torch
+import fire
+from opendr.engine.datasets import PointCloudsDatasetIterator
+from opendr.perception.object_detection_3d import VoxelObjectDetection3DLearner
+from opendr.perception.object_detection_3d import (
+    KittiDataset,
+    LabeledPointCloudsDatasetIterator,
+)
+
+
+def train_pointpillars(device="cuda:0", load=0):
+    name = "pointpillars_car"
+    config = os.path.join(
+        ".",
+        "src",
+        "opendr",
+        "perception",
+        "object_detection_3d",
+        "voxel_object_detection_3d",
+        "second_detector",
+        "configs",
+        "pointpillars",
+        "car",
+        "xyres_16.proto",
+    )
+    temp_dir = "./run/models"
+    model_path = os.path.join(temp_dir, name)
+
+    subsets_path = os.path.join(
+        ".",
+        "src",
+        "opendr",
+        "perception",
+        "object_detection_3d",
+        "datasets",
+        "kitti_subsets",
+    )
+
+    dataset = KittiDataset("/data/sets/opendr_kitti/", subsets_path)
+
+    learner = VoxelObjectDetection3DLearner(
+        model_config_path=config, device=device, checkpoint_after_iter=1000, checkpoint_load_iter=load
+    )
+
+    learner.fit(
+        dataset, model_dir=model_path, verbose=True, evaluate=True,
+    )
+    learner.save(model_path)
+
+
+def train_vnn_pointpillars(device="cuda:0", load=0):
+    name = "vnn_pointpillars_car"
+    config = os.path.join(
+        ".",
+        "src",
+        "opendr",
+        "perception",
+        "object_detection_3d",
+        "voxel_object_detection_3d",
+        "second_detector",
+        "configs",
+        "pointpillars",
+        "car",
+        "vnn_xyres_16.proto",
+    )
+    temp_dir = "./run/models"
+    model_path = os.path.join(temp_dir, name)
+
+    subsets_path = os.path.join(
+        ".",
+        "src",
+        "opendr",
+        "perception",
+        "object_detection_3d",
+        "datasets",
+        "kitti_subsets",
+    )
+
+    dataset = KittiDataset("/data/sets/opendr_kitti/", subsets_path)
+
+    learner = VoxelObjectDetection3DLearner(
+        model_config_path=config, device=device, checkpoint_after_iter=1000, checkpoint_load_iter=load
+    )
+
+    learner.fit(
+        dataset, model_dir=model_path, verbose=True, evaluate=True,
+    )
+    learner.save(model_path)
+
+
+def test_vnn_pointpillars(device="cuda:0", load=0):
+    name = "vnn_pointpillars_car"
+    config = os.path.join(
+        ".",
+        "src",
+        "opendr",
+        "perception",
+        "object_detection_3d",
+        "voxel_object_detection_3d",
+        "second_detector",
+        "configs",
+        "pointpillars",
+        "car",
+        "vnn_xyres_16.proto",
+    )
+    temp_dir = "./run/models"
+    model_path = os.path.join(temp_dir, name)
+
+    subsets_path = os.path.join(
+        ".",
+        "src",
+        "opendr",
+        "perception",
+        "object_detection_3d",
+        "datasets",
+        "kitti_subsets",
+    )
+
+    dataset = KittiDataset("/data/sets/opendr_kitti/", subsets_path)
+
+    learner = VoxelObjectDetection3DLearner(
+        model_config_path=config, device=device, checkpoint_after_iter=1000, checkpoint_load_iter=load
+    )
+    learner.load(model_path)
+    result = learner.eval(dataset, verbose=True)
+    return result
+
+
+def train_tanet(device="cuda:0", load=0):
+    name = "tanet_car"
+    config = os.path.join(
+        ".",
+        "src",
+        "opendr",
+        "perception",
+        "object_detection_3d",
+        "voxel_object_detection_3d",
+        "second_detector",
+        "configs",
+        "tanet",
+        "car",
+        "xyres_16.proto",
+    )
+    temp_dir = "./run/models"
+    model_path = os.path.join(temp_dir, name)
+
+    subsets_path = os.path.join(
+        ".",
+        "src",
+        "opendr",
+        "perception",
+        "object_detection_3d",
+        "datasets",
+        "kitti_subsets",
+    )
+
+    dataset = KittiDataset("/data/sets/opendr_kitti/", subsets_path)
+
+    learner = VoxelObjectDetection3DLearner(
+        model_config_path=config, device=device, checkpoint_after_iter=1000, checkpoint_load_iter=load,
+    )
+
+    learner.fit(
+        dataset, model_dir=model_path, verbose=True, evaluate=True,
+    )
+    learner.save(model_path)
+
+
+if __name__ == "__main__":
+    fire.Fire()
diff --git a/src/opendr/perception/object_detection_3d/datasets/create_data_kitti.py b/src/opendr/perception/object_detection_3d/datasets/create_data_kitti.py
index 40b5318577..1adac1e498 100644
--- a/src/opendr/perception/object_detection_3d/datasets/create_data_kitti.py
+++ b/src/opendr/perception/object_detection_3d/datasets/create_data_kitti.py
@@ -40,7 +40,8 @@ def _read_imageset_file(path):
 def _calculate_num_points_in_gt(
     data_path, infos, relative_path, remove_outside=True, num_features=4
 ):
-    for info in infos:
+    for idx, info in enumerate(infos):
+        print(idx, "/", len(infos))
         if relative_path:
             v_path = str(pathlib.Path(data_path) / info["velodyne_path"])
         else:
diff --git a/src/opendr/perception/object_detection_3d/voxel_object_detection_3d/second_detector/configs/pointpillars/car/vnn_xyres_16.proto b/src/opendr/perception/object_detection_3d/voxel_object_detection_3d/second_detector/configs/pointpillars/car/vnn_xyres_16.proto
new file mode 100644
index 0000000000..2d3ce2bad1
--- /dev/null
+++ b/src/opendr/perception/object_detection_3d/voxel_object_detection_3d/second_detector/configs/pointpillars/car/vnn_xyres_16.proto
@@ -0,0 +1,173 @@
+model: {
+  second: {
+    voxel_generator {
+      point_cloud_range : [0, -39.68, -3, 69.12, 39.68, 1]
+      voxel_size : [0.16, 0.16, 4]
+      max_number_of_points_per_voxel : 100
+    }
+    num_class: 1
+    voxel_feature_extractor: {
+      module_class_name: "PillarFeatureNet"
+      num_filters: [64]
+      with_distance: false
+    }
+    middle_feature_extractor: {
+      module_class_name: "PointPillarsScatter"
+    }
+    rpn: {
+      module_class_name: "VRPN"
+      layer_nums: [3, 5, 5]
+      layer_strides: [2, 2, 2]
+      num_filters: [64, 128, 256]
+      upsample_strides: [1, 2, 4]
+      num_upsample_filters: [128, 128, 128]
+      use_groupnorm: false
+      num_groups: 32
+    }
+    loss: {
+      classification_loss: {
+        weighted_sigmoid_focal: {
+          alpha: 0.25
+          gamma: 2.0
+          anchorwise_output: true
+        }
+      }
+      localization_loss: {
+        weighted_smooth_l1: {
+          sigma: 3.0
+          code_weight: [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
+        }
+      }
+      classification_weight: 1.0
+      localization_weight: 2.0
+    }
+    # Outputs
+    use_sigmoid_score: true
+    encode_background_as_zeros: true
+    encode_rad_error_by_sin: true
+
+    use_direction_classifier: true
+    direction_loss_weight: 0.2
+    use_aux_classifier: false
+    # Loss
+    pos_class_weight: 1.0
+    neg_class_weight: 1.0
+
+    loss_norm_type: NormByNumPositives
+    # Postprocess
+    post_center_limit_range: [0, -39.68, -5, 69.12, 39.68, 5]
+    use_rotate_nms: false
+    use_multi_class_nms: false
+    nms_pre_max_size: 1000
+    nms_post_max_size: 300
+    nms_score_threshold: 0.05
+    nms_iou_threshold: 0.5
+
+    use_bev: false
+    num_point_features: 4
+    without_reflectivity: false
+    box_coder: {
+      ground_box3d_coder: {
+        linear_dim: false
+        encode_angle_vector: false
+      }
+    }
+    target_assigner: {
+      anchor_generators: {
+         anchor_generator_stride: {
+           sizes: [1.6, 3.9, 1.56] # wlh
+           strides: [0.32, 0.32, 0.0] # if generate only 1 z_center, z_stride will be ignored
+           offsets: [0.16, -39.52, -1.78] # origin_offset + strides / 2
+           rotations: [0, 1.57] # 0, pi/2
+           matched_threshold : 0.6
+           unmatched_threshold : 0.45
+         }
+       }
+
+      sample_positive_fraction : -1
+      sample_size : 512
+      region_similarity_calculator: {
+        nearest_iou_similarity: {
+        }
+      }
+    }
+  }
+}
+
+
+train_input_reader: {
+  record_file_path: "kitti_train.tfrecord"
+  class_names: ["Car"]
+  max_num_epochs : 160
+  batch_size: 2
+  prefetch_size : 25
+  max_number_of_voxels: 12000
+  shuffle_points: true
+  num_workers: 2
+  groundtruth_localization_noise_std: [0.25, 0.25, 0.25]
+  groundtruth_rotation_uniform_noise: [-0.15707963267, 0.15707963267]
+  global_rotation_uniform_noise: [-0.78539816, 0.78539816]
+  global_scaling_uniform_noise: [0.95, 1.05]
+  global_random_rotation_range_per_object: [0, 0]
+  anchor_area_threshold: 1
+  remove_points_after_sample: false
+  groundtruth_points_drop_percentage: 0.0
+  groundtruth_drop_max_keep_points: 15
+  database_sampler {
+    database_info_path: "kitti_dbinfos_train.pkl"
+    sample_groups {
+      name_to_max_num {
+        key: "Car"
+        value: 15
+      }
+    }
+    database_prep_steps {
+      filter_by_min_num_points {
+        min_num_point_pairs {
+          key: "Car"
+          value: 5
+        }
+      }
+    }
+    database_prep_steps {
+      filter_by_difficulty {
+        removed_difficulties: [-1]
+      }
+    }
+    global_random_rotation_range_per_object: [0, 0]
+    rate: 1.0
+  }
+
+  remove_unknown_examples: false
+  remove_environment: false
+  kitti_info_path: "kitti_infos_train.pkl"
+  kitti_root_path: ""
+}
+
+train_config: {
+
+  inter_op_parallelism_threads: 4
+  intra_op_parallelism_threads: 4
+  steps: 296960 # 1856 steps per epoch * 160 epochs
+  steps_per_eval: 9280 # 1856 steps per epoch * 5 epochs
+  save_checkpoints_secs : 1800 # half hour
+  save_summary_steps : 10
+  enable_mixed_precision: false
+  loss_scale_factor : 512.0
+  clear_metrics_every_epoch: false
+}
+
+eval_input_reader: {
+  record_file_path: "kitti_val.tfrecord"
+  class_names: ["Car"]
+  batch_size: 1
+  max_num_epochs : 160
+  prefetch_size : 25
+  max_number_of_voxels: 12000
+  shuffle_points: false
+  num_workers: 3
+  anchor_area_threshold: 1
+  remove_environment: false
+  kitti_info_path: "kitti_infos_val.pkl"
+  kitti_root_path: ""
+}
diff --git a/src/opendr/perception/object_detection_3d/voxel_object_detection_3d/second_detector/core/geometry.py b/src/opendr/perception/object_detection_3d/voxel_object_detection_3d/second_detector/core/geometry.py
index 7fd9e6fa34..63ce25bb8b 100644
--- a/src/opendr/perception/object_detection_3d/voxel_object_detection_3d/second_detector/core/geometry.py
+++ b/src/opendr/perception/object_detection_3d/voxel_object_detection_3d/second_detector/core/geometry.py
@@ -92,6 +92,7 @@ def surface_equ_3d_jit(polygon_surfaces):
     return normal_vec, -d
 
 
+@numba.jit(nopython=False)
 def points_in_convex_polygon_3d_jit(points,
                                     polygon_surfaces,
                                     num_surfaces=None):
diff --git a/src/opendr/perception/object_detection_3d/voxel_object_detection_3d/second_detector/pytorch/models/variational.py b/src/opendr/perception/object_detection_3d/voxel_object_detection_3d/second_detector/pytorch/models/variational.py
new file mode 100644
index 0000000000..d3294b16d9
--- /dev/null
+++ b/src/opendr/perception/object_detection_3d/voxel_object_detection_3d/second_detector/pytorch/models/variational.py
@@ -0,0 +1,323 @@
+# Copyright 2020-2022 OpenDR European Project
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import torch
+from torch import nn
+import numpy as np
+
+
+class VariationalBase(nn.Module):
+
+    GLOBAL_STD = 0
+    LOG_STDS = False
+
+    def __init__(self) -> None:
+        super().__init__()
+
+    def build(
+        self,
+        means: nn.Module,
+        stds,
+        batch_norm_module,
+        batch_norm_size: int,
+        activation=None,
+        activation_mode="mean",
+        use_batch_norm=False,
+        batch_norm_mode="mean",
+        batch_norm_eps=1e-3,
+        batch_norm_momentum=0.01,
+        global_std_mode="none",
+    ) -> None:
+
+        super().__init__()
+
+        self.end_activation = None
+        self.end_batch_norm = None
+
+        self.means = means
+        self.stds = stds
+
+        self.global_std_mode = global_std_mode
+
+        if use_batch_norm:
+
+            batch_norm_targets = batch_norm_mode.split("+")
+
+            for i, target in enumerate(batch_norm_targets):
+
+                if target == "mean":
+                    self.means = nn.Sequential(
+                        self.means,
+                        batch_norm_module(
+                            batch_norm_size,
+                            eps=batch_norm_eps,
+                            momentum=batch_norm_momentum,
+                        ),
+                    )
+                elif target == "std":
+                    if self.stds is not None:
+                        self.stds = nn.Sequential(
+                            self.stds,
+                            batch_norm_module(
+                                batch_norm_size,
+                                eps=batch_norm_eps,
+                                momentum=batch_norm_momentum,
+                            ),
+                        )
+                elif target == "end":
+                    self.end_batch_norm = batch_norm_module(
+                        batch_norm_size,
+                        eps=batch_norm_eps,
+                        momentum=batch_norm_momentum,
+                    )
+                else:
+                    raise ValueError("Unknown batch norm target: " + target)
+
+        if activation is not None:
+
+            activation_targets = activation_mode.split("+")
+
+            for i, target in enumerate(activation_targets):
+
+                if len(activation_targets) == 1:
+                    current_activation = activation  # type: ignore
+                else:
+                    current_activation = activation[i]  # type: ignore
+
+                if target == "mean":
+                    self.means = nn.Sequential(self.means, current_activation,)
+                elif target == "std":
+                    if self.stds is not None:
+                        self.stds = nn.Sequential(self.stds, current_activation,)
+                elif target == "end":
+                    self.end_activation = current_activation
+                else:
+                    raise ValueError("Unknown activation target: " + target)
+
+    def forward(self, input):
+
+        if isinstance(input, tuple):
+            x, nstd_x = input
+        else:
+            x = input
+            nstd_x = x
+
+        means = self.means(x)
+
+        if self.stds:
+            stds = self.stds(x)
+        else:
+            stds = 0
+
+        if self.global_std_mode == "replace":
+            stds = VariationalBase.GLOBAL_STD
+        elif self.global_std_mode == "multiply":
+            stds = VariationalBase.GLOBAL_STD * stds
+
+        if self.LOG_STDS:
+
+            pstds = stds
+
+            if isinstance(stds, (int, float)):
+                pstds = torch.tensor(stds * 1.0)
+
+            print(
+                "std%:",
+                abs(
+                    float(torch.mean(pstds).detach())
+                    / float(torch.mean(means).detach())
+                    * 100
+                ),
+                "std:",
+                float(torch.mean(pstds).detach()),
+                "mean",
+                float(torch.mean(means).detach()),
+            )
+
+        # if self.is_uncertainty_layer:
+        #     result = means
+        # else:
+        #     result = torch.distributions.Normal(means, stds).rsample()
+        # result = torch.distributions.Normal(means, stds.abs() + 1e-40).rsample()
+        result = means + stds * torch.normal(0, torch.ones_like(means))
+
+        if self.end_batch_norm is not None:
+            result = self.end_batch_norm(result)
+
+        if self.end_activation is not None:
+            result = self.end_activation(result)
+
+        return result
+
+
+class VariationalConvolution(VariationalBase):
+    def __init__(
+        self,
+        in_channels: int,
+        out_channels: int,
+        kernel_size: int,
+        stride=1,
+        activation=None,
+        activation_mode="mean",
+        use_batch_norm=False,
+        batch_norm_mode="end",
+        batch_norm_eps=1e-3,
+        batch_norm_momentum=0.01,
+        global_std_mode="none",
+        bias=True,
+        **kwargs,
+    ) -> None:
+
+        super().__init__()
+
+        if use_batch_norm:
+            bias = False
+
+        means = nn.Conv2d(
+            in_channels=in_channels,
+            out_channels=out_channels,
+            kernel_size=kernel_size,
+            stride=stride,
+            bias=bias,
+            **kwargs,
+        )
+
+        if global_std_mode == "replace":
+            stds = None
+        else:
+            stds = nn.Conv2d(
+                in_channels=in_channels,
+                out_channels=out_channels,
+                kernel_size=kernel_size,
+                stride=stride,
+                bias=bias,
+                **kwargs,
+            )
+
+        super().build(
+            means,
+            stds,
+            nn.BatchNorm2d,
+            out_channels,
+            activation=activation,
+            activation_mode=activation_mode,
+            use_batch_norm=use_batch_norm,
+            batch_norm_mode=batch_norm_mode,
+            batch_norm_eps=batch_norm_eps,
+            batch_norm_momentum=batch_norm_momentum,
+            global_std_mode=global_std_mode,
+        )
+
+
+class VariationalLinear(VariationalBase):
+    def __init__(
+        self,
+        in_features: int,
+        out_features: int,
+        activation=None,
+        activation_mode="mean",
+        use_batch_norm=False,
+        batch_norm_mode="mean",
+        batch_norm_eps=1e-3,
+        batch_norm_momentum=0.01,
+        global_std_mode="none",
+        bias=True,
+        **kwargs,
+    ) -> None:
+
+        super().__init__()
+
+        if use_batch_norm:
+            bias = False
+
+        means = nn.Linear(in_features, out_features, bias=bias, **kwargs)
+
+        if global_std_mode == "replace":
+            stds = None
+        else:
+            stds = nn.Linear(in_features, out_features, bias=bias, **kwargs)
+
+        super().build(
+            means,
+            stds,
+            nn.BatchNorm1d,
+            out_features,
+            activation=activation,
+            activation_mode=activation_mode,
+            use_batch_norm=use_batch_norm,
+            batch_norm_mode=batch_norm_mode,
+            batch_norm_eps=batch_norm_eps,
+            batch_norm_momentum=batch_norm_momentum,
+            global_std_mode=global_std_mode,
+        )
+
+
+class VariationalConvolutionTranspose(VariationalBase):
+    def __init__(
+        self,
+        in_channels: int,
+        out_channels: int,
+        kernel_size: int,
+        stride=1,
+        activation=None,
+        activation_mode="mean",
+        use_batch_norm=False,
+        batch_norm_mode="end",
+        batch_norm_eps=1e-3,
+        batch_norm_momentum=0.01,
+        global_std_mode="none",
+        bias=True,
+        **kwargs,
+    ) -> None:
+
+        super().__init__()
+
+        if use_batch_norm:
+            bias = False
+
+        means = nn.ConvTranspose2d(
+            in_channels=in_channels,
+            out_channels=out_channels,
+            kernel_size=kernel_size,
+            stride=stride,
+            bias=bias,
+            **kwargs,
+        )
+
+        if global_std_mode == "replace":
+            stds = None
+        else:
+            stds = nn.ConvTranspose2d(
+                in_channels=in_channels,
+                out_channels=out_channels,
+                kernel_size=kernel_size,
+                stride=stride,
+                bias=bias,
+                **kwargs,
+            )
+
+        super().build(
+            means,
+            stds,
+            nn.BatchNorm2d,
+            out_channels,
+            activation=activation,
+            activation_mode=activation_mode,
+            use_batch_norm=use_batch_norm,
+            batch_norm_mode=batch_norm_mode,
+            batch_norm_eps=batch_norm_eps,
+            batch_norm_momentum=batch_norm_momentum,
+            global_std_mode=global_std_mode,
+        )
diff --git a/src/opendr/perception/object_detection_3d/voxel_object_detection_3d/second_detector/pytorch/models/voxelnet.py b/src/opendr/perception/object_detection_3d/voxel_object_detection_3d/second_detector/pytorch/models/voxelnet.py
index c27e491edc..96edef3662 100644
--- a/src/opendr/perception/object_detection_3d/voxel_object_detection_3d/second_detector/pytorch/models/voxelnet.py
+++ b/src/opendr/perception/object_detection_3d/voxel_object_detection_3d/second_detector/pytorch/models/voxelnet.py
@@ -35,6 +35,9 @@ from opendr.perception.object_detection_3d.voxel_object_detection_3d.second_dete
     create_refine_loss
 )
 from opendr.perception.object_detection_3d.voxel_object_detection_3d.second_detector.pytorch.utils import get_paddings_indicator
+from opendr.perception.object_detection_3d.voxel_object_detection_3d.second_detector.pytorch.models.variational import (
+    VariationalConvolution, VariationalConvolutionTranspose, VariationalLinear
+)
 
 USING_SCN = False  # default: not use SparseConv
 
@@ -509,6 +512,198 @@ class RPN(nn.Module):
             ret_dict["dir_cls_preds"] = dir_cls_preds
         return ret_dict
 
+class VRPN(nn.Module):
+    def __init__(
+        self,
+        use_norm=True,
+        num_class=2,
+        layer_nums=[3, 5, 5],
+        layer_strides=[2, 2, 2],
+        num_filters=[128, 128, 256],
+        upsample_strides=[1, 2, 4],
+        num_upsample_filters=[256, 256, 256],
+        num_input_filters=128,
+        num_anchor_per_loc=2,
+        encode_background_as_zeros=True,
+        use_direction_classifier=True,
+        use_groupnorm=False,
+        num_groups=32,
+        use_bev=False,
+        box_code_size=7,
+        name="vrpn",
+    ):
+        super(VRPN, self).__init__()
+        self._num_anchor_per_loc = num_anchor_per_loc
+        self._use_direction_classifier = use_direction_classifier
+        self._use_bev = use_bev
+        assert len(layer_nums) == 3
+        assert len(layer_strides) == len(layer_nums)
+        assert len(num_filters) == len(layer_nums)
+        assert len(upsample_strides) == len(layer_nums)
+        assert len(num_upsample_filters) == len(layer_nums)
+        factors = []
+        for i in range(len(layer_nums)):
+            assert int(np.prod(
+                layer_strides[:i + 1])) % upsample_strides[i] == 0
+            factors.append(
+                np.prod(layer_strides[:i + 1]) // upsample_strides[i])
+        assert all([x == factors[0] for x in factors])
+        if use_norm:
+            if use_groupnorm:
+                BatchNorm2d = change_default_args(num_groups=num_groups,
+                                                  eps=1e-3)(GroupNorm)
+            else:
+                BatchNorm2d = change_default_args(eps=1e-3, momentum=0.01)(
+                    nn.BatchNorm2d)
+            Conv2d = change_default_args(bias=False)(VariationalConvolution)
+            ConvTranspose2d = change_default_args(bias=False)(
+                VariationalConvolutionTranspose)
+        else:
+            BatchNorm2d = Empty
+            Conv2d = change_default_args(bias=True)(VariationalConvolution)
+            ConvTranspose2d = change_default_args(bias=True)(
+                VariationalConvolutionTranspose)
+
+        # note that when stride > 1, conv2d with same padding isn't
+        # equal to pad-conv2d. we should use pad-conv2d.
+        block2_input_filters = num_filters[0]
+        if use_bev:
+            self.bev_extractor = Sequential(
+                Conv2d(6, 32, 3, padding=1),
+                BatchNorm2d(32),
+                nn.ReLU(),
+                Conv2d(32, 64, 3, padding=1),
+                BatchNorm2d(64),
+                nn.ReLU(),
+                nn.MaxPool2d(2, 2),
+            )
+            block2_input_filters += 64
+
+        self.block1 = Sequential(
+            nn.ZeroPad2d(1),
+            Conv2d(num_input_filters,
+                   num_filters[0],
+                   3,
+                   stride=layer_strides[0]),
+            BatchNorm2d(num_filters[0]),
+            nn.ReLU(),
+        )
+        for i in range(layer_nums[0]):
+            self.block1.add(
+                Conv2d(num_filters[0], num_filters[0], 3, padding=1))
+            self.block1.add(BatchNorm2d(num_filters[0]))
+            self.block1.add(nn.ReLU())
+        self.deconv1 = Sequential(
+            ConvTranspose2d(
+                num_filters[0],
+                num_upsample_filters[0],
+                upsample_strides[0],
+                stride=upsample_strides[0],
+            ),
+            BatchNorm2d(num_upsample_filters[0]),
+            nn.ReLU(),
+        )
+        self.block2 = Sequential(
+            nn.ZeroPad2d(1),
+            Conv2d(
+                block2_input_filters,
+                num_filters[1],
+                3,
+                stride=layer_strides[1],
+            ),
+            BatchNorm2d(num_filters[1]),
+            nn.ReLU(),
+        )
+        for i in range(layer_nums[1]):
+            self.block2.add(
+                Conv2d(num_filters[1], num_filters[1], 3, padding=1))
+            self.block2.add(BatchNorm2d(num_filters[1]))
+            self.block2.add(nn.ReLU())
+        self.deconv2 = Sequential(
+            ConvTranspose2d(
+                num_filters[1],
+                num_upsample_filters[1],
+                upsample_strides[1],
+                stride=upsample_strides[1],
+            ),
+            BatchNorm2d(num_upsample_filters[1]),
+            nn.ReLU(),
+        )
+        self.block3 = Sequential(
+            nn.ZeroPad2d(1),
+            Conv2d(num_filters[1], num_filters[2], 3, stride=layer_strides[2]),
+            BatchNorm2d(num_filters[2]),
+            nn.ReLU(),
+        )
+        for i in range(layer_nums[2]):
+            self.block3.add(
+                Conv2d(num_filters[2], num_filters[2], 3, padding=1))
+            self.block3.add(BatchNorm2d(num_filters[2]))
+            self.block3.add(nn.ReLU())
+        self.deconv3 = Sequential(
+            ConvTranspose2d(
+                num_filters[2],
+                num_upsample_filters[2],
+                upsample_strides[2],
+                stride=upsample_strides[2],
+            ),
+            BatchNorm2d(num_upsample_filters[2]),
+            nn.ReLU(),
+        )
+        if encode_background_as_zeros:
+            num_cls = num_anchor_per_loc * num_class
+        else:
+            num_cls = num_anchor_per_loc * (num_class + 1)
+        self.conv_cls = nn.Conv2d(sum(num_upsample_filters), num_cls, 1)
+        self.conv_box = nn.Conv2d(sum(num_upsample_filters),
+                                  num_anchor_per_loc * box_code_size, 1)
+        if use_direction_classifier:
+            self.conv_dir_cls = nn.Conv2d(sum(num_upsample_filters),
+                                          num_anchor_per_loc * 2, 1)
+
+    def forward(self, input, bev=None, samples=5):
+
+        all_box_preds = []
+        all_cls_preds = []
+
+        for s in range(samples):
+            x = self.block1(input)
+            up1 = self.deconv1(x)
+            if self._use_bev:
+                bev[:, -1] = torch.clamp(torch.log(1 + bev[:, -1]) / np.log(16.0),
+                                        max=1.0)
+                x = torch.cat([x, self.bev_extractor(bev)], dim=1)
+            x = self.block2(x)
+            up2 = self.deconv2(x)
+            x = self.block3(x)
+            up3 = self.deconv3(x)
+            x = torch.cat([up1, up2, up3], dim=1)
+            box_preds_local = self.conv_box(x)
+            cls_preds_local = self.conv_cls(x)
+
+            all_box_preds.append(box_preds_local)
+            all_cls_preds.append(cls_preds_local)
+
+        box_preds, box_preds_var = torch.var_mean(torch.cat(all_box_preds, dim=0), dim=0, unbiased=False)
+        cls_preds, cls_preds_var = torch.var_mean(torch.cat(all_cls_preds, dim=0), dim=0, unbiased=False)
+
+
+        # [N, C, y(H), x(W)]
+        box_preds = box_preds.permute(0, 2, 3, 1).contiguous()
+        cls_preds = cls_preds.permute(0, 2, 3, 1).contiguous()
+        box_preds_var = box_preds_var.permute(0, 2, 3, 1).contiguous()
+        cls_preds_var = cls_preds_var.permute(0, 2, 3, 1).contiguous()
+        ret_dict = {
+            "box_preds": box_preds,
+            "cls_preds": cls_preds,
+            "box_preds_var": box_preds_var,
+            "cls_preds_var": cls_preds_var,
+        }
+        if self._use_direction_classifier:
+            dir_cls_preds = self.conv_dir_cls(x)
+            dir_cls_preds = dir_cls_preds.permute(0, 2, 3, 1).contiguous()
+            ret_dict["dir_cls_preds"] = dir_cls_preds
+        return ret_dict
 
 class LossNormType(Enum):
     NormByNumPositives = "norm_by_num_positives"
@@ -650,7 +845,7 @@ class VoxelNet(nn.Module):
             else:
                 num_rpn_input_filters = int(middle_num_filters_d2[-1] * 2)
 
-        rpn_class_dict = {"RPN": RPN, "PSA": PSA}
+        rpn_class_dict = {"RPN": RPN, "PSA": PSA, "VRPN": VRPN}
         self.rpn_class_name = rpn_class_name
         rpn_class = rpn_class_dict[rpn_class_name]
         self.rpn = rpn_class(
-- 
GitLab