diff --git a/run/train_3d.py b/run/train_3d.py index 686a1f3a773fe4643a9cefb3cc07c9813707a8bd..4eb63549f1c4703958440b15de46b9417d328893 100644 --- a/run/train_3d.py +++ b/run/train_3d.py @@ -196,6 +196,19 @@ def train_lens_pointpillars( device=device, load=load, name=name, samples_list=samples_list, config=config ) + + +def train_blens_pointpillars( + device="cuda:0", + load=0, + name="blens_pointpillars_car", + samples_list=[4], + config="blens_xyres_16.proto", +): + return train_pointpillars( + device=device, load=load, name=name, samples_list=samples_list, config=config + ) + def train_vnn_tanet( device="cuda:0", load=0, diff --git a/src/opendr/perception/object_detection_3d/voxel_object_detection_3d/second_detector/configs/pointpillars/car/blens_xyres_16.proto b/src/opendr/perception/object_detection_3d/voxel_object_detection_3d/second_detector/configs/pointpillars/car/blens_xyres_16.proto new file mode 100644 index 0000000000000000000000000000000000000000..0d0125f66faa5cd5411761f093df1177c94bf928 --- /dev/null +++ b/src/opendr/perception/object_detection_3d/voxel_object_detection_3d/second_detector/configs/pointpillars/car/blens_xyres_16.proto @@ -0,0 +1,173 @@ +model: { + second: { + voxel_generator { + point_cloud_range : [0, -39.68, -3, 69.12, 39.68, 1] + voxel_size : [0.16, 0.16, 4] + max_number_of_points_per_voxel : 100 + } + num_class: 1 + voxel_feature_extractor: { + module_class_name: "PillarFeatureNet" + num_filters: [64] + with_distance: false + } + middle_feature_extractor: { + module_class_name: "PointPillarsScatter" + } + rpn: { + module_class_name: "BLEnsRPN" + layer_nums: [3, 5, 5] + layer_strides: [2, 2, 2] + num_filters: [64, 128, 256] + upsample_strides: [1, 2, 4] + num_upsample_filters: [128, 128, 128] + use_groupnorm: false + num_groups: 32 + } + loss: { + classification_loss: { + weighted_sigmoid_focal: { + alpha: 0.25 + gamma: 2.0 + anchorwise_output: true + } + } + localization_loss: { + weighted_smooth_l1: { + sigma: 3.0 + code_weight: [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0] + } + } + classification_weight: 1.0 + localization_weight: 2.0 + } + # Outputs + use_sigmoid_score: true + encode_background_as_zeros: true + encode_rad_error_by_sin: true + + use_direction_classifier: true + direction_loss_weight: 0.2 + use_aux_classifier: false + # Loss + pos_class_weight: 1.0 + neg_class_weight: 1.0 + + loss_norm_type: NormByNumPositives + # Postprocess + post_center_limit_range: [0, -39.68, -5, 69.12, 39.68, 5] + use_rotate_nms: false + use_multi_class_nms: false + nms_pre_max_size: 1000 + nms_post_max_size: 300 + nms_score_threshold: 0.05 + nms_iou_threshold: 0.5 + + use_bev: false + num_point_features: 4 + without_reflectivity: false + box_coder: { + ground_box3d_coder: { + linear_dim: false + encode_angle_vector: false + } + } + target_assigner: { + anchor_generators: { + anchor_generator_stride: { + sizes: [1.6, 3.9, 1.56] # wlh + strides: [0.32, 0.32, 0.0] # if generate only 1 z_center, z_stride will be ignored + offsets: [0.16, -39.52, -1.78] # origin_offset + strides / 2 + rotations: [0, 1.57] # 0, pi/2 + matched_threshold : 0.6 + unmatched_threshold : 0.45 + } + } + + sample_positive_fraction : -1 + sample_size : 512 + region_similarity_calculator: { + nearest_iou_similarity: { + } + } + } + } +} + + +train_input_reader: { + record_file_path: "kitti_train.tfrecord" + class_names: ["Car"] + max_num_epochs : 160 + batch_size: 1 + prefetch_size : 25 + max_number_of_voxels: 12000 + shuffle_points: true + num_workers: 2 + groundtruth_localization_noise_std: [0.25, 0.25, 0.25] + groundtruth_rotation_uniform_noise: [-0.15707963267, 0.15707963267] + global_rotation_uniform_noise: [-0.78539816, 0.78539816] + global_scaling_uniform_noise: [0.95, 1.05] + global_random_rotation_range_per_object: [0, 0] + anchor_area_threshold: 1 + remove_points_after_sample: false + groundtruth_points_drop_percentage: 0.0 + groundtruth_drop_max_keep_points: 15 + database_sampler { + database_info_path: "kitti_dbinfos_train.pkl" + sample_groups { + name_to_max_num { + key: "Car" + value: 15 + } + } + database_prep_steps { + filter_by_min_num_points { + min_num_point_pairs { + key: "Car" + value: 5 + } + } + } + database_prep_steps { + filter_by_difficulty { + removed_difficulties: [-1] + } + } + global_random_rotation_range_per_object: [0, 0] + rate: 1.0 + } + + remove_unknown_examples: false + remove_environment: false + kitti_info_path: "kitti_infos_train.pkl" + kitti_root_path: "" +} + +train_config: { + + inter_op_parallelism_threads: 4 + intra_op_parallelism_threads: 4 + steps: 296960 # 1856 steps per epoch * 160 epochs + steps_per_eval: 9280 # 1856 steps per epoch * 5 epochs + save_checkpoints_secs : 1800 # half hour + save_summary_steps : 10 + enable_mixed_precision: false + loss_scale_factor : 512.0 + clear_metrics_every_epoch: false +} + +eval_input_reader: { + record_file_path: "kitti_val.tfrecord" + class_names: ["Car"] + batch_size: 1 + max_num_epochs : 160 + prefetch_size : 25 + max_number_of_voxels: 12000 + shuffle_points: false + num_workers: 3 + anchor_area_threshold: 1 + remove_environment: false + kitti_info_path: "kitti_infos_val.pkl" + kitti_root_path: "" +} diff --git a/src/opendr/perception/object_detection_3d/voxel_object_detection_3d/second_detector/pytorch/models/voxelnet.py b/src/opendr/perception/object_detection_3d/voxel_object_detection_3d/second_detector/pytorch/models/voxelnet.py index 6347b223d0d0560beeaca40c684e0a0f61519255..8d5b211a6cee794979e34d90ea080889efa874e6 100644 --- a/src/opendr/perception/object_detection_3d/voxel_object_detection_3d/second_detector/pytorch/models/voxelnet.py +++ b/src/opendr/perception/object_detection_3d/voxel_object_detection_3d/second_detector/pytorch/models/voxelnet.py @@ -1090,10 +1090,222 @@ def createLEnsRPNBase(Convolution, ConvolutionTranspose): return LEnsRPNBase + +def createBlockLEnsRPNBase(Convolution: LayerEnsembleBase, ConvolutionTranspose: LayerEnsembleBase): + class BLEnsRPNBase(nn.Module): + def __init__( + self, + use_norm=True, + num_class=2, + layer_nums=[3, 5, 5], + layer_strides=[2, 2, 2], + num_filters=[128, 128, 256], + upsample_strides=[1, 2, 4], + num_upsample_filters=[256, 256, 256], + num_input_filters=128, + num_anchor_per_loc=2, + encode_background_as_zeros=True, + use_direction_classifier=True, + use_groupnorm=False, + num_groups=32, + use_bev=False, + box_code_size=7, + name="lensrpn", + ): + super(BLEnsRPNBase, self).__init__() + self._num_anchor_per_loc = num_anchor_per_loc + self._use_direction_classifier = use_direction_classifier + self._use_bev = use_bev + assert len(layer_nums) == 3 + assert len(layer_strides) == len(layer_nums) + assert len(num_filters) == len(layer_nums) + assert len(upsample_strides) == len(layer_nums) + assert len(num_upsample_filters) == len(layer_nums) + factors = [] + for i in range(len(layer_nums)): + assert int(np.prod( + layer_strides[:i + 1])) % upsample_strides[i] == 0 + factors.append( + np.prod(layer_strides[:i + 1]) // upsample_strides[i]) + assert all([x == factors[0] for x in factors]) + if use_norm: + if use_groupnorm: + BatchNorm2d = change_default_args(num_groups=num_groups, + eps=1e-3)(GroupNorm) + else: + BatchNorm2d = change_default_args(eps=1e-3, momentum=0.01)( + nn.BatchNorm2d) + Conv2d = lambda *args, **kwargs: Convolution(bias=False, *args, **kwargs) + ConvTranspose2d = lambda *args, **kwargs: ConvolutionTranspose(bias=False, *args, **kwargs) + else: + BatchNorm2d = Empty + Conv2d = lambda *args, **kwargs: Convolution(bias=True, *args, **kwargs) + ConvTranspose2d = lambda *args, **kwargs: ConvolutionTranspose(bias=True, *args, **kwargs) + + # note that when stride > 1, conv2d with same padding isn't + # equal to pad-conv2d. we should use pad-conv2d. + + block2_input_filters = num_filters[0] + if use_bev: + LayerEnsembleBase.start_subcollection() + self.bev_extractor = Sequential( + Conv2d(6, 32, 3, padding=1), + BatchNorm2d(32), + nn.ReLU(), + Conv2d(32, 64, 3, padding=1), + BatchNorm2d(64), + nn.ReLU(), + nn.MaxPool2d(2, 2), + ) + block2_input_filters += 64 + LayerEnsembleBase.connect(LayerEnsembleBase.subcollect()) + + LayerEnsembleBase.start_subcollection() + + self.block1 = Sequential( + nn.ZeroPad2d(1), + Conv2d(num_input_filters, + num_filters[0], + 3, + stride=layer_strides[0]), + BatchNorm2d(num_filters[0]), + nn.ReLU(), + ) + for i in range(layer_nums[0]): + self.block1.add( + Conv2d(num_filters[0], num_filters[0], 3, padding=1)) + self.block1.add(BatchNorm2d(num_filters[0])) + self.block1.add(nn.ReLU()) + + LayerEnsembleBase.connect(LayerEnsembleBase.subcollect()) + LayerEnsembleBase.start_subcollection() + + self.deconv1 = Sequential( + ConvTranspose2d( + num_filters[0], + num_upsample_filters[0], + upsample_strides[0], + stride=upsample_strides[0], + ), + BatchNorm2d(num_upsample_filters[0]), + nn.ReLU(), + ) + + LayerEnsembleBase.connect(LayerEnsembleBase.subcollect()) + LayerEnsembleBase.start_subcollection() + + self.block2 = Sequential( + nn.ZeroPad2d(1), + Conv2d( + block2_input_filters, + num_filters[1], + 3, + stride=layer_strides[1], + ), + BatchNorm2d(num_filters[1]), + nn.ReLU(), + ) + + for i in range(layer_nums[1]): + self.block2.add( + Conv2d(num_filters[1], num_filters[1], 3, padding=1)) + self.block2.add(BatchNorm2d(num_filters[1])) + self.block2.add(nn.ReLU()) + + LayerEnsembleBase.connect(LayerEnsembleBase.subcollect()) + LayerEnsembleBase.start_subcollection() + + self.deconv2 = Sequential( + ConvTranspose2d( + num_filters[1], + num_upsample_filters[1], + upsample_strides[1], + stride=upsample_strides[1], + ), + BatchNorm2d(num_upsample_filters[1]), + nn.ReLU(), + ) + + LayerEnsembleBase.connect(LayerEnsembleBase.subcollect()) + LayerEnsembleBase.start_subcollection() + + self.block3 = Sequential( + nn.ZeroPad2d(1), + Conv2d(num_filters[1], num_filters[2], 3, stride=layer_strides[2]), + BatchNorm2d(num_filters[2]), + nn.ReLU(), + ) + for i in range(layer_nums[2]): + self.block3.add( + Conv2d(num_filters[2], num_filters[2], 3, padding=1)) + self.block3.add(BatchNorm2d(num_filters[2])) + self.block3.add(nn.ReLU()) + + LayerEnsembleBase.connect(LayerEnsembleBase.subcollect()) + LayerEnsembleBase.start_subcollection() + + self.deconv3 = Sequential( + ConvTranspose2d( + num_filters[2], + num_upsample_filters[2], + upsample_strides[2], + stride=upsample_strides[2], + ), + BatchNorm2d(num_upsample_filters[2]), + nn.ReLU(), + ) + + LayerEnsembleBase.connect(LayerEnsembleBase.subcollect()) + + if encode_background_as_zeros: + num_cls = num_anchor_per_loc * num_class + else: + num_cls = num_anchor_per_loc * (num_class + 1) + self.conv_cls = nn.Conv2d(sum(num_upsample_filters), num_cls, 1) + self.conv_box = nn.Conv2d(sum(num_upsample_filters), + num_anchor_per_loc * box_code_size, 1) + if use_direction_classifier: + self.conv_dir_cls = nn.Conv2d(sum(num_upsample_filters), + num_anchor_per_loc * 2, 1) + + def forward(self, x, bev=None, sample=None): + x = self.block1(x) + up1 = self.deconv1(x) + if self._use_bev: + bev[:, -1] = torch.clamp(torch.log(1 + bev[:, -1]) / np.log(16.0), + max=1.0) + x = torch.cat([x, self.bev_extractor(bev)], dim=1) + x = self.block2(x) + up2 = self.deconv2(x) + x = self.block3(x) + up3 = self.deconv3(x) + x = torch.cat([up1, up2, up3], dim=1) + box_preds = self.conv_box(x) + cls_preds = self.conv_cls(x) + # [N, C, y(H), x(W)] + box_preds = box_preds.permute(0, 2, 3, 1).contiguous() + cls_preds = cls_preds.permute(0, 2, 3, 1).contiguous() + ret_dict = { + "box_preds": box_preds, + "cls_preds": cls_preds, + } + if self._use_direction_classifier: + dir_cls_preds = self.conv_dir_cls(x) + dir_cls_preds = dir_cls_preds.permute(0, 2, 3, 1).contiguous() + ret_dict["dir_cls_preds"] = dir_cls_preds + return ret_dict + + return BLEnsRPNBase + + def LEnsRPN(*args, **kwargs): return create_layer_ensemble_network(createLEnsRPNBase, *args, **kwargs) +def BLEnsRPN(*args, **kwargs): + return create_layer_ensemble_network(createBlockLEnsRPNBase, *args, **kwargs) + + class LossNormType(Enum): NormByNumPositives = "norm_by_num_positives" NormByNumExamples = "norm_by_num_examples" @@ -1237,7 +1449,7 @@ class VoxelNet(nn.Module): rpn_class_dict = { "RPN": RPN, "PSA": PSA, "VPSA": VPSA, "EnsPSA": EnsPSA, "VRPN": VRPN, "VARPN": VARPN, - "LEnsRPN": LEnsRPN, + "LEnsRPN": LEnsRPN, "BLEnsRPN": BLEnsRPN, } self.rpn_class_name = rpn_class_name rpn_class = rpn_class_dict[rpn_class_name] diff --git a/src/opendr/perception/object_detection_3d/voxel_object_detection_3d/second_detector/run.py b/src/opendr/perception/object_detection_3d/voxel_object_detection_3d/second_detector/run.py index 6302cac412e032090957a0f5277326e61449192b..7bd87f6b93b93f2b88e1aa06aa722c85a420dea2 100644 --- a/src/opendr/perception/object_detection_3d/voxel_object_detection_3d/second_detector/run.py +++ b/src/opendr/perception/object_detection_3d/voxel_object_detection_3d/second_detector/run.py @@ -709,7 +709,7 @@ def predict_kitti_to_anno( ) return annos_coarse, annos_refine else: - predictions_dicts_coarse = net(example, samples=samples) + predictions_dicts_coarse = net(example, samples=samples, combine_predictions=True) annos_coarse = comput_kitti_output( predictions_dicts_coarse, batch_image_shape,