fixed training speed problem
This commit is contained in:
parent
c2b1e5c0fa
commit
b4eb23917e
42
utils.py
42
utils.py
|
@ -142,22 +142,40 @@ def compute_iou(anchors, bboxes):
|
||||||
Outputs:
|
Outputs:
|
||||||
- iou: IoU matrix of shape (M, N)
|
- iou: IoU matrix of shape (M, N)
|
||||||
"""
|
"""
|
||||||
iou = torch.zeros((anchors.shape[0], bboxes.shape[0]))
|
iou = None
|
||||||
iou = iou.to(anchors.device)
|
|
||||||
##############################################################################
|
##############################################################################
|
||||||
# TODO: Given anchors and gt bboxes, #
|
# TODO: Given anchors and gt bboxes, #
|
||||||
# compute the iou between each anchor and gt bbox. #
|
# compute the iou between each anchor and gt bbox. #
|
||||||
##############################################################################
|
##############################################################################
|
||||||
for i in range(anchors.shape[0]):
|
|
||||||
for j in range(bboxes.shape[0]):
|
M = anchors.shape[0]
|
||||||
x1 = max(anchors[i][0], bboxes[j][0])
|
N = bboxes.shape[0]
|
||||||
y1 = max(anchors[i][1], bboxes[j][1])
|
# Extract the coordinates of the anchors and bboxes
|
||||||
x2 = min(anchors[i][2], bboxes[j][2])
|
# Expand dimensions to compute pairwise IoU
|
||||||
y2 = min(anchors[i][3], bboxes[j][3])
|
anchors = anchors.reshape(M, 1, 4)
|
||||||
inter = max(0, x2 - x1) * max(0, y2 - y1)
|
bboxes = bboxes.reshape(1, N, 4)
|
||||||
area1 = (anchors[i][2] - anchors[i][0]) * (anchors[i][3] - anchors[i][1])
|
#extract (x,y) of left_down and right_up points
|
||||||
area2 = (bboxes[j][2] - bboxes[j][0]) * (bboxes[j][3] - bboxes[j][1])
|
x1_a, y1_a, x2_a, y2_a = anchors[:,:, 0], anchors[:,:, 1], anchors[:,:, 2], anchors[:,:, 3]
|
||||||
iou[i][j] = inter / (area1 + area2 - inter)
|
x1_b, y1_b, x2_b, y2_b = bboxes[:,:, 0], bboxes[:,:, 1], bboxes[:,:, 2], bboxes[:,:, 3]
|
||||||
|
|
||||||
|
# Compute the intersection coordinates
|
||||||
|
inter_x1 = torch.max(x1_a, x1_b)
|
||||||
|
inter_y1 = torch.max(y1_a, y1_b)
|
||||||
|
inter_x2 = torch.min(x2_a, x2_b)
|
||||||
|
inter_y2 = torch.min(y2_a, y2_b)
|
||||||
|
|
||||||
|
# Compute the intersection area
|
||||||
|
inter_area = torch.clamp(inter_x2 - inter_x1,min=0) * torch.clamp(inter_y2 - inter_y1,min=0)
|
||||||
|
|
||||||
|
# Compute the area of anchors and bboxes
|
||||||
|
anchor_area = (x2_a - x1_a) * (y2_a - y1_a) # Shape (M, 1)
|
||||||
|
bbox_area = (x2_b - x1_b) * (y2_b - y1_b) # Shape (1, N)
|
||||||
|
|
||||||
|
# Compute the union area
|
||||||
|
union_area = anchor_area + bbox_area - inter_area
|
||||||
|
|
||||||
|
# Compute IoU
|
||||||
|
iou = inter_area / union_area
|
||||||
|
|
||||||
##############################################################################
|
##############################################################################
|
||||||
# END OF YOUR CODE #
|
# END OF YOUR CODE #
|
||||||
|
|
Loading…
Reference in New Issue