1+ from .config import HOME
12import os
23import os .path
34import sys
78import cv2
89import numpy as np
910
11+ COCO_ROOT = os .path .join (HOME , 'data/coco/' )
12+ IMAGES = 'images'
13+ ANNOTATIONS = 'annotations'
14+ COCO_API = 'PythonAPI'
15+ INSTANCES_SET = 'instances_{}.json'
16+ COCO_CLASSES = ('person' , 'bicycle' , 'car' , 'motorcycle' , 'airplane' , 'bus' ,
17+ 'train' , 'truck' , 'boat' , 'traffic light' , 'fire' , 'hydrant' ,
18+ 'stop sign' , 'parking meter' , 'bench' , 'bird' , 'cat' , 'dog' ,
19+ 'horse' , 'sheep' , 'cow' , 'elephant' , 'bear' , 'zebra' ,
20+ 'giraffe' , 'backpack' , 'umbrella' , 'handbag' , 'tie' ,
21+ 'suitcase' , 'frisbee' , 'skis' , 'snowboard' , 'sports ball' ,
22+ 'kite' , 'baseball bat' , 'baseball glove' , 'skateboard' ,
23+ 'surfboard' , 'tennis racket' , 'bottle' , 'wine glass' , 'cup' ,
24+ 'fork' , 'knife' , 'spoon' , 'bowl' , 'banana' , 'apple' ,
25+ 'sandwich' , 'orange' , 'broccoli' , 'carrot' , 'hot dog' , 'pizza' ,
26+ 'donut' , 'cake' , 'chair' , 'couch' , 'potted plant' , 'bed' ,
27+ 'dining table' , 'toilet' , 'tv' , 'laptop' , 'mouse' , 'remote' ,
28+ 'keyboard' , 'cell phone' , 'microwave oven' , 'toaster' , 'sink' ,
29+ 'refrigerator' , 'book' , 'clock' , 'vase' , 'scissors' ,
30+ 'teddy bear' , 'hair drier' , 'toothbrush' )
31+
1032
1133class COCOAnnotationTransform (object ):
12- """Transforms a VOC annotation into a Tensor of bbox coords and label index
34+ """Transforms a COCO annotation into a Tensor of bbox coords and label index
1335 Initilized with a dictionary lookup of classnames to indexes
14-
15- Arguments:
16- class_to_ind (dict, optional): dictionary lookup of classnames -> indexes
17- (default: alphabetic indexing of VOC's 20 classes)
18- keep_difficult (bool, optional): keep difficult instances or not
19- (default: False)
20- height (int): height
21- width (int): width
2236 """
2337
24- # def __init__(self)
25-
2638 def __call__ (self , target , width , height ):
2739 """
28- Arguments:
29- target (annotation) : the target annotation to be made usable
30- will be an ET.Element
40+ Args:
41+ target (dict): COCO target json annotation as a python dict
42+ height (int): height
43+ width (int): width
3144 Returns:
32- a list containing lists of bounding boxes [bbox coords, class name ]
45+ a list containing lists of bounding boxes [bbox coords, class idx ]
3346 """
3447 scale = np .array ([width , height , width , height ])
3548 res = []
@@ -41,35 +54,40 @@ def __call__(self, target, width, height):
4154 label_idx = obj ['category_id' ]
4255 final_box = list (np .array (bbox )/ scale )
4356 final_box .append (label_idx )
44- res += [final_box ] # [xmin, ymin, xmax, ymax, label_ind ]
45- return res # [[xmin, ymin, xmax, ymax, label_ind ], ... ]
57+ res += [final_box ] # [xmin, ymin, xmax, ymax, label_idx ]
58+ return res # [[xmin, ymin, xmax, ymax, label_idx ], ... ]
4659
4760
4861class COCODetection (data .Dataset ):
4962 """`MS Coco Detection <http://mscoco.org/dataset/#detections-challenge2016>`_ Dataset.
5063 Args:
5164 root (string): Root directory where images are downloaded to.
52- annFile (string): Path to json annotation file .
53- transform (callable, optional): A function/transform that takes in an PIL image
54- and returns a transformed version. E.g, ``transforms.ToTensor` `
55- target_transform (callable, optional): A function/transform that takes in the
56- target and transforms it.
65+ set_name (string): Name of the specific set of COCO images .
66+ transform (callable, optional): A function/transform that augments the
67+ raw images `
68+ target_transform (callable, optional): A function/transform that takes
69+ in the target (bbox) and transforms it.
5770 """
5871
59- def __init__ (self , root , annFile , transform = None , target_transform = None ):
72+ def __init__ (self , root , image_set , transform = None ,
73+ target_transform = None , dataset_name = 'COCO2014' ):
74+ sys .path .append (os .path .join (root , COCO_API ))
6075 from pycocotools .coco import COCO
61- self .root = root
62- self .coco = COCO (annFile )
76+ self .root = os .path .join (root , IMAGES , image_set )
77+ self .coco = COCO (os .path .join (root , ANNOTATIONS ,
78+ INSTANCES_SET .format (image_set )))
6379 self .ids = list (self .coco .imgs .keys ())
6480 self .transform = transform
6581 self .target_transform = target_transform
82+ self .name = dataset_name
6683
6784 def __getitem__ (self , index ):
6885 """
6986 Args:
7087 index (int): Index
7188 Returns:
72- tuple: Tuple (image, target). target is the object returned by ``coco.loadAnns``.
89+ tuple: Tuple (image, target).
90+ target is the object returned by ``coco.loadAnns``.
7391 """
7492 im , gt , h , w = self .pull_item (index )
7593 return im , gt
@@ -82,26 +100,58 @@ def pull_item(self, index):
82100 Args:
83101 index (int): Index
84102 Returns:
85- tuple: Tuple (image, target). target is the object returned by ``coco.loadAnns``.
103+ tuple: Tuple (image, target, height, width).
104+ target is the object returned by ``coco.loadAnns``.
86105 """
87- coco = self .coco
88106 img_id = self .ids [index ]
89- ann_ids = coco .getAnnIds (imgIds = img_id )
90- target = coco .loadAnns (ann_ids )
91- path = coco .loadImgs (img_id )[0 ]['file_name' ]
107+ ann_ids = self . coco .getAnnIds (imgIds = img_id )
108+ target = self . coco .loadAnns (ann_ids )
109+ path = self . coco .loadImgs (img_id )[0 ]['file_name' ]
92110 img = cv2 .imread (os .path .join (self .root , path ))
93111 height , width , channels = img .shape
94112 if self .target_transform is not None :
95113 target = self .target_transform (target , width , height )
96114 if self .transform is not None :
97115 target = np .array (target )
98- img , boxes , labels = self .transform (img , target [:, :4 ], target [:, 4 ])
116+ img , boxes , labels = self .transform (img , target [:, :4 ],
117+ target [:, 4 ])
99118 # to rgb
100119 img = img [:, :, (2 , 1 , 0 )]
101120 # img = img.transpose(2, 0, 1)
102121 target = np .hstack ((boxes , np .expand_dims (labels , axis = 1 )))
103122 return torch .from_numpy (img ).permute (2 , 0 , 1 ), target , height , width
104123
124+ def pull_image (self , index ):
125+ '''Returns the original image object at index in PIL form
126+
127+ Note: not using self.__getitem__(), as any transformations passed in
128+ could mess up this functionality.
129+
130+ Argument:
131+ index (int): index of img to show
132+ Return:
133+ cv2 img
134+ '''
135+ img_id = self .ids [index ]
136+ path = self .coco .loadImgs (img_id )[0 ]['file_name' ]
137+ return cv2 .imread (os .path .join (self .root , path ), cv2 .IMREAD_COLOR )
138+
139+ def pull_anno (self , index ):
140+ '''Returns the original annotation of image at index
141+
142+ Note: not using self.__getitem__(), as any transformations passed in
143+ could mess up this functionality.
144+
145+ Argument:
146+ index (int): index of img to get annotation of
147+ Return:
148+ list: [img_id, [(label, bbox coords),...]]
149+ eg: ('001718', [('dog', (96, 13, 438, 332))])
150+ '''
151+ img_id = self .ids [index ]
152+ ann_ids = self .coco .getAnnIds (imgIds = img_id )
153+ return self .coco .loadAnns (ann_ids )
154+
105155 def __repr__ (self ):
106156 fmt_str = 'Dataset ' + self .__class__ .__name__ + '\n '
107157 fmt_str += ' Number of datapoints: {}\n ' .format (self .__len__ ())
0 commit comments