Skip to content
Snippets Groups Projects
Commit 0550a244 authored by Jakob's avatar Jakob
Browse files

fix bug with running loss

parent b660c8ff
No related branches found
No related tags found
No related merge requests found
import math, os , sys
import os , sys
import torch
from random import random
from datetime import datetime
import matplotlib.pyplot as plt
from PIL import Image
from torchvision import transforms
......@@ -35,10 +34,9 @@ transform = transforms.Compose([
def train_one_epoch(yolo_network, train_loader):
running_loss = []
# get the inputs; data is a list of [inputs, labels]
for i, data in enumerate(train_loader):
# get the inputs; data is a list of [inputs, labels]
# if (i == 50): break
if (i == 50): break
inputs, labels = data[0].to(device), data[1].to(device)
......@@ -55,14 +53,9 @@ def train_one_epoch(yolo_network, train_loader):
# print statistics
running_loss.append(loss.item())
# if epoch % 3 == 0:
if i % 25 == 0 and not i == 0:
print(f'[{epoch + 1}, {i + 1:5d}] loss: {sum(running_loss) / (i+1)}')
# utils.save_model(yolo_network, optimizer)
if i % 2000 == 1999: # print every 2000 mini-batches
if i % 500 == 499: # print every 500 mini-batches
print(f'[{epoch + 1}, {i + 1:5d}] loss: {sum(running_loss) / 2000:.3f}')
running_loss = 0.0
running_loss = []
if __name__ == '__main__':
......@@ -110,7 +103,7 @@ if __name__ == '__main__':
running_val_loss = 0
for i, val_data in enumerate(val_loader):
# if (i == 25): break
if (i == 25): break
inputs, labels = val_data[0].to(device), val_data[1].to(device)
......@@ -131,32 +124,4 @@ if __name__ == '__main__':
print('Finished Training')
random_image_index = math.ceil(random() * len(val_dataset.x))
image = Image.open(val_dataset.x[random_image_index])
labels = val_dataset.read_annotation_file(val_dataset.y[random_image_index])
preprocessed_image, _ = val_dataset[random_image_index]
preprocessed_image = preprocessed_image.unsqueeze(0).to(device)
nn_output = encoder(preprocessed_image)
nn_output = yolo_network(nn_output)
fig, ax = plt.subplots()
predicted_bounding_boxes = utils.filter_predictions(nn_output[0], threshhold=0.5)
predicted_bounding_boxes = utils.generalize_prediction(predicted_bounding_boxes)
for bounding_box in predicted_bounding_boxes:
confidence = bounding_box[1]
bounding_box = utils.predicted_bb_to_draw(bounding_box[1:], image.size)
utils.draw_bb(ax, bounding_box, linewidth=confidence*2)
for bounding_box in labels:
confidence = bounding_box[1]
bounding_box = utils.predicted_bb_to_draw(bounding_box[1:], image.size)
utils.draw_bb(ax, bounding_box, "g", linewidth=1.5)
print(f"{len(predicted_bounding_boxes)} good bounding boxes")
print(predicted_bounding_boxes)
ax.imshow(image)
plt.show()
\ No newline at end of file
utils.draw_random_image(yolo_network, encoder, val_dataset, device)
\ No newline at end of file
......@@ -5,6 +5,40 @@ import requests
from io import BytesIO
from matplotlib import patches
import numpy as np
import math
from random import random
import matplotlib.pyplot as plt
def draw_random_image(yolo_network, encoder, dataset, device, threshhold=0.5):
random_image_index = math.ceil(random() * len(dataset.x))
image = Image.open(dataset.x[random_image_index])
labels = dataset.read_annotation_file(dataset.y[random_image_index])
preprocessed_image, _ = dataset[random_image_index]
preprocessed_image = preprocessed_image.unsqueeze(0).to(device)
nn_output = encoder(preprocessed_image)
nn_output = yolo_network(nn_output)
fig, ax = plt.subplots()
predicted_bounding_boxes = filter_predictions(nn_output[0], threshhold=threshhold)
predicted_bounding_boxes = generalize_prediction(predicted_bounding_boxes)
for bounding_box in predicted_bounding_boxes:
confidence = bounding_box[1]
bounding_box = predicted_bb_to_draw(bounding_box[1:], image.size)
draw_bb(ax, bounding_box, linewidth=confidence*2)
for bounding_box in labels:
confidence = bounding_box[1]
bounding_box = predicted_bb_to_draw(bounding_box[1:], image.size)
draw_bb(ax, bounding_box, "g", linewidth=1.5)
print(f"{len(predicted_bounding_boxes)} good bounding boxes")
print(predicted_bounding_boxes)
ax.imshow(image)
plt.show()
def get_image(url):
res = requests.get(url)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment