This commit is contained in:
ClF3 2024-11-18 23:53:38 +08:00
parent 2bd4f09891
commit 86b50fb075
1 changed files with 1 additions and 1 deletions

View File

@ -245,7 +245,7 @@ class FastRCNN(nn.Module):
# filter by threshold # filter by threshold
cls_prob = torch.softmax(cls_scores[proposal_batch_ids == img_idx], dim=1) cls_prob = torch.softmax(cls_scores[proposal_batch_ids == img_idx], dim=1)
print(cls_prob.shape) print(cls_prob.shape)
pos_mask = cls_prob[:, 1] > thresh pos_mask = (torch.max(cls_prob, dim=1)[0] > thresh) and (torch.max(cls_prob, dim=1)[1] != self.num_classes)
print(pos_mask.shape) print(pos_mask.shape)
proposals_img = proposals[proposal_batch_ids == img_idx][pos_mask] proposals_img = proposals[proposal_batch_ids == img_idx][pos_mask]
print(proposals_img.shape) print(proposals_img.shape)