From 89c3326a59dd74192a3f2553e21855ce28652646 Mon Sep 17 00:00:00 2001 From: ClF3 Date: Tue, 19 Nov 2024 12:42:06 +0800 Subject: [PATCH] complete but low mAP --- model.py | 35 +++++++++++++++++++++-------------- 1 file changed, 21 insertions(+), 14 deletions(-) diff --git a/model.py b/model.py index 1844ecd..405a21e 100644 --- a/model.py +++ b/model.py @@ -230,9 +230,9 @@ class FastRCNN(nn.Module): # forward heads, get predicted cls scores & offsets cls_scores = self.cls_head(feat) - print(cls_scores.shape) + # print(cls_scores.shape) bbox_offsets = self.bbox_head(feat) - print(bbox_offsets.shape) + # print(bbox_offsets.shape) # get predicted boxes & class label & confidence probability proposals = generate_proposal(proposals, bbox_offsets) @@ -244,21 +244,28 @@ class FastRCNN(nn.Module): # filter by threshold cls_prob = torch.softmax(cls_scores[proposal_batch_ids == img_idx], dim=1) - print(cls_prob.shape) - pos_mask = (torch.max(cls_prob, dim=1)[0] > thresh) and (torch.max(cls_prob, dim=1)[1] != self.num_classes) - print(pos_mask.shape) - proposals_img = proposals[proposal_batch_ids == img_idx][pos_mask] - print(proposals_img.shape) - - print(cls_prob.shape) - final_proposals.append(proposals_img) - final_conf_probs.append(cls_prob[pos_mask, 1].unsqueeze(1)) + # print(cls_prob) + # print(torch.max(cls_prob, dim=1)[1].shape) + # print(torch.max(cls_prob, dim=1)[0]) + # print(torch.max(cls_prob, dim=1)[1]) + + pos_mask = torch.max(cls_prob, dim=1)[0] > thresh + not_bg_mask = torch.max(cls_prob, dim=1)[1] != self.num_classes + # print(pos_mask) + # print(not_bg_mask) + total_mask = pos_mask & not_bg_mask + # print(final_mask) + # print(pos_mask.shape) + proposals_obj = proposals[proposal_batch_ids == img_idx][total_mask] + conf_probs=torch.max(cls_prob, dim=1)[0][total_mask] + class_idx = torch.max(cls_prob, dim=1)[1][total_mask] # nms - keep = torchvision.ops.nms(proposals_img, cls_prob[:, 1], nms_thresh) - proposals_img = proposals_img[keep] - cls_prob = cls_prob[keep] + keep = torchvision.ops.nms(proposals_obj, conf_probs, nms_thresh) + final_proposals.append(proposals_obj[keep]) + final_conf_probs.append(conf_probs[keep].unsqueeze(1)) + final_class.append(class_idx[keep].unsqueeze(1)) ##############################################################################