Skip to content
Snippets Groups Projects
Commit 6a3104a8 authored by Jeppe Gade's avatar Jeppe Gade
Browse files

Stuff done for training

parent d0873606
No related branches found
No related tags found
No related merge requests found
......@@ -37,7 +37,6 @@ def train_one_epoch(yolo_network, train_loader):
# get the inputs; data is a list of [inputs, labels]
for i, data in enumerate(train_loader):
if (i == 5): break
inputs, labels = data[0].to(device), data[1].to(device)
......@@ -68,7 +67,7 @@ if __name__ == '__main__':
validation_losses = []
print("creating train dataset and loader")
train_dataset = CocoDataSet("./data/train2014", "./data/labels/train2014", transform, 1)
train_dataset = CocoDataSet("./data/train2014", "./data/labels/train2014", transform, 0)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE,
shuffle=True, num_workers=0)
......@@ -101,12 +100,11 @@ if __name__ == '__main__':
yolo_network.train(True)
train_one_epoch(yolo_network, train_loader)
yolo_network.train(False)
yolo_network.train(False)
print("Start Validation")
running_val_loss = 0
for i, val_data in enumerate(val_loader):
if (i == 2): break
inputs, labels = val_data[0].to(device), val_data[1].to(device)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment