Skip to content

Commit 13c3bb3

Browse files
authored
Merge branch 'coco' into develop
2 parents 2328f2b + e1bb0d0 commit 13c3bb3

File tree

16 files changed

+481
-267
lines changed

16 files changed

+481
-267
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,3 +123,6 @@ test_data_aug.py
123123
# temp checkout soln
124124
data/datasets/
125125
data/ssd_dataloader.py
126+
127+
# pylint
128+
.pylintrc

README.md

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
# SSD: Single Shot MultiBox Object Detector, in PyTorch
22
A [PyTorch](http://pytorch.org/) implementation of [Single Shot MultiBox Detector](http://arxiv.org/abs/1512.02325) from the 2016 paper by Wei Liu, Dragomir Anguelov, Dumitru Erhan, Christian Szegedy, Scott Reed, Cheng-Yang, and Alexander C. Berg. The official and original Caffe code can be found [here](https://github.com/weiliu89/caffe/tree/ssd).
33

4+
***UPDATE:*** We have just added support for MS COCO! Check it out [below](#coco).
5+
46
## Authors
57

68
* [**Max deGroot**](https://github.com/amdegroot)
79
* [**Ellis Brown**](http://github.com/ellisbrown)
810

9-
***Note:*** Unfortunately, this is just a hobby of ours and not a full-time job, so we'll do our best to keep things up to date, but no guarantees. That being said, thanks to everyone for your continued help and feedback as it is really appreciated. We will try to address everything as soon as possible.
11+
***Note:*** Unfortunately, this is just a hobby for us and not a full-time job, so we'll do our best to keep things up to date, but no guarantees. That being said, thanks to everyone for your continued help and feedback as it is really appreciated. We will try to address everything as soon as possible.
1012

1113

1214
<img align="right" src= "https://github.com/amdegroot/ssd.pytorch/blob/master/doc/ssd.png" height = 400/>
@@ -30,7 +32,7 @@ A [PyTorch](http://pytorch.org/) implementation of [Single Shot MultiBox Detecto
3032
- Install [PyTorch](http://pytorch.org/) by selecting your environment on the website and running the appropriate command.
3133
- Clone this repository.
3234
* Note: We currently only support Python 3+.
33-
- Then download the dataset by following the [instructions](#download-voc2007-trainval--test) below.
35+
- Then download the dataset by following the [instructions](#datasets) below.
3436
- We now support [Visdom](https://github.com/facebookresearch/visdom) for real-time loss visualization during training!
3537
* To use Visdom in the browser:
3638
```Shell
@@ -40,21 +42,31 @@ A [PyTorch](http://pytorch.org/) implementation of [Single Shot MultiBox Detecto
4042
python -m visdom.server
4143
```
4244
* Then (during training) navigate to http://localhost:8097/ (see the Train section below for training details).
43-
- Note: For training, we currently only support [VOC](http://host.robots.ox.ac.uk/pascal/VOC/), but are adding [COCO](http://mscoco.org/) and hopefully [ImageNet](http://www.image-net.org/) soon.
45+
- Note: For training, we currently support [VOC](http://host.robots.ox.ac.uk/pascal/VOC/) and [COCO](http://mscoco.org/), and aim to add [ImageNet](http://www.image-net.org/) support soon.
4446

4547
## Datasets
46-
To make things easy, we provide a simple VOC dataset loader that inherits `torch.utils.data.Dataset` making it fully compatible with the `torchvision.datasets` [API](http://pytorch.org/docs/torchvision/datasets.html).
48+
To make things easy, we provide bash scripts to handle the dataset downloads and setup for you. We also provide simple dataset loaders that inherit `torch.utils.data.Dataset`, making them fully compatible with the `torchvision.datasets` [API](http://pytorch.org/docs/torchvision/datasets.html).
49+
50+
51+
### COCO
52+
Microsoft COCO: Common Objects in Context
53+
54+
##### Download COCO 2014
55+
```Shell
56+
# specify a directory for dataset to be downloaded into, else default is ~/data/
57+
sh data/scripts/COCO2014.sh
58+
```
4759

4860
### VOC Dataset
49-
##### Download VOC2007 trainval & test
61+
PASCAL VOC: Visual Object Classes
5062

63+
##### Download VOC2007 trainval & test
5164
```Shell
5265
# specify a directory for dataset to be downloaded into, else default is ~/data/
5366
sh data/scripts/VOC2007.sh # <directory>
5467
```
5568

5669
##### Download VOC2012 trainval
57-
5870
```Shell
5971
# specify a directory for dataset to be downloaded into, else default is ~/data/
6072
sh data/scripts/VOC2012.sh # <directory>
@@ -149,12 +161,11 @@ jupyter notebook
149161
- Running `python -m demo.live` opens the webcam and begins detecting!
150162

151163
## TODO
152-
We have accumulated the following to-do list, which you can expect to be done in the very near future
164+
We have accumulated the following to-do list, which we hope to complete in the near future
153165
- Still to come:
154-
* Support for the MS COCO dataset
155-
* Support for SSD512 training and testing
156-
* Support for training on custom datasets
157-
166+
* [x] Support for the MS COCO dataset
167+
* [ ] Support for SSD512 training and testing
168+
* [ ] Support for training on custom datasets
158169

159170
## References
160171
- Wei Liu, et al. "SSD: Single Shot MultiBox Detector." [ECCV2016]((http://arxiv.org/abs/1512.02325)).

data/__init__.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,33 @@
1-
from .voc0712 import VOCDetection, AnnotationTransform, detection_collate, VOC_CLASSES
1+
from .voc0712 import VOCDetection, VOCAnnotationTransform, VOC_CLASSES, VOC_ROOT
2+
from .coco import COCODetection, COCOAnnotationTransform, COCO_CLASSES, COCO_ROOT
23
from .config import *
4+
import torch
35
import cv2
46
import numpy as np
57

8+
def detection_collate(batch):
9+
"""Custom collate fn for dealing with batches of images that have a different
10+
number of associated object annotations (bounding boxes).
11+
12+
Arguments:
13+
batch: (tuple) A tuple of tensor images and lists of annotations
14+
15+
Return:
16+
A tuple containing:
17+
1) (tensor) batch of images stacked on their 0 dim
18+
2) (list of tensors) annotations for a given image are stacked on
19+
0 dim
20+
"""
21+
targets = []
22+
imgs = []
23+
for sample in batch:
24+
imgs.append(sample[0])
25+
targets.append(torch.FloatTensor(sample[1]))
26+
return torch.stack(imgs, 0), targets
27+
628

729
def base_transform(image, size, mean):
830
x = cv2.resize(image, (size, size)).astype(np.float32)
9-
# x = cv2.resize(np.array(image), (size, size)).astype(np.float32)
1031
x -= mean
1132
x = x.astype(np.float32)
1233
return x

data/coco.py

Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
from .config import HOME
2+
import os
3+
import os.path as osp
4+
import sys
5+
import torch
6+
import torch.utils.data as data
7+
import torchvision.transforms as transforms
8+
import cv2
9+
import numpy as np
10+
11+
COCO_ROOT = osp.join(HOME, 'data/coco/')
12+
IMAGES = 'images'
13+
ANNOTATIONS = 'annotations'
14+
COCO_API = 'PythonAPI'
15+
INSTANCES_SET = 'instances_{}.json'
16+
COCO_CLASSES = ('person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
17+
'train', 'truck', 'boat', 'traffic light', 'fire', 'hydrant',
18+
'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog',
19+
'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra',
20+
'giraffe', 'backpack', 'umbrella', 'handbag', 'tie',
21+
'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
22+
'kite', 'baseball bat', 'baseball glove', 'skateboard',
23+
'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup',
24+
'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
25+
'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
26+
'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed',
27+
'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote',
28+
'keyboard', 'cell phone', 'microwave oven', 'toaster', 'sink',
29+
'refrigerator', 'book', 'clock', 'vase', 'scissors',
30+
'teddy bear', 'hair drier', 'toothbrush')
31+
32+
33+
class COCOAnnotationTransform(object):
34+
"""Transforms a COCO annotation into a Tensor of bbox coords and label index
35+
Initilized with a dictionary lookup of classnames to indexes
36+
"""
37+
def __init__(self):
38+
self.label_map = get_label_map(osp.join(COCO_ROOT, 'coco_labels.txt'))
39+
40+
def __call__(self, target, width, height):
41+
"""
42+
Args:
43+
target (dict): COCO target json annotation as a python dict
44+
height (int): height
45+
width (int): width
46+
Returns:
47+
a list containing lists of bounding boxes [bbox coords, class idx]
48+
"""
49+
scale = np.array([width, height, width, height])
50+
res = []
51+
for obj in target:
52+
if 'bbox' in obj:
53+
bbox = obj['bbox']
54+
bbox[2] += bbox[0]
55+
bbox[3] += bbox[1]
56+
label_idx = self.label_map[obj['category_id']] - 1
57+
final_box = list(np.array(bbox)/scale)
58+
final_box.append(label_idx)
59+
res += [final_box] # [xmin, ymin, xmax, ymax, label_idx]
60+
else:
61+
print("no bbox problem!")
62+
63+
return res # [[xmin, ymin, xmax, ymax, label_idx], ... ]
64+
65+
66+
class COCODetection(data.Dataset):
67+
"""`MS Coco Detection <http://mscoco.org/dataset/#detections-challenge2016>`_ Dataset.
68+
Args:
69+
root (string): Root directory where images are downloaded to.
70+
set_name (string): Name of the specific set of COCO images.
71+
transform (callable, optional): A function/transform that augments the
72+
raw images`
73+
target_transform (callable, optional): A function/transform that takes
74+
in the target (bbox) and transforms it.
75+
"""
76+
77+
def __init__(self, root, image_set, transform=None,
78+
target_transform=None):
79+
sys.path.append(osp.join(root, COCO_API))
80+
from pycocotools.coco import COCO
81+
self.root = osp.join(root, IMAGES, image_set)
82+
self.coco = COCO(osp.join(root, ANNOTATIONS,
83+
INSTANCES_SET.format(image_set)))
84+
self.ids = list(self.coco.imgToAnns.keys())
85+
self.transform = transform
86+
self.target_transform = target_transform
87+
self.name = 'MS COCO ' + image_set
88+
89+
def __getitem__(self, index):
90+
"""
91+
Args:
92+
index (int): Index
93+
Returns:
94+
tuple: Tuple (image, target).
95+
target is the object returned by ``coco.loadAnns``.
96+
"""
97+
im, gt, h, w = self.pull_item(index)
98+
return im, gt
99+
100+
def __len__(self):
101+
return len(self.ids)
102+
103+
def pull_item(self, index):
104+
"""
105+
Args:
106+
index (int): Index
107+
Returns:
108+
tuple: Tuple (image, target, height, width).
109+
target is the object returned by ``coco.loadAnns``.
110+
"""
111+
img_id = self.ids[index]
112+
target = self.coco.imgToAnns[img_id]
113+
ann_ids = self.coco.getAnnIds(imgIds=img_id)
114+
115+
target = self.coco.loadAnns(ann_ids)
116+
path = osp.join(self.root, self.coco.loadImgs(img_id)[0]['file_name'])
117+
assert osp.exists(path), 'Image path does not exist: {}'.format(path)
118+
img = cv2.imread(osp.join(self.root, path))
119+
height, width, _ = img.shape
120+
if self.target_transform is not None:
121+
target = self.target_transform(target, width, height)
122+
if self.transform is not None:
123+
target = np.array(target)
124+
img, boxes, labels = self.transform(img, target[:, :4],
125+
target[:, 4])
126+
# to rgb
127+
img = img[:, :, (2, 1, 0)]
128+
129+
target = np.hstack((boxes, np.expand_dims(labels, axis=1)))
130+
return torch.from_numpy(img).permute(2, 0, 1), target, height, width
131+
132+
def pull_image(self, index):
133+
'''Returns the original image object at index in PIL form
134+
135+
Note: not using self.__getitem__(), as any transformations passed in
136+
could mess up this functionality.
137+
138+
Argument:
139+
index (int): index of img to show
140+
Return:
141+
cv2 img
142+
'''
143+
img_id = self.ids[index]
144+
path = self.coco.loadImgs(img_id)[0]['file_name']
145+
return cv2.imread(osp.join(self.root, path), cv2.IMREAD_COLOR)
146+
147+
def pull_anno(self, index):
148+
'''Returns the original annotation of image at index
149+
150+
Note: not using self.__getitem__(), as any transformations passed in
151+
could mess up this functionality.
152+
153+
Argument:
154+
index (int): index of img to get annotation of
155+
Return:
156+
list: [img_id, [(label, bbox coords),...]]
157+
eg: ('001718', [('dog', (96, 13, 438, 332))])
158+
'''
159+
img_id = self.ids[index]
160+
ann_ids = self.coco.getAnnIds(imgIds=img_id)
161+
return self.coco.loadAnns(ann_ids)
162+
163+
def __repr__(self):
164+
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
165+
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
166+
fmt_str += ' Root Location: {}\n'.format(self.root)
167+
tmp = ' Transforms (if any): '
168+
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
169+
tmp = ' Target Transforms (if any): '
170+
fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
171+
return fmt_str
172+
173+
174+
def get_label_map(label_file):
175+
label_map = {}
176+
labels = open(label_file, 'r')
177+
for line in labels:
178+
ids = line.split(',')
179+
label_map[int(ids[0])] = int(ids[1])
180+
return label_map

data/config.py

Lines changed: 29 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -2,63 +2,35 @@
22
import os.path
33

44
# gets home dir cross platform
5-
home = os.path.expanduser("~")
6-
ddir = os.path.join(home,"data/VOCdevkit/")
7-
8-
# note: if you used our download scripts, this should be right
9-
VOCroot = ddir # path to VOCdevkit root dir
10-
11-
# default batch size
12-
BATCHES = 32
13-
# data reshuffled at every epoch
14-
SHUFFLE = True
15-
# number of subprocesses to use for data loading
16-
WORKERS = 4
17-
18-
19-
#SSD300 CONFIGS
20-
# newer version: use additional conv11_2 layer as last layer before multibox layers
21-
v2 = {
22-
'feature_maps' : [38, 19, 10, 5, 3, 1],
23-
24-
'min_dim' : 300,
25-
26-
'steps' : [8, 16, 32, 64, 100, 300],
27-
28-
'min_sizes' : [30, 60, 111, 162, 213, 264],
29-
30-
'max_sizes' : [60, 111, 162, 213, 264, 315],
31-
32-
# 'aspect_ratios' : [[2, 1/2], [2, 1/2, 3, 1/3], [2, 1/2, 3, 1/3],
33-
# [2, 1/2, 3, 1/3], [2, 1/2], [2, 1/2]],
34-
'aspect_ratios' : [[2], [2, 3], [2, 3], [2, 3], [2], [2]],
35-
36-
'variance' : [0.1, 0.2],
37-
38-
'clip' : True,
39-
40-
'name' : 'v2',
5+
HOME = os.path.expanduser("~")
6+
7+
# for making bounding boxes pretty
8+
COLORS = ((255, 0, 0, 128), (0, 255, 0, 128), (0, 0, 255, 128),
9+
(0, 255, 255, 128), (255, 0, 255, 128), (255, 255, 0, 128))
10+
11+
MEANS = (104, 117, 123)
12+
13+
# SSD300 CONFIGS
14+
voc = {
15+
'feature_maps': [38, 19, 10, 5, 3, 1],
16+
'min_dim': 300,
17+
'steps': [8, 16, 32, 64, 100, 300],
18+
'min_sizes': [30, 60, 111, 162, 213, 264],
19+
'max_sizes': [60, 111, 162, 213, 264, 315],
20+
'aspect_ratios': [[2], [2, 3], [2, 3], [2, 3], [2], [2]],
21+
'variance': [0.1, 0.2],
22+
'clip': True,
23+
'name': 'VOC',
4124
}
4225

43-
# use average pooling layer as last layer before multibox layers
44-
v1 = {
45-
'feature_maps' : [38, 19, 10, 5, 3, 1],
46-
47-
'min_dim' : 300,
48-
49-
'steps' : [8, 16, 32, 64, 100, 300],
50-
51-
'min_sizes' : [30, 60, 114, 168, 222, 276],
52-
53-
'max_sizes' : [-1, 114, 168, 222, 276, 330],
54-
55-
# 'aspect_ratios' : [[2], [2, 3], [2, 3], [2, 3], [2, 3], [2, 3]],
56-
'aspect_ratios' : [[1,1,2,1/2],[1,1,2,1/2,3,1/3],[1,1,2,1/2,3,1/3],
57-
[1,1,2,1/2,3,1/3],[1,1,2,1/2,3,1/3],[1,1,2,1/2,3,1/3]],
58-
59-
'variance' : [0.1, 0.2],
60-
61-
'clip' : True,
62-
63-
'name' : 'v1',
26+
coco = {
27+
'feature_maps': [38, 19, 10, 5, 3, 1],
28+
'min_dim': 300,
29+
'steps': [8, 16, 32, 64, 100, 300],
30+
'min_sizes': [21, 45, 99, 153, 207, 261],
31+
'max_sizes': [45, 99, 153, 207, 261, 315],
32+
'aspect_ratios': [[2], [2, 3], [2, 3], [2, 3], [2], [2]],
33+
'variance': [0.1, 0.2],
34+
'clip': True,
35+
'name': 'COCO',
6436
}

0 commit comments

Comments
 (0)