From 5a2ed716a36b98157aad4d4fcc1ed2cd7a2670d7 Mon Sep 17 00:00:00 2001 From: Illia Oleksiienko <io@ece.au.dk> Date: Tue, 31 Jan 2023 11:15:40 +0000 Subject: [PATCH] Add full training steps retraining --- run/modeling.py | 53 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 53 insertions(+) diff --git a/run/modeling.py b/run/modeling.py index c76246bfe8..2ae0118d38 100644 --- a/run/modeling.py +++ b/run/modeling.py @@ -151,5 +151,58 @@ def run_vnn_from_classical_2(id=0, gpu_capacity=1, total_devices=4, raise_on_inf i += gpu_capacity * total_devices +def run_vnn_from_classical_2_full_training_steps(id=0, gpu_capacity=1, total_devices=4, raise_on_infer_error=True): + + device_id = id % total_devices + i = id + + training_samples = [2, 4] + eval_samples_list = [1, 2, 3, 4] + + def create_models(): + result = [] + for init_vnn_name, init_vnn_weights in [ + ("fill0001", "fill:stds:0.001:0.001"), + ("fill00001", "fill:stds:0.0001:0.0001"), + ("xavier_uniform001", "xavier_uniform:stds:0.01:0.01"), + ("xavier_normal001", "xavier_normal:stds:0.01:0.01"), + ]: + for single_training_samples in training_samples: + name = f"vnnclass1-2{init_vnn_name}_tanet_car" + result.append( + Model( + name, + samples_list=[single_training_samples], + eval_samples_list=eval_samples_list, + init_vnn_weights=init_vnn_weights, + classical_name="tanet_car_2", + config="vnn_xyres_16.proto", + ), + ) + + return result + + models = create_models() + + while i < len(models): + model = models[i] + + try: + result = model.train( + device="cuda:" + str(device_id) + ) + print(result) + + except Exception as e: + if raise_on_infer_error: + raise e + else: + print(e) + + i += gpu_capacity * total_devices + + + + if __name__ == "__main__": fire.Fire() -- GitLab