diff --git a/utils.py b/utils.py index b3d3ca2..ccd8151 100644 --- a/utils.py +++ b/utils.py @@ -142,22 +142,40 @@ def compute_iou(anchors, bboxes): Outputs: - iou: IoU matrix of shape (M, N) """ - iou = torch.zeros((anchors.shape[0], bboxes.shape[0])) - iou = iou.to(anchors.device) + iou = None ############################################################################## # TODO: Given anchors and gt bboxes, # # compute the iou between each anchor and gt bbox. # ############################################################################## - for i in range(anchors.shape[0]): - for j in range(bboxes.shape[0]): - x1 = max(anchors[i][0], bboxes[j][0]) - y1 = max(anchors[i][1], bboxes[j][1]) - x2 = min(anchors[i][2], bboxes[j][2]) - y2 = min(anchors[i][3], bboxes[j][3]) - inter = max(0, x2 - x1) * max(0, y2 - y1) - area1 = (anchors[i][2] - anchors[i][0]) * (anchors[i][3] - anchors[i][1]) - area2 = (bboxes[j][2] - bboxes[j][0]) * (bboxes[j][3] - bboxes[j][1]) - iou[i][j] = inter / (area1 + area2 - inter) + + M = anchors.shape[0] + N = bboxes.shape[0] + # Extract the coordinates of the anchors and bboxes + # Expand dimensions to compute pairwise IoU + anchors = anchors.reshape(M, 1, 4) + bboxes = bboxes.reshape(1, N, 4) + #extract (x,y) of left_down and right_up points + x1_a, y1_a, x2_a, y2_a = anchors[:,:, 0], anchors[:,:, 1], anchors[:,:, 2], anchors[:,:, 3] + x1_b, y1_b, x2_b, y2_b = bboxes[:,:, 0], bboxes[:,:, 1], bboxes[:,:, 2], bboxes[:,:, 3] + + # Compute the intersection coordinates + inter_x1 = torch.max(x1_a, x1_b) + inter_y1 = torch.max(y1_a, y1_b) + inter_x2 = torch.min(x2_a, x2_b) + inter_y2 = torch.min(y2_a, y2_b) + + # Compute the intersection area + inter_area = torch.clamp(inter_x2 - inter_x1,min=0) * torch.clamp(inter_y2 - inter_y1,min=0) + + # Compute the area of anchors and bboxes + anchor_area = (x2_a - x1_a) * (y2_a - y1_a) # Shape (M, 1) + bbox_area = (x2_b - x1_b) * (y2_b - y1_b) # Shape (1, N) + + # Compute the union area + union_area = anchor_area + bbox_area - inter_area + + # Compute IoU + iou = inter_area / union_area ############################################################################## # END OF YOUR CODE #