complete but low mAP
This commit is contained in:
parent
86b50fb075
commit
89c3326a59
35
model.py
35
model.py
|
@ -230,9 +230,9 @@ class FastRCNN(nn.Module):
|
||||||
|
|
||||||
# forward heads, get predicted cls scores & offsets
|
# forward heads, get predicted cls scores & offsets
|
||||||
cls_scores = self.cls_head(feat)
|
cls_scores = self.cls_head(feat)
|
||||||
print(cls_scores.shape)
|
# print(cls_scores.shape)
|
||||||
bbox_offsets = self.bbox_head(feat)
|
bbox_offsets = self.bbox_head(feat)
|
||||||
print(bbox_offsets.shape)
|
# print(bbox_offsets.shape)
|
||||||
# get predicted boxes & class label & confidence probability
|
# get predicted boxes & class label & confidence probability
|
||||||
proposals = generate_proposal(proposals, bbox_offsets)
|
proposals = generate_proposal(proposals, bbox_offsets)
|
||||||
|
|
||||||
|
@ -244,21 +244,28 @@ class FastRCNN(nn.Module):
|
||||||
|
|
||||||
# filter by threshold
|
# filter by threshold
|
||||||
cls_prob = torch.softmax(cls_scores[proposal_batch_ids == img_idx], dim=1)
|
cls_prob = torch.softmax(cls_scores[proposal_batch_ids == img_idx], dim=1)
|
||||||
print(cls_prob.shape)
|
# print(cls_prob)
|
||||||
pos_mask = (torch.max(cls_prob, dim=1)[0] > thresh) and (torch.max(cls_prob, dim=1)[1] != self.num_classes)
|
# print(torch.max(cls_prob, dim=1)[1].shape)
|
||||||
print(pos_mask.shape)
|
# print(torch.max(cls_prob, dim=1)[0])
|
||||||
proposals_img = proposals[proposal_batch_ids == img_idx][pos_mask]
|
# print(torch.max(cls_prob, dim=1)[1])
|
||||||
print(proposals_img.shape)
|
|
||||||
|
pos_mask = torch.max(cls_prob, dim=1)[0] > thresh
|
||||||
print(cls_prob.shape)
|
not_bg_mask = torch.max(cls_prob, dim=1)[1] != self.num_classes
|
||||||
final_proposals.append(proposals_img)
|
# print(pos_mask)
|
||||||
final_conf_probs.append(cls_prob[pos_mask, 1].unsqueeze(1))
|
# print(not_bg_mask)
|
||||||
|
total_mask = pos_mask & not_bg_mask
|
||||||
|
# print(final_mask)
|
||||||
|
# print(pos_mask.shape)
|
||||||
|
proposals_obj = proposals[proposal_batch_ids == img_idx][total_mask]
|
||||||
|
conf_probs=torch.max(cls_prob, dim=1)[0][total_mask]
|
||||||
|
class_idx = torch.max(cls_prob, dim=1)[1][total_mask]
|
||||||
|
|
||||||
|
|
||||||
# nms
|
# nms
|
||||||
keep = torchvision.ops.nms(proposals_img, cls_prob[:, 1], nms_thresh)
|
keep = torchvision.ops.nms(proposals_obj, conf_probs, nms_thresh)
|
||||||
proposals_img = proposals_img[keep]
|
final_proposals.append(proposals_obj[keep])
|
||||||
cls_prob = cls_prob[keep]
|
final_conf_probs.append(conf_probs[keep].unsqueeze(1))
|
||||||
|
final_class.append(class_idx[keep].unsqueeze(1))
|
||||||
|
|
||||||
|
|
||||||
##############################################################################
|
##############################################################################
|
||||||
|
|
Loading…
Reference in New Issue