added report
This commit is contained in:
parent
a484bfa9ad
commit
2e92990035
4
model.py
4
model.py
|
@ -16,7 +16,7 @@ class FeatureExtractor(nn.Module):
|
||||||
def __init__(self, reshape_size=224, pooling=False, verbose=False):
|
def __init__(self, reshape_size=224, pooling=False, verbose=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.mobilenet = models.mobilenet_v2(pretrained=True)
|
self.mobilenet = models.mobilenet_v2(weights=models.MobileNet_V2_Weights)
|
||||||
self.mobilenet = nn.Sequential(*list(self.mobilenet.children())[:-1]) # Remove the last classifier
|
self.mobilenet = nn.Sequential(*list(self.mobilenet.children())[:-1]) # Remove the last classifier
|
||||||
|
|
||||||
# average pooling
|
# average pooling
|
||||||
|
@ -131,7 +131,7 @@ class FastRCNN(nn.Module):
|
||||||
# print(feat.shape)
|
# print(feat.shape)
|
||||||
|
|
||||||
# perform RoI Pool & mean pool
|
# perform RoI Pool & mean pool
|
||||||
feat=torchvision.ops.roi_pool(feat, torch.cat((proposal_batch_ids.unsqueeze(1), proposals),dim=1), output_size=(self.roi_output_w, self.roi_output_h))
|
feat=torchvision.ops.roi_pool(feat, torch.cat((proposal_batch_ids.unsqueeze(1), proposals),dim=1), output_size=(self.roi_output_h, self.roi_output_w))
|
||||||
# print(feat.shape)
|
# print(feat.shape)
|
||||||
feat=feat.mean(dim=[2,3])
|
feat=feat.mean(dim=[2,3])
|
||||||
# print(feat.shape)
|
# print(feat.shape)
|
||||||
|
|
Binary file not shown.
Binary file not shown.
Loading…
Reference in New Issue