fixed loss compute

This commit is contained in:
ClF3 2024-11-19 13:18:08 +08:00
parent 89c3326a59
commit c2b1e5c0fa
1 changed files with 8 additions and 19 deletions

View File

@ -154,26 +154,15 @@ class FastRCNN(nn.Module):
neg_masks.append(neg_mask)
GT_labels.append(GT_label)
GT_bboxes.append(GT_bbox)
GT_labels = torch.cat(GT_labels)
GT_bboxes = torch.cat(GT_bboxes)
pos_masks = torch.cat(pos_masks)
neg_masks = torch.cat(neg_masks)
# compute loss
cls_loss = 0
img_idx = 0
for GT_label in GT_labels:
# print(cls_scores.shape, GT_label.shape)
cls_loss += ClsScoreRegression(cls_scores[proposal_batch_ids==img_idx,:], GT_label, B)
img_idx += 1
bbox_loss = 0
img_idx=0
cls_loss = ClsScoreRegression(cls_scores[pos_masks|neg_masks], GT_labels[pos_masks|neg_masks], B)
for GT_bbox in GT_bboxes:
bbox_offsets_cur=bbox_offsets[proposal_batch_ids==img_idx,:]
pos_box_offsets = bbox_offsets_cur[pos_masks[img_idx],:]
proposals_cur = proposals[proposal_batch_ids==img_idx,:]
pos_proposals = proposals_cur[pos_masks[img_idx],:]
# print(pos_box_offsets.shape, GT_bbox.shape)
bbox_loss += BboxRegression(pos_box_offsets, compute_offsets(pos_proposals, GT_bbox), B)
img_idx += 1
total_loss=cls_loss+bbox_loss
bbox_loss = BboxRegression(bbox_offsets[pos_masks], compute_offsets(proposals[pos_masks], GT_bboxes), B)
total_loss=w_cls*cls_loss+w_bbox*bbox_loss
##############################################################################
# END OF YOUR CODE #