From 86b50fb0756928549787ecdab9face9d9e3d16be Mon Sep 17 00:00:00 2001 From: ClF3 Date: Mon, 18 Nov 2024 23:53:38 +0800 Subject: [PATCH] modify --- model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/model.py b/model.py index 10fd812..1844ecd 100644 --- a/model.py +++ b/model.py @@ -245,7 +245,7 @@ class FastRCNN(nn.Module): # filter by threshold cls_prob = torch.softmax(cls_scores[proposal_batch_ids == img_idx], dim=1) print(cls_prob.shape) - pos_mask = cls_prob[:, 1] > thresh + pos_mask = (torch.max(cls_prob, dim=1)[0] > thresh) and (torch.max(cls_prob, dim=1)[1] != self.num_classes) print(pos_mask.shape) proposals_img = proposals[proposal_batch_ids == img_idx][pos_mask] print(proposals_img.shape)