Skip to content

Commit 66faf9c

Browse files
committed
Updates to support coco training and clean up
1 parent c8c386b commit 66faf9c

File tree

9 files changed

+210
-225
lines changed

9 files changed

+210
-225
lines changed

data/__init__.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,34 @@
1-
from .voc0712 import VOCDetection, AnnotationTransform, detection_collate, VOC_CLASSES
2-
from .coco import COCODetection, COCOAnnotationTransform
1+
from .voc0712 import VOCDetection, VOCAnnotationTransform, VOC_CLASSES, VOC_ROOT
2+
from .coco import COCODetection, COCOAnnotationTransform, COCO_CLASSES, COCO_ROOT
33
from .config import *
4+
import torch
45
import cv2
56
import numpy as np
67

78

9+
def detection_collate(batch):
10+
"""Custom collate fn for dealing with batches of images that have a different
11+
number of associated object annotations (bounding boxes).
12+
13+
Arguments:
14+
batch: (tuple) A tuple of tensor images and lists of annotations
15+
16+
Return:
17+
A tuple containing:
18+
1) (tensor) batch of images stacked on their 0 dim
19+
2) (list of tensors) annotations for a given image are stacked on
20+
0 dim
21+
"""
22+
targets = []
23+
imgs = []
24+
for sample in batch:
25+
imgs.append(sample[0])
26+
targets.append(torch.FloatTensor(sample[1]))
27+
return torch.stack(imgs, 0), targets
28+
29+
830
def base_transform(image, size, mean):
931
x = cv2.resize(image, (size, size)).astype(np.float32)
10-
# x = cv2.resize(np.array(image), (size, size)).astype(np.float32)
1132
x -= mean
1233
x = x.astype(np.float32)
1334
return x

data/coco.py

Lines changed: 82 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from .config import HOME
12
import os
23
import os.path
34
import sys
@@ -7,29 +8,41 @@
78
import cv2
89
import numpy as np
910

11+
COCO_ROOT = os.path.join(HOME, 'data/coco/')
12+
IMAGES = 'images'
13+
ANNOTATIONS = 'annotations'
14+
COCO_API = 'PythonAPI'
15+
INSTANCES_SET = 'instances_{}.json'
16+
COCO_CLASSES = ('person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
17+
'train', 'truck', 'boat', 'traffic light', 'fire', 'hydrant',
18+
'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog',
19+
'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra',
20+
'giraffe', 'backpack', 'umbrella', 'handbag', 'tie',
21+
'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
22+
'kite', 'baseball bat', 'baseball glove', 'skateboard',
23+
'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup',
24+
'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
25+
'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
26+
'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed',
27+
'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote',
28+
'keyboard', 'cell phone', 'microwave oven', 'toaster', 'sink',
29+
'refrigerator', 'book', 'clock', 'vase', 'scissors',
30+
'teddy bear', 'hair drier', 'toothbrush')
31+
1032

1133
class COCOAnnotationTransform(object):
12-
"""Transforms a VOC annotation into a Tensor of bbox coords and label index
34+
"""Transforms a COCO annotation into a Tensor of bbox coords and label index
1335
Initilized with a dictionary lookup of classnames to indexes
14-
15-
Arguments:
16-
class_to_ind (dict, optional): dictionary lookup of classnames -> indexes
17-
(default: alphabetic indexing of VOC's 20 classes)
18-
keep_difficult (bool, optional): keep difficult instances or not
19-
(default: False)
20-
height (int): height
21-
width (int): width
2236
"""
2337

24-
# def __init__(self)
25-
2638
def __call__(self, target, width, height):
2739
"""
28-
Arguments:
29-
target (annotation) : the target annotation to be made usable
30-
will be an ET.Element
40+
Args:
41+
target (dict): COCO target json annotation as a python dict
42+
height (int): height
43+
width (int): width
3144
Returns:
32-
a list containing lists of bounding boxes [bbox coords, class name]
45+
a list containing lists of bounding boxes [bbox coords, class idx]
3346
"""
3447
scale = np.array([width, height, width, height])
3548
res = []
@@ -41,35 +54,40 @@ def __call__(self, target, width, height):
4154
label_idx = obj['category_id']
4255
final_box = list(np.array(bbox)/scale)
4356
final_box.append(label_idx)
44-
res += [final_box] # [xmin, ymin, xmax, ymax, label_ind]
45-
return res # [[xmin, ymin, xmax, ymax, label_ind], ... ]
57+
res += [final_box] # [xmin, ymin, xmax, ymax, label_idx]
58+
return res # [[xmin, ymin, xmax, ymax, label_idx], ... ]
4659

4760

4861
class COCODetection(data.Dataset):
4962
"""`MS Coco Detection <http://mscoco.org/dataset/#detections-challenge2016>`_ Dataset.
5063
Args:
5164
root (string): Root directory where images are downloaded to.
52-
annFile (string): Path to json annotation file.
53-
transform (callable, optional): A function/transform that takes in an PIL image
54-
and returns a transformed version. E.g, ``transforms.ToTensor``
55-
target_transform (callable, optional): A function/transform that takes in the
56-
target and transforms it.
65+
set_name (string): Name of the specific set of COCO images.
66+
transform (callable, optional): A function/transform that augments the
67+
raw images`
68+
target_transform (callable, optional): A function/transform that takes
69+
in the target (bbox) and transforms it.
5770
"""
5871

59-
def __init__(self, root, annFile, transform=None, target_transform=None):
72+
def __init__(self, root, image_set, transform=None,
73+
target_transform=None, dataset_name='COCO2014'):
74+
sys.path.append(os.path.join(root, COCO_API))
6075
from pycocotools.coco import COCO
61-
self.root = root
62-
self.coco = COCO(annFile)
76+
self.root = os.path.join(root, IMAGES, image_set)
77+
self.coco = COCO(os.path.join(root, ANNOTATIONS,
78+
INSTANCES_SET.format(image_set)))
6379
self.ids = list(self.coco.imgs.keys())
6480
self.transform = transform
6581
self.target_transform = target_transform
82+
self.name = dataset_name
6683

6784
def __getitem__(self, index):
6885
"""
6986
Args:
7087
index (int): Index
7188
Returns:
72-
tuple: Tuple (image, target). target is the object returned by ``coco.loadAnns``.
89+
tuple: Tuple (image, target).
90+
target is the object returned by ``coco.loadAnns``.
7391
"""
7492
im, gt, h, w = self.pull_item(index)
7593
return im, gt
@@ -82,26 +100,58 @@ def pull_item(self, index):
82100
Args:
83101
index (int): Index
84102
Returns:
85-
tuple: Tuple (image, target). target is the object returned by ``coco.loadAnns``.
103+
tuple: Tuple (image, target, height, width).
104+
target is the object returned by ``coco.loadAnns``.
86105
"""
87-
coco = self.coco
88106
img_id = self.ids[index]
89-
ann_ids = coco.getAnnIds(imgIds=img_id)
90-
target = coco.loadAnns(ann_ids)
91-
path = coco.loadImgs(img_id)[0]['file_name']
107+
ann_ids = self.coco.getAnnIds(imgIds=img_id)
108+
target = self.coco.loadAnns(ann_ids)
109+
path = self.coco.loadImgs(img_id)[0]['file_name']
92110
img = cv2.imread(os.path.join(self.root, path))
93111
height, width, channels = img.shape
94112
if self.target_transform is not None:
95113
target = self.target_transform(target, width, height)
96114
if self.transform is not None:
97115
target = np.array(target)
98-
img, boxes, labels = self.transform(img, target[:, :4], target[:, 4])
116+
img, boxes, labels = self.transform(img, target[:, :4],
117+
target[:, 4])
99118
# to rgb
100119
img = img[:, :, (2, 1, 0)]
101120
# img = img.transpose(2, 0, 1)
102121
target = np.hstack((boxes, np.expand_dims(labels, axis=1)))
103122
return torch.from_numpy(img).permute(2, 0, 1), target, height, width
104123

124+
def pull_image(self, index):
125+
'''Returns the original image object at index in PIL form
126+
127+
Note: not using self.__getitem__(), as any transformations passed in
128+
could mess up this functionality.
129+
130+
Argument:
131+
index (int): index of img to show
132+
Return:
133+
cv2 img
134+
'''
135+
img_id = self.ids[index]
136+
path = self.coco.loadImgs(img_id)[0]['file_name']
137+
return cv2.imread(os.path.join(self.root, path), cv2.IMREAD_COLOR)
138+
139+
def pull_anno(self, index):
140+
'''Returns the original annotation of image at index
141+
142+
Note: not using self.__getitem__(), as any transformations passed in
143+
could mess up this functionality.
144+
145+
Argument:
146+
index (int): index of img to get annotation of
147+
Return:
148+
list: [img_id, [(label, bbox coords),...]]
149+
eg: ('001718', [('dog', (96, 13, 438, 332))])
150+
'''
151+
img_id = self.ids[index]
152+
ann_ids = self.coco.getAnnIds(imgIds=img_id)
153+
return self.coco.loadAnns(ann_ids)
154+
105155
def __repr__(self):
106156
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
107157
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())

data/config.py

Lines changed: 29 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -2,63 +2,35 @@
22
import os.path
33

44
# gets home dir cross platform
5-
home = os.path.expanduser("~")
6-
ddir = os.path.join(home,"data/VOCdevkit/")
7-
8-
# note: if you used our download scripts, this should be right
9-
VOCroot = ddir # path to VOCdevkit root dir
10-
11-
# default batch size
12-
BATCHES = 32
13-
# data reshuffled at every epoch
14-
SHUFFLE = True
15-
# number of subprocesses to use for data loading
16-
WORKERS = 4
17-
18-
19-
#SSD300 CONFIGS
20-
# newer version: use additional conv11_2 layer as last layer before multibox layers
21-
v2 = {
22-
'feature_maps' : [38, 19, 10, 5, 3, 1],
23-
24-
'min_dim' : 300,
25-
26-
'steps' : [8, 16, 32, 64, 100, 300],
27-
28-
'min_sizes' : [30, 60, 111, 162, 213, 264],
29-
30-
'max_sizes' : [60, 111, 162, 213, 264, 315],
31-
32-
# 'aspect_ratios' : [[2, 1/2], [2, 1/2, 3, 1/3], [2, 1/2, 3, 1/3],
33-
# [2, 1/2, 3, 1/3], [2, 1/2], [2, 1/2]],
34-
'aspect_ratios' : [[2], [2, 3], [2, 3], [2, 3], [2], [2]],
35-
36-
'variance' : [0.1, 0.2],
37-
38-
'clip' : True,
39-
40-
'name' : 'v2',
5+
HOME = os.path.expanduser("~")
6+
7+
# for making bounding boxes pretty
8+
COLORS = ((255, 0, 0, 128), (0, 255, 0, 128), (0, 0, 255, 128),
9+
(0, 255, 255, 128), (255, 0, 255, 128), (255, 255, 0, 128))
10+
11+
MEANS = (104, 117, 123)
12+
13+
# SSD300 CONFIGS
14+
voc = {
15+
'feature_maps': [38, 19, 10, 5, 3, 1],
16+
'min_dim': 300,
17+
'steps': [8, 16, 32, 64, 100, 300],
18+
'min_sizes': [30, 60, 111, 162, 213, 264],
19+
'max_sizes': [60, 111, 162, 213, 264, 315],
20+
'aspect_ratios': [[2], [2, 3], [2, 3], [2, 3], [2], [2]],
21+
'variance': [0.1, 0.2],
22+
'clip': True,
23+
'name': 'VOC',
4124
}
4225

43-
# use average pooling layer as last layer before multibox layers
44-
v1 = {
45-
'feature_maps' : [38, 19, 10, 5, 3, 1],
46-
47-
'min_dim' : 300,
48-
49-
'steps' : [8, 16, 32, 64, 100, 300],
50-
51-
'min_sizes' : [30, 60, 114, 168, 222, 276],
52-
53-
'max_sizes' : [-1, 114, 168, 222, 276, 330],
54-
55-
# 'aspect_ratios' : [[2], [2, 3], [2, 3], [2, 3], [2, 3], [2, 3]],
56-
'aspect_ratios' : [[1,1,2,1/2],[1,1,2,1/2,3,1/3],[1,1,2,1/2,3,1/3],
57-
[1,1,2,1/2,3,1/3],[1,1,2,1/2,3,1/3],[1,1,2,1/2,3,1/3]],
58-
59-
'variance' : [0.1, 0.2],
60-
61-
'clip' : True,
62-
63-
'name' : 'v1',
26+
coco = {
27+
'feature_maps': [38, 19, 10, 5, 3, 1],
28+
'min_dim': 300,
29+
'steps': [8, 16, 32, 64, 100, 300],
30+
'min_sizes': [21, 45, 99, 153, 207, 261],
31+
'max_sizes': [45, 99, 153, 207, 261, 315],
32+
'aspect_ratios': [[2], [2, 3], [2, 3], [2, 3], [2], [2]],
33+
'variance': [0.1, 0.2],
34+
'clip': True,
35+
'name': 'COCO',
6436
}

data/voc0712.py

Lines changed: 5 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
66
Updated by: Ellis Brown, Max deGroot
77
"""
8-
8+
from .config import HOME
99
import os
1010
import os.path
1111
import sys
@@ -27,12 +27,11 @@
2727
'motorbike', 'person', 'pottedplant',
2828
'sheep', 'sofa', 'train', 'tvmonitor')
2929

30-
# for making bounding boxes pretty
31-
COLORS = ((255, 0, 0, 128), (0, 255, 0, 128), (0, 0, 255, 128),
32-
(0, 255, 255, 128), (255, 0, 255, 128), (255, 255, 0, 128))
30+
# note: if you used our download scripts, this should be right
31+
VOC_ROOT = os.path.join(HOME, "data/VOCdevkit/")
3332

3433

35-
class AnnotationTransform(object):
34+
class VOCAnnotationTransform(object):
3635
"""Transforms a VOC annotation into a Tensor of bbox coords and label index
3736
Initilized with a dictionary lookup of classnames to indexes
3837
@@ -115,6 +114,7 @@ def __init__(self, root, image_sets, transform=None, target_transform=None,
115114

116115
def __getitem__(self, index):
117116
im, gt, h, w = self.pull_item(index)
117+
118118
return im, gt
119119

120120
def __len__(self):
@@ -183,23 +183,3 @@ def pull_tensor(self, index):
183183
tensorized version of img, squeezed
184184
'''
185185
return torch.Tensor(self.pull_image(index)).unsqueeze_(0)
186-
187-
188-
def detection_collate(batch):
189-
"""Custom collate fn for dealing with batches of images that have a different
190-
number of associated object annotations (bounding boxes).
191-
192-
Arguments:
193-
batch: (tuple) A tuple of tensor images and lists of annotations
194-
195-
Return:
196-
A tuple containing:
197-
1) (tensor) batch of images stacked on their 0 dim
198-
2) (list of tensors) annotations for a given image are stacked on 0 dim
199-
"""
200-
targets = []
201-
imgs = []
202-
for sample in batch:
203-
imgs.append(sample[0])
204-
targets.append(torch.FloatTensor(sample[1]))
205-
return torch.stack(imgs, 0), targets

layers/functions/detection.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import torch
22
from torch.autograd import Function
33
from ..box_utils import decode, nms
4-
from data import v2 as cfg
4+
from data import voc as cfg
55

66

77
class Detect(Function):

0 commit comments

Comments
 (0)