Skip to content
Snippets Groups Projects
Commit 33f638c4 authored by Jakob's avatar Jakob
Browse files

changed mobilenet to decoder

parent fcf4a86d
No related branches found
No related tags found
No related merge requests found
......@@ -18,10 +18,10 @@ from Yolo_v1_fcs import Yolo_v1_fcs
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
model = torch.hub.load('pytorch/vision:v0.10.0', 'mobilenet_v2', pretrained=True)
decoder = torch.hub.load('pytorch/vision:v0.10.0', 'mobilenet_v2', pretrained=True)
#Enable following to only get CNN part:
model = torch.nn.Sequential(*(list(model.children())[:-1]))
model.eval()
decoder = torch.nn.Sequential(*(list(decoder.children())[:-1]))
decoder.eval()
preprocess = transforms.Compose([
transforms.Resize(256),
......@@ -40,9 +40,9 @@ def runCNN(image):
input_batch = preprocessImage(image)
if torch.cuda.is_available():
input_batch = input_batch.to('cuda')
model.to('cuda')
decoder.to('cuda')
with torch.no_grad():
return model(input_batch)
return decoder(input_batch)
if __name__=="__main__":
......@@ -66,9 +66,9 @@ if __name__=="__main__":
print(output.shape)
size = reduce(lambda x,y: x*y, output.shape)
model = Yolo_v1_fcs(size)
decoder = Yolo_v1_fcs(size)
output = model(output)
output = decoder(output)
fig, ax = plt.subplots()
ax.imshow(input_image)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment