complete but low mAP

This commit is contained in:
ClF3 2024-11-19 12:42:06 +08:00
parent 86b50fb075
commit 89c3326a59
1 changed files with 21 additions and 14 deletions

View File

@ -230,9 +230,9 @@ class FastRCNN(nn.Module):
# forward heads, get predicted cls scores & offsets # forward heads, get predicted cls scores & offsets
cls_scores = self.cls_head(feat) cls_scores = self.cls_head(feat)
print(cls_scores.shape) # print(cls_scores.shape)
bbox_offsets = self.bbox_head(feat) bbox_offsets = self.bbox_head(feat)
print(bbox_offsets.shape) # print(bbox_offsets.shape)
# get predicted boxes & class label & confidence probability # get predicted boxes & class label & confidence probability
proposals = generate_proposal(proposals, bbox_offsets) proposals = generate_proposal(proposals, bbox_offsets)
@ -244,21 +244,28 @@ class FastRCNN(nn.Module):
# filter by threshold # filter by threshold
cls_prob = torch.softmax(cls_scores[proposal_batch_ids == img_idx], dim=1) cls_prob = torch.softmax(cls_scores[proposal_batch_ids == img_idx], dim=1)
print(cls_prob.shape) # print(cls_prob)
pos_mask = (torch.max(cls_prob, dim=1)[0] > thresh) and (torch.max(cls_prob, dim=1)[1] != self.num_classes) # print(torch.max(cls_prob, dim=1)[1].shape)
print(pos_mask.shape) # print(torch.max(cls_prob, dim=1)[0])
proposals_img = proposals[proposal_batch_ids == img_idx][pos_mask] # print(torch.max(cls_prob, dim=1)[1])
print(proposals_img.shape)
pos_mask = torch.max(cls_prob, dim=1)[0] > thresh
print(cls_prob.shape) not_bg_mask = torch.max(cls_prob, dim=1)[1] != self.num_classes
final_proposals.append(proposals_img) # print(pos_mask)
final_conf_probs.append(cls_prob[pos_mask, 1].unsqueeze(1)) # print(not_bg_mask)
total_mask = pos_mask & not_bg_mask
# print(final_mask)
# print(pos_mask.shape)
proposals_obj = proposals[proposal_batch_ids == img_idx][total_mask]
conf_probs=torch.max(cls_prob, dim=1)[0][total_mask]
class_idx = torch.max(cls_prob, dim=1)[1][total_mask]
# nms # nms
keep = torchvision.ops.nms(proposals_img, cls_prob[:, 1], nms_thresh) keep = torchvision.ops.nms(proposals_obj, conf_probs, nms_thresh)
proposals_img = proposals_img[keep] final_proposals.append(proposals_obj[keep])
cls_prob = cls_prob[keep] final_conf_probs.append(conf_probs[keep].unsqueeze(1))
final_class.append(class_idx[keep].unsqueeze(1))
############################################################################## ##############################################################################