fixed loss compute
This commit is contained in:
parent
89c3326a59
commit
c2b1e5c0fa
25
model.py
25
model.py
|
@ -154,26 +154,15 @@ class FastRCNN(nn.Module):
|
||||||
neg_masks.append(neg_mask)
|
neg_masks.append(neg_mask)
|
||||||
GT_labels.append(GT_label)
|
GT_labels.append(GT_label)
|
||||||
GT_bboxes.append(GT_bbox)
|
GT_bboxes.append(GT_bbox)
|
||||||
|
GT_labels = torch.cat(GT_labels)
|
||||||
|
GT_bboxes = torch.cat(GT_bboxes)
|
||||||
|
pos_masks = torch.cat(pos_masks)
|
||||||
|
neg_masks = torch.cat(neg_masks)
|
||||||
# compute loss
|
# compute loss
|
||||||
cls_loss = 0
|
cls_loss = ClsScoreRegression(cls_scores[pos_masks|neg_masks], GT_labels[pos_masks|neg_masks], B)
|
||||||
img_idx = 0
|
|
||||||
for GT_label in GT_labels:
|
|
||||||
# print(cls_scores.shape, GT_label.shape)
|
|
||||||
cls_loss += ClsScoreRegression(cls_scores[proposal_batch_ids==img_idx,:], GT_label, B)
|
|
||||||
img_idx += 1
|
|
||||||
bbox_loss = 0
|
|
||||||
img_idx=0
|
|
||||||
|
|
||||||
for GT_bbox in GT_bboxes:
|
bbox_loss = BboxRegression(bbox_offsets[pos_masks], compute_offsets(proposals[pos_masks], GT_bboxes), B)
|
||||||
bbox_offsets_cur=bbox_offsets[proposal_batch_ids==img_idx,:]
|
total_loss=w_cls*cls_loss+w_bbox*bbox_loss
|
||||||
pos_box_offsets = bbox_offsets_cur[pos_masks[img_idx],:]
|
|
||||||
proposals_cur = proposals[proposal_batch_ids==img_idx,:]
|
|
||||||
pos_proposals = proposals_cur[pos_masks[img_idx],:]
|
|
||||||
# print(pos_box_offsets.shape, GT_bbox.shape)
|
|
||||||
bbox_loss += BboxRegression(pos_box_offsets, compute_offsets(pos_proposals, GT_bbox), B)
|
|
||||||
img_idx += 1
|
|
||||||
total_loss=cls_loss+bbox_loss
|
|
||||||
|
|
||||||
##############################################################################
|
##############################################################################
|
||||||
# END OF YOUR CODE #
|
# END OF YOUR CODE #
|
||||||
|
|
Loading…
Reference in New Issue