fixed training speed problem

This commit is contained in:
ClF3 2024-11-19 13:33:01 +08:00
parent c2b1e5c0fa
commit b4eb23917e
1 changed files with 30 additions and 12 deletions

View File

@ -142,22 +142,40 @@ def compute_iou(anchors, bboxes):
Outputs:
- iou: IoU matrix of shape (M, N)
"""
iou = torch.zeros((anchors.shape[0], bboxes.shape[0]))
iou = iou.to(anchors.device)
iou = None
##############################################################################
# TODO: Given anchors and gt bboxes, #
# compute the iou between each anchor and gt bbox. #
##############################################################################
for i in range(anchors.shape[0]):
for j in range(bboxes.shape[0]):
x1 = max(anchors[i][0], bboxes[j][0])
y1 = max(anchors[i][1], bboxes[j][1])
x2 = min(anchors[i][2], bboxes[j][2])
y2 = min(anchors[i][3], bboxes[j][3])
inter = max(0, x2 - x1) * max(0, y2 - y1)
area1 = (anchors[i][2] - anchors[i][0]) * (anchors[i][3] - anchors[i][1])
area2 = (bboxes[j][2] - bboxes[j][0]) * (bboxes[j][3] - bboxes[j][1])
iou[i][j] = inter / (area1 + area2 - inter)
M = anchors.shape[0]
N = bboxes.shape[0]
# Extract the coordinates of the anchors and bboxes
# Expand dimensions to compute pairwise IoU
anchors = anchors.reshape(M, 1, 4)
bboxes = bboxes.reshape(1, N, 4)
#extract (x,y) of left_down and right_up points
x1_a, y1_a, x2_a, y2_a = anchors[:,:, 0], anchors[:,:, 1], anchors[:,:, 2], anchors[:,:, 3]
x1_b, y1_b, x2_b, y2_b = bboxes[:,:, 0], bboxes[:,:, 1], bboxes[:,:, 2], bboxes[:,:, 3]
# Compute the intersection coordinates
inter_x1 = torch.max(x1_a, x1_b)
inter_y1 = torch.max(y1_a, y1_b)
inter_x2 = torch.min(x2_a, x2_b)
inter_y2 = torch.min(y2_a, y2_b)
# Compute the intersection area
inter_area = torch.clamp(inter_x2 - inter_x1,min=0) * torch.clamp(inter_y2 - inter_y1,min=0)
# Compute the area of anchors and bboxes
anchor_area = (x2_a - x1_a) * (y2_a - y1_a) # Shape (M, 1)
bbox_area = (x2_b - x1_b) * (y2_b - y1_b) # Shape (1, N)
# Compute the union area
union_area = anchor_area + bbox_area - inter_area
# Compute IoU
iou = inter_area / union_area
##############################################################################
# END OF YOUR CODE #