Skip to content
Snippets Groups Projects
Commit d974a9be authored by Jakob's avatar Jakob
Browse files

added sigmoid to last layer

parent a6656f63
No related branches found
No related tags found
No related merge requests found
import torch
from functools import reduce
import torch.nn as nn
import torch.nn.functional as F
......@@ -11,6 +12,8 @@ import numpy as np
from coco_utils import get_image
from Yolo_v1_fcs import Yolo_v1_fcs
model = torch.hub.load('pytorch/vision:v0.10.0', 'mobilenet_v2', pretrained=True)
#Enable following to only get CNN part:
......@@ -44,5 +47,11 @@ if __name__=="__main__":
input_image = get_image("http://images.cocodataset.org/val2017/000000148783.jpg")
output = runCNN(input_image)
print(output.shape)
#Shape of output:
#torch.Size([1, 1280, 7, 7])
\ No newline at end of file
size = reduce(lambda x,y: x*y, output.shape)
model = Yolo_v1_fcs(size)
output = model(output)
print(output.shape);
#torch.Size([1, 1280, 7, 7])
import torch
from torch import nn
import numpy as np
class Yolo_v1_fcs(nn.Module):
def __init__(self, input_size):
super(Yolo_v1_fcs, self).__init__()
layers = []
layers.append(nn.Flatten())
layers.append(nn.Linear(input_size, 496))
layers.append(nn.Dropout(0, 0))
layers.append(nn.LeakyReLU(0.1))
layers.append(nn.Linear(496, 7*7*10))
self.fcs = nn.Sequential(*layers)
def forward(self, x):
# Run through decoder
output = self.fcs(x)
# Run sigmoid on p, x, y for ever bb
output = torch.reshape(output, (-1, 5))
output[:, :3] = torch.sigmoid(output[:, :3])
return output
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment