Skip to content
Snippets Groups Projects
Commit e173e988 authored by Jakob's avatar Jakob
Browse files

image show WIP

parent fe0427fb
No related branches found
No related tags found
No related merge requests found
import math, os , sys
import torch
import torchvision
from random import random
from datetime import datetime
import matplotlib.pyplot as plt
from PIL import Image
from torchvision import transforms
import torch.optim as optim
from torch.utils.data import DataLoader
from datetime import datetime
import sys
import os
from MobileNetV1 import runCNN
from MobileNetV1 import encoder
from cocoDataset import CocoDataSet
from Yolo_loss import YoloLoss
......@@ -66,18 +67,18 @@ def train_one_epoch(yolo_network, train_loader):
if __name__ == '__main__':
EPOCHS = 5
EPOCHS = 0
BATCH_SIZE = 32
best_avg_val_loss = float("inf")
print("creating train dataset and loader")
train_set = CocoDataSet("./data/train2014", "./data/labels/train2014", transform, 1)
train_loader = DataLoader(train_set, batch_size=BATCH_SIZE,
train_dataset = CocoDataSet("./data/train2014", "./data/labels/train2014", transform, 1)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE,
shuffle=True, num_workers=0)
print("creating val dataset and loader")
val_set = CocoDataSet("./data/val2014", "./data/labels/val2014", transform, 1)
val_loader = DataLoader(val_set, batch_size=BATCH_SIZE,
val_dataset = CocoDataSet("./data/val2014", "./data/labels/val2014", transform, 1)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE,
shuffle=True, num_workers=0)
yolo_network = Yolo_v1_fcs(1280*7*7)
......@@ -85,9 +86,9 @@ if __name__ == '__main__':
# CUDA setup guide - https://www.youtube.com/watch?v=GMSjDTU8Zlc
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# convert model to cuda if available
if torch.cuda.is_available():
encoder.to('cuda')
yolo_network.fcs.to('cuda')
# if torch.cuda.is_available():
encoder.to(device)
yolo_network.fcs.to(device)
#Loss stuff:
criterion = YoloLoss()
......@@ -126,8 +127,22 @@ if __name__ == '__main__':
if (avg_val_loss < best_avg_val_loss):
print(f"New best avg. validation loss: {avg_val_loss}")
best_avg_val_loss = avg_val_loss
utils.save_model(yolo_network, optimizer)
utils.save_model(yolo_network, optimizer)
print('Finished Training')
random_image_index = math.ceil(random() * len(val_dataset.x))
image = Image.open(val_dataset.x[random_image_index])
label = val_dataset.read_annotation_file(val_dataset.y[random_image_index])
preprocessed_image, _ = val_dataset[random_image_index]
preprocessed_image = preprocessed_image.unsqueeze(0).to(device)
x = encoder(preprocessed_image)
x = yolo_network(x)
fig, ax = plt.subplots()
print('Finished Training')
\ No newline at end of file
# utils.predicted_bb_to_draw()
ax.imshow(image)
plt.show()
\ No newline at end of file
import torch
import numpy as np
from datetime import datetime
from PIL import Image
import requests
from io import BytesIO
from matplotlib import patches
def predicted_bb_to_draw(b):
def get_image(url):
res = requests.get(url)
return Image.open(BytesIO(res.content))
def draw_bb(ax, bb, color="r"):
rect = patches.Rectangle(tuple(bb[:2]), *bb[2:], linewidth=1, edgecolor=color, facecolor='none')
ax.add_patch(rect)
def predicted_bb_to_draw(b, img_size):
"""
b as tuple of the from (x,y,w,h)
where x,y is centers
where x,y,w,h is percentage relative to image size in the domain [0,1]
x, y is centers
img size is format (w,h)
"""
return (b[0] - b[2]/2,b[1] - b[3]/2, *b[2:] )
return ((b[0] - b[2]/2)*img_size[0], (b[1] - b[3]/2)*img_size[1], b[3]*img_size[0], b[4]*img_size[1])
def save_model(yolo_network, optimizer, file_name=None, folder=fr"saved_models\{datetime.strftime(datetime.now(), '%Y_%m_%d')}"):
"""
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment