fixed loss compute
This commit is contained in:
parent
89c3326a59
commit
c2b1e5c0fa
27
model.py
27
model.py
|
@ -154,26 +154,15 @@ class FastRCNN(nn.Module):
|
|||
neg_masks.append(neg_mask)
|
||||
GT_labels.append(GT_label)
|
||||
GT_bboxes.append(GT_bbox)
|
||||
|
||||
GT_labels = torch.cat(GT_labels)
|
||||
GT_bboxes = torch.cat(GT_bboxes)
|
||||
pos_masks = torch.cat(pos_masks)
|
||||
neg_masks = torch.cat(neg_masks)
|
||||
# compute loss
|
||||
cls_loss = 0
|
||||
img_idx = 0
|
||||
for GT_label in GT_labels:
|
||||
# print(cls_scores.shape, GT_label.shape)
|
||||
cls_loss += ClsScoreRegression(cls_scores[proposal_batch_ids==img_idx,:], GT_label, B)
|
||||
img_idx += 1
|
||||
bbox_loss = 0
|
||||
img_idx=0
|
||||
|
||||
for GT_bbox in GT_bboxes:
|
||||
bbox_offsets_cur=bbox_offsets[proposal_batch_ids==img_idx,:]
|
||||
pos_box_offsets = bbox_offsets_cur[pos_masks[img_idx],:]
|
||||
proposals_cur = proposals[proposal_batch_ids==img_idx,:]
|
||||
pos_proposals = proposals_cur[pos_masks[img_idx],:]
|
||||
# print(pos_box_offsets.shape, GT_bbox.shape)
|
||||
bbox_loss += BboxRegression(pos_box_offsets, compute_offsets(pos_proposals, GT_bbox), B)
|
||||
img_idx += 1
|
||||
total_loss=cls_loss+bbox_loss
|
||||
cls_loss = ClsScoreRegression(cls_scores[pos_masks|neg_masks], GT_labels[pos_masks|neg_masks], B)
|
||||
|
||||
bbox_loss = BboxRegression(bbox_offsets[pos_masks], compute_offsets(proposals[pos_masks], GT_bboxes), B)
|
||||
total_loss=w_cls*cls_loss+w_bbox*bbox_loss
|
||||
|
||||
##############################################################################
|
||||
# END OF YOUR CODE #
|
||||
|
|
Loading…
Reference in New Issue