|
| 1 | +from .config import HOME |
| 2 | +import os |
| 3 | +import os.path as osp |
| 4 | +import sys |
| 5 | +import torch |
| 6 | +import torch.utils.data as data |
| 7 | +import torchvision.transforms as transforms |
| 8 | +import cv2 |
| 9 | +import numpy as np |
| 10 | + |
| 11 | +COCO_ROOT = osp.join(HOME, 'data/coco/') |
| 12 | +IMAGES = 'images' |
| 13 | +ANNOTATIONS = 'annotations' |
| 14 | +COCO_API = 'PythonAPI' |
| 15 | +INSTANCES_SET = 'instances_{}.json' |
| 16 | +COCO_CLASSES = ('person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', |
| 17 | + 'train', 'truck', 'boat', 'traffic light', 'fire', 'hydrant', |
| 18 | + 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', |
| 19 | + 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', |
| 20 | + 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', |
| 21 | + 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', |
| 22 | + 'kite', 'baseball bat', 'baseball glove', 'skateboard', |
| 23 | + 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', |
| 24 | + 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', |
| 25 | + 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', |
| 26 | + 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', |
| 27 | + 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', |
| 28 | + 'keyboard', 'cell phone', 'microwave oven', 'toaster', 'sink', |
| 29 | + 'refrigerator', 'book', 'clock', 'vase', 'scissors', |
| 30 | + 'teddy bear', 'hair drier', 'toothbrush') |
| 31 | + |
| 32 | + |
| 33 | +class COCOAnnotationTransform(object): |
| 34 | + """Transforms a COCO annotation into a Tensor of bbox coords and label index |
| 35 | + Initilized with a dictionary lookup of classnames to indexes |
| 36 | + """ |
| 37 | + def __init__(self): |
| 38 | + self.label_map = get_label_map(osp.join(COCO_ROOT, 'coco_labels.txt')) |
| 39 | + |
| 40 | + def __call__(self, target, width, height): |
| 41 | + """ |
| 42 | + Args: |
| 43 | + target (dict): COCO target json annotation as a python dict |
| 44 | + height (int): height |
| 45 | + width (int): width |
| 46 | + Returns: |
| 47 | + a list containing lists of bounding boxes [bbox coords, class idx] |
| 48 | + """ |
| 49 | + scale = np.array([width, height, width, height]) |
| 50 | + res = [] |
| 51 | + for obj in target: |
| 52 | + if 'bbox' in obj: |
| 53 | + bbox = obj['bbox'] |
| 54 | + bbox[2] += bbox[0] |
| 55 | + bbox[3] += bbox[1] |
| 56 | + label_idx = self.label_map[obj['category_id']] - 1 |
| 57 | + final_box = list(np.array(bbox)/scale) |
| 58 | + final_box.append(label_idx) |
| 59 | + res += [final_box] # [xmin, ymin, xmax, ymax, label_idx] |
| 60 | + else: |
| 61 | + print("no bbox problem!") |
| 62 | + |
| 63 | + return res # [[xmin, ymin, xmax, ymax, label_idx], ... ] |
| 64 | + |
| 65 | + |
| 66 | +class COCODetection(data.Dataset): |
| 67 | + """`MS Coco Detection <http://mscoco.org/dataset/#detections-challenge2016>`_ Dataset. |
| 68 | + Args: |
| 69 | + root (string): Root directory where images are downloaded to. |
| 70 | + set_name (string): Name of the specific set of COCO images. |
| 71 | + transform (callable, optional): A function/transform that augments the |
| 72 | + raw images` |
| 73 | + target_transform (callable, optional): A function/transform that takes |
| 74 | + in the target (bbox) and transforms it. |
| 75 | + """ |
| 76 | + |
| 77 | + def __init__(self, root, image_set, transform=None, |
| 78 | + target_transform=None): |
| 79 | + sys.path.append(osp.join(root, COCO_API)) |
| 80 | + from pycocotools.coco import COCO |
| 81 | + self.root = osp.join(root, IMAGES, image_set) |
| 82 | + self.coco = COCO(osp.join(root, ANNOTATIONS, |
| 83 | + INSTANCES_SET.format(image_set))) |
| 84 | + self.ids = list(self.coco.imgToAnns.keys()) |
| 85 | + self.transform = transform |
| 86 | + self.target_transform = target_transform |
| 87 | + self.name = 'MS COCO ' + image_set |
| 88 | + |
| 89 | + def __getitem__(self, index): |
| 90 | + """ |
| 91 | + Args: |
| 92 | + index (int): Index |
| 93 | + Returns: |
| 94 | + tuple: Tuple (image, target). |
| 95 | + target is the object returned by ``coco.loadAnns``. |
| 96 | + """ |
| 97 | + im, gt, h, w = self.pull_item(index) |
| 98 | + return im, gt |
| 99 | + |
| 100 | + def __len__(self): |
| 101 | + return len(self.ids) |
| 102 | + |
| 103 | + def pull_item(self, index): |
| 104 | + """ |
| 105 | + Args: |
| 106 | + index (int): Index |
| 107 | + Returns: |
| 108 | + tuple: Tuple (image, target, height, width). |
| 109 | + target is the object returned by ``coco.loadAnns``. |
| 110 | + """ |
| 111 | + img_id = self.ids[index] |
| 112 | + target = self.coco.imgToAnns[img_id] |
| 113 | + ann_ids = self.coco.getAnnIds(imgIds=img_id) |
| 114 | + |
| 115 | + target = self.coco.loadAnns(ann_ids) |
| 116 | + path = osp.join(self.root, self.coco.loadImgs(img_id)[0]['file_name']) |
| 117 | + assert osp.exists(path), 'Image path does not exist: {}'.format(path) |
| 118 | + img = cv2.imread(osp.join(self.root, path)) |
| 119 | + height, width, _ = img.shape |
| 120 | + if self.target_transform is not None: |
| 121 | + target = self.target_transform(target, width, height) |
| 122 | + if self.transform is not None: |
| 123 | + target = np.array(target) |
| 124 | + img, boxes, labels = self.transform(img, target[:, :4], |
| 125 | + target[:, 4]) |
| 126 | + # to rgb |
| 127 | + img = img[:, :, (2, 1, 0)] |
| 128 | + |
| 129 | + target = np.hstack((boxes, np.expand_dims(labels, axis=1))) |
| 130 | + return torch.from_numpy(img).permute(2, 0, 1), target, height, width |
| 131 | + |
| 132 | + def pull_image(self, index): |
| 133 | + '''Returns the original image object at index in PIL form |
| 134 | +
|
| 135 | + Note: not using self.__getitem__(), as any transformations passed in |
| 136 | + could mess up this functionality. |
| 137 | +
|
| 138 | + Argument: |
| 139 | + index (int): index of img to show |
| 140 | + Return: |
| 141 | + cv2 img |
| 142 | + ''' |
| 143 | + img_id = self.ids[index] |
| 144 | + path = self.coco.loadImgs(img_id)[0]['file_name'] |
| 145 | + return cv2.imread(osp.join(self.root, path), cv2.IMREAD_COLOR) |
| 146 | + |
| 147 | + def pull_anno(self, index): |
| 148 | + '''Returns the original annotation of image at index |
| 149 | +
|
| 150 | + Note: not using self.__getitem__(), as any transformations passed in |
| 151 | + could mess up this functionality. |
| 152 | +
|
| 153 | + Argument: |
| 154 | + index (int): index of img to get annotation of |
| 155 | + Return: |
| 156 | + list: [img_id, [(label, bbox coords),...]] |
| 157 | + eg: ('001718', [('dog', (96, 13, 438, 332))]) |
| 158 | + ''' |
| 159 | + img_id = self.ids[index] |
| 160 | + ann_ids = self.coco.getAnnIds(imgIds=img_id) |
| 161 | + return self.coco.loadAnns(ann_ids) |
| 162 | + |
| 163 | + def __repr__(self): |
| 164 | + fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' |
| 165 | + fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) |
| 166 | + fmt_str += ' Root Location: {}\n'.format(self.root) |
| 167 | + tmp = ' Transforms (if any): ' |
| 168 | + fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) |
| 169 | + tmp = ' Target Transforms (if any): ' |
| 170 | + fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) |
| 171 | + return fmt_str |
| 172 | + |
| 173 | + |
| 174 | +def get_label_map(label_file): |
| 175 | + label_map = {} |
| 176 | + labels = open(label_file, 'r') |
| 177 | + for line in labels: |
| 178 | + ids = line.split(',') |
| 179 | + label_map[int(ids[0])] = int(ids[1]) |
| 180 | + return label_map |
0 commit comments