Skip to content
Snippets Groups Projects
Commit 5a2ed716 authored by Illia Oleksiienko's avatar Illia Oleksiienko
Browse files

Add full training steps retraining

parent 3e9e4564
No related branches found
No related tags found
No related merge requests found
......@@ -151,5 +151,58 @@ def run_vnn_from_classical_2(id=0, gpu_capacity=1, total_devices=4, raise_on_inf
i += gpu_capacity * total_devices
def run_vnn_from_classical_2_full_training_steps(id=0, gpu_capacity=1, total_devices=4, raise_on_infer_error=True):
device_id = id % total_devices
i = id
training_samples = [2, 4]
eval_samples_list = [1, 2, 3, 4]
def create_models():
result = []
for init_vnn_name, init_vnn_weights in [
("fill0001", "fill:stds:0.001:0.001"),
("fill00001", "fill:stds:0.0001:0.0001"),
("xavier_uniform001", "xavier_uniform:stds:0.01:0.01"),
("xavier_normal001", "xavier_normal:stds:0.01:0.01"),
]:
for single_training_samples in training_samples:
name = f"vnnclass1-2{init_vnn_name}_tanet_car"
result.append(
Model(
name,
samples_list=[single_training_samples],
eval_samples_list=eval_samples_list,
init_vnn_weights=init_vnn_weights,
classical_name="tanet_car_2",
config="vnn_xyres_16.proto",
),
)
return result
models = create_models()
while i < len(models):
model = models[i]
try:
result = model.train(
device="cuda:" + str(device_id)
)
print(result)
except Exception as e:
if raise_on_infer_error:
raise e
else:
print(e)
i += gpu_capacity * total_devices
if __name__ == "__main__":
fire.Fire()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment