shared linear

This commit is contained in:
ClF3 2024-11-19 13:46:49 +08:00
parent b4eb23917e
commit 36651b22de
1 changed files with 12 additions and 13 deletions

View File

@ -71,18 +71,16 @@ class FastRCNN(nn.Module):
# hidden_dim -> hidden_dim. # # hidden_dim -> hidden_dim. #
############################################################################## ##############################################################################
# Replace "pass" statement with your code # Replace "pass" statement with your code
self.cls_head = nn.Sequential( self.shared_fc = nn.Sequential(
nn.Linear(in_dim, hidden_dim), nn.Linear(in_dim, hidden_dim),
nn.Dropout(drop_ratio), nn.Dropout(drop_ratio),
nn.ReLU(), nn.ReLU(),
nn.Linear(hidden_dim, num_classes+1) nn.Linear(hidden_dim, hidden_dim)
)
self.bbox_head = nn.Sequential(
nn.Linear(in_dim, hidden_dim),
nn.Dropout(drop_ratio),
nn.ReLU(),
nn.Linear(hidden_dim, 4)
) )
self.cls_head = nn.Linear(hidden_dim, self.num_classes+1) # The cls head is a Linear layer that predicts num_classes + 1 (background).
self.bbox_head = nn.Linear(hidden_dim, 4)# The det head is a Linear layer that predicts offsets(dim=4).
############################################################################## ##############################################################################
# END OF YOUR CODE # # END OF YOUR CODE #
############################################################################## ##############################################################################
@ -139,8 +137,9 @@ class FastRCNN(nn.Module):
# print(feat.shape) # print(feat.shape)
# forward heads, get predicted cls scores & offsets # forward heads, get predicted cls scores & offsets
cls_scores=self.cls_head(feat) shared_feat = self.shared_fc(feat)
bbox_offsets=self.bbox_head(feat) cls_scores=self.cls_head(shared_feat)
bbox_offsets=self.bbox_head(shared_feat)
# print(cls_scores.shape, bbox_offsets.shape) # print(cls_scores.shape, bbox_offsets.shape)
# assign targets with proposals # assign targets with proposals
@ -216,11 +215,11 @@ class FastRCNN(nn.Module):
# perform RoI Pool & mean pool # perform RoI Pool & mean pool
feat=torchvision.ops.roi_pool(feat, torch.cat((proposal_batch_ids.unsqueeze(1), proposals),dim=1), output_size=(self.roi_output_w, self.roi_output_h)) feat=torchvision.ops.roi_pool(feat, torch.cat((proposal_batch_ids.unsqueeze(1), proposals),dim=1), output_size=(self.roi_output_w, self.roi_output_h))
feat = feat.mean(dim=[2, 3]) feat = feat.mean(dim=[2, 3])
shared_feat = self.shared_fc(feat)
# forward heads, get predicted cls scores & offsets # forward heads, get predicted cls scores & offsets
cls_scores = self.cls_head(feat) cls_scores = self.cls_head(shared_feat)
# print(cls_scores.shape) # print(cls_scores.shape)
bbox_offsets = self.bbox_head(feat) bbox_offsets = self.bbox_head(shared_feat)
# print(bbox_offsets.shape) # print(bbox_offsets.shape)
# get predicted boxes & class label & confidence probability # get predicted boxes & class label & confidence probability
proposals = generate_proposal(proposals, bbox_offsets) proposals = generate_proposal(proposals, bbox_offsets)