modify
This commit is contained in:
parent
2bd4f09891
commit
86b50fb075
2
model.py
2
model.py
|
@ -245,7 +245,7 @@ class FastRCNN(nn.Module):
|
||||||
# filter by threshold
|
# filter by threshold
|
||||||
cls_prob = torch.softmax(cls_scores[proposal_batch_ids == img_idx], dim=1)
|
cls_prob = torch.softmax(cls_scores[proposal_batch_ids == img_idx], dim=1)
|
||||||
print(cls_prob.shape)
|
print(cls_prob.shape)
|
||||||
pos_mask = cls_prob[:, 1] > thresh
|
pos_mask = (torch.max(cls_prob, dim=1)[0] > thresh) and (torch.max(cls_prob, dim=1)[1] != self.num_classes)
|
||||||
print(pos_mask.shape)
|
print(pos_mask.shape)
|
||||||
proposals_img = proposals[proposal_batch_ids == img_idx][pos_mask]
|
proposals_img = proposals[proposal_batch_ids == img_idx][pos_mask]
|
||||||
print(proposals_img.shape)
|
print(proposals_img.shape)
|
||||||
|
|
Loading…
Reference in New Issue