96 lines
3.5 KiB
Python
96 lines
3.5 KiB
Python
|
import os
|
||
|
import json
|
||
|
import torch
|
||
|
from torchvision import datasets
|
||
|
from torchvision import transforms
|
||
|
from torch.utils.data import DataLoader
|
||
|
from functools import partial
|
||
|
|
||
|
|
||
|
class_to_idx = {
|
||
|
'aeroplane':0, 'bicycle':1, 'bird':2, 'boat':3, 'bottle':4,
|
||
|
'bus':5, 'car':6, 'cat':7, 'chair':8, 'cow':9, 'diningtable':10,
|
||
|
'dog':11, 'horse':12, 'motorbike':13, 'person':14, 'pottedplant':15,
|
||
|
'sheep':16, 'sofa':17, 'train':18, 'tvmonitor':19
|
||
|
}
|
||
|
idx_to_class = {i:c for c, i in class_to_idx.items()}
|
||
|
|
||
|
|
||
|
def get_pascal_voc2007_data(image_root, split='train'):
|
||
|
"""
|
||
|
Use torchvision.datasets
|
||
|
https://pytorch.org/docs/stable/torchvision/datasets.html#torchvision.datasets.VOCDetection
|
||
|
"""
|
||
|
|
||
|
train_dataset = datasets.VOCDetection(image_root, year='2007', image_set=split,
|
||
|
download=False)
|
||
|
|
||
|
return train_dataset
|
||
|
|
||
|
|
||
|
def pascal_voc2007_loader(dataset, batch_size, num_workers=0, shuffle=False, proposal_path=None):
|
||
|
"""
|
||
|
Data loader for Pascal VOC 2007.
|
||
|
https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader
|
||
|
"""
|
||
|
collate_fn = partial(voc_collate_fn, proposal_path=proposal_path)
|
||
|
train_loader = DataLoader(dataset,
|
||
|
batch_size=batch_size,
|
||
|
shuffle=shuffle, pin_memory=True,
|
||
|
num_workers=num_workers,
|
||
|
collate_fn=collate_fn)
|
||
|
return train_loader
|
||
|
|
||
|
|
||
|
def voc_collate_fn(batch_lst, reshape_size=224, proposal_path=None):
|
||
|
preprocess = transforms.Compose([
|
||
|
transforms.Resize((reshape_size, reshape_size)),
|
||
|
transforms.ToTensor(),
|
||
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
||
|
])
|
||
|
|
||
|
batch_size = len(batch_lst)
|
||
|
|
||
|
img_batch = torch.zeros(batch_size, 3, reshape_size, reshape_size)
|
||
|
|
||
|
box_list = []
|
||
|
box_batch_idx = []
|
||
|
w_list = []
|
||
|
h_list = []
|
||
|
img_id_list = []
|
||
|
proposal_list = []
|
||
|
proposal_batch_idx = []
|
||
|
|
||
|
for i in range(batch_size):
|
||
|
img, ann = batch_lst[i]
|
||
|
w_list.append(img.size[0]) # image width
|
||
|
h_list.append(img.size[1]) # image height
|
||
|
img_id_list.append(ann['annotation']['filename'])
|
||
|
img_batch[i] = preprocess(img)
|
||
|
all_bbox = ann['annotation']['object']
|
||
|
if type(all_bbox) == dict: # inconsistency in the annotation file
|
||
|
all_bbox = [all_bbox]
|
||
|
for bbox_idx, one_bbox in enumerate(all_bbox):
|
||
|
bbox = one_bbox['bndbox']
|
||
|
obj_cls = one_bbox['name']
|
||
|
box_list.append(torch.Tensor([float(bbox['xmin']), float(bbox['ymin']),
|
||
|
float(bbox['xmax']), float(bbox['ymax']), class_to_idx[obj_cls]]))
|
||
|
box_batch_idx.append(i)
|
||
|
if proposal_path is not None:
|
||
|
proposal_fn = ann['annotation']['filename'].replace('.jpg', '.json')
|
||
|
with open(os.path.join(proposal_path, proposal_fn), 'r') as f:
|
||
|
proposal = json.load(f)
|
||
|
for p in proposal:
|
||
|
proposal_list.append([p['x_min'], p['y_min'], p['x_max'], p['y_max']])
|
||
|
proposal_batch_idx.append(i)
|
||
|
|
||
|
h_batch = torch.tensor(h_list)
|
||
|
w_batch = torch.tensor(w_list)
|
||
|
box_batch = torch.stack(box_list)
|
||
|
box_batch_ids = torch.tensor(box_batch_idx, dtype=torch.long)
|
||
|
proposals = torch.tensor(proposal_list, dtype=box_batch.dtype)
|
||
|
proposal_batch_ids = torch.tensor(proposal_batch_idx, dtype=torch.long)
|
||
|
assert len(box_batch) == len(box_batch_ids)
|
||
|
assert len(proposals) == len(proposal_batch_ids)
|
||
|
|
||
|
return img_batch, box_batch, box_batch_ids, proposals, proposal_batch_ids, w_batch, h_batch, img_id_list
|