Skip to content

Commit 8dd3865

Browse files
committed
Add COCO training option and cleanup training script
1 parent dd85aff commit 8dd3865

File tree

12 files changed

+302
-163
lines changed

12 files changed

+302
-163
lines changed

README.md

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,10 @@ A [PyTorch](http://pytorch.org/) implementation of [Single Shot MultiBox Detecto
2424
- Clone this repository.
2525
* Note: We currently only support Python 3+.
2626
- Then download the dataset by following the [instructions](#datasets) below.
27-
- We now support [Visdom](https://github.com/facebookresearch/visdom) for real-time loss visualization during training!
28-
* To use Visdom in the browser:
27+
- We now support [Visdom](https://github.com/facebookresearch/visdom) for real-time loss visualization during training!
28+
* To use Visdom in the browser:
2929
```Shell
30-
# First install Python server and client
30+
# First install Python server and client
3131
pip install visdom
3232
# Start the server (probably in a screen or tmux)
3333
python -m visdom.server
@@ -40,7 +40,7 @@ To make things easy, we provide bash scripts to handle the dataset downloads and
4040

4141

4242
### COCO
43-
Microsoft COCO: Common Objects in Context
43+
Microsoft COCO: Common Objects in Context
4444

4545
##### Download COCO 2014
4646
```Shell
@@ -83,7 +83,7 @@ python train.py
8383
* For training, an NVIDIA GPU is strongly recommended for speed.
8484
* For instructions on Visdom usage/installation, see the <a href='#installation'>Installation</a> section.
8585
* You can pick-up training from a checkpoint by specifying the path as one of the training parameters (again, see `train.py` for options)
86-
86+
8787
## Evaluation
8888
To evaluate a trained network:
8989

@@ -107,31 +107,31 @@ You can specify the parameters listed in the `eval.py` file by flagging them or
107107
| 77.2 % | 77.26 % | 58.12% | 77.43 % |
108108

109109
##### FPS
110-
**GTX 1060:** ~45.45 FPS
110+
**GTX 1060:** ~45.45 FPS
111111

112112
## Demos
113113

114114
### Use a pre-trained SSD network for detection
115115

116116
#### Download a pre-trained network
117117
- We are trying to provide PyTorch `state_dicts` (dict of weight tensors) of the latest SSD model definitions trained on different datasets.
118-
- Currently, we provide the following PyTorch models:
118+
- Currently, we provide the following PyTorch models:
119119
* SSD300 trained on VOC0712 (newest PyTorch weights)
120120
- https://s3.amazonaws.com/amdegroot-models/ssd300_mAP_77.43_v2.pth
121121
* SSD300 trained on VOC0712 (original Caffe weights)
122122
- https://s3.amazonaws.com/amdegroot-models/ssd_300_VOC0712.pth
123-
- Our goal is to reproduce this table from the [original paper](http://arxiv.org/abs/1512.02325)
123+
- Our goal is to reproduce this table from the [original paper](http://arxiv.org/abs/1512.02325)
124124
<p align="left">
125125
<img src="http://www.cs.unc.edu/~wliu/papers/ssd_results.png" alt="SSD results on multiple datasets" width="800px"></p>
126126

127127
### Try the demo notebook
128128
- Make sure you have [jupyter notebook](http://jupyter.readthedocs.io/en/latest/install.html) installed.
129129
- Two alternatives for installing jupyter notebook:
130-
1. If you installed PyTorch with [conda](https://www.continuum.io/downloads) (recommended), then you should already have it. (Just navigate to the ssd.pytorch cloned repo and run):
131-
`jupyter notebook`
130+
1. If you installed PyTorch with [conda](https://www.continuum.io/downloads) (recommended), then you should already have it. (Just navigate to the ssd.pytorch cloned repo and run):
131+
`jupyter notebook`
132132

133133
2. If using [pip](https://pypi.python.org/pypi/pip):
134-
134+
135135
```Shell
136136
# make sure pip is upgraded
137137
pip3 install --upgrade pip
@@ -169,5 +169,5 @@ We have accumulated the following to-do list, which we hope to complete in the n
169169
- Wei Liu, et al. "SSD: Single Shot MultiBox Detector." [ECCV2016]((http://arxiv.org/abs/1512.02325)).
170170
- [Original Implementation (CAFFE)](https://github.com/weiliu89/caffe/tree/ssd)
171171
- A huge thank you to [Alex Koltun](https://github.com/alexkoltun) and his team at [Webyclip](webyclip.com) for their help in finishing the data augmentation portion.
172-
- A list of other great SSD ports that were sources of inspiration (especially the Chainer repo):
173-
* [Chainer](https://github.com/Hakuyume/chainer-ssd), [Keras](https://github.com/rykov8/ssd_keras), [MXNet](https://github.com/zhreshold/mxnet-ssd), [Tensorflow](https://github.com/balancap/SSD-Tensorflow)
172+
- A list of other great SSD ports that were sources of inspiration (especially the Chainer repo):
173+
* [Chainer](https://github.com/Hakuyume/chainer-ssd), [Keras](https://github.com/rykov8/ssd_keras), [MXNet](https://github.com/zhreshold/mxnet-ssd), [Tensorflow](https://github.com/balancap/SSD-Tensorflow)

data/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from .voc0712 import VOCDetection, VOCAnnotationTransform, VOC_CLASSES, VOC_ROOT
2-
from .coco import COCODetection, COCOAnnotationTransform, COCO_CLASSES, COCO_ROOT
2+
3+
from .coco import COCODetection, COCOAnnotationTransform, COCO_CLASSES, COCO_ROOT, get_label_map
34
from .config import *
45
import torch
56
import cv2

data/coco.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,15 @@
3030
'teddy bear', 'hair drier', 'toothbrush')
3131

3232

33+
def get_label_map(label_file):
34+
label_map = {}
35+
labels = open(label_file, 'r')
36+
for line in labels:
37+
ids = line.split(',')
38+
label_map[int(ids[0])] = int(ids[1])
39+
return label_map
40+
41+
3342
class COCOAnnotationTransform(object):
3443
"""Transforms a COCO annotation into a Tensor of bbox coords and label index
3544
Initilized with a dictionary lookup of classnames to indexes
@@ -74,8 +83,8 @@ class COCODetection(data.Dataset):
7483
in the target (bbox) and transforms it.
7584
"""
7685

77-
def __init__(self, root, image_set, transform=None,
78-
target_transform=None):
86+
def __init__(self, root, image_set='trainval35k', transform=None,
87+
target_transform=COCOAnnotationTransform(), dataset_name='MS COCO'):
7988
sys.path.append(osp.join(root, COCO_API))
8089
from pycocotools.coco import COCO
8190
self.root = osp.join(root, IMAGES, image_set)
@@ -84,7 +93,7 @@ def __init__(self, root, image_set, transform=None,
8493
self.ids = list(self.coco.imgToAnns.keys())
8594
self.transform = transform
8695
self.target_transform = target_transform
87-
self.name = 'MS COCO ' + image_set
96+
self.name = dataset_name
8897

8998
def __getitem__(self, index):
9099
"""
@@ -169,12 +178,3 @@ def __repr__(self):
169178
tmp = ' Target Transforms (if any): '
170179
fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
171180
return fmt_str
172-
173-
174-
def get_label_map(label_file):
175-
label_map = {}
176-
labels = open(label_file, 'r')
177-
for line in labels:
178-
ids = line.split(',')
179-
label_map[int(ids[0])] = int(ids[1])
180-
return label_map

data/coco_labels.txt

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
1,1,person
2+
2,2,bicycle
3+
3,3,car
4+
4,4,motorcycle
5+
5,5,airplane
6+
6,6,bus
7+
7,7,train
8+
8,8,truck
9+
9,9,boat
10+
10,10,traffic light
11+
11,11,fire hydrant
12+
13,12,stop sign
13+
14,13,parking meter
14+
15,14,bench
15+
16,15,bird
16+
17,16,cat
17+
18,17,dog
18+
19,18,horse
19+
20,19,sheep
20+
21,20,cow
21+
22,21,elephant
22+
23,22,bear
23+
24,23,zebra
24+
25,24,giraffe
25+
27,25,backpack
26+
28,26,umbrella
27+
31,27,handbag
28+
32,28,tie
29+
33,29,suitcase
30+
34,30,frisbee
31+
35,31,skis
32+
36,32,snowboard
33+
37,33,sports ball
34+
38,34,kite
35+
39,35,baseball bat
36+
40,36,baseball glove
37+
41,37,skateboard
38+
42,38,surfboard
39+
43,39,tennis racket
40+
44,40,bottle
41+
46,41,wine glass
42+
47,42,cup
43+
48,43,fork
44+
49,44,knife
45+
50,45,spoon
46+
51,46,bowl
47+
52,47,banana
48+
53,48,apple
49+
54,49,sandwich
50+
55,50,orange
51+
56,51,broccoli
52+
57,52,carrot
53+
58,53,hot dog
54+
59,54,pizza
55+
60,55,donut
56+
61,56,cake
57+
62,57,chair
58+
63,58,couch
59+
64,59,potted plant
60+
65,60,bed
61+
67,61,dining table
62+
70,62,toilet
63+
72,63,tv
64+
73,64,laptop
65+
74,65,mouse
66+
75,66,remote
67+
76,67,keyboard
68+
77,68,cell phone
69+
78,69,microwave
70+
79,70,oven
71+
80,71,toaster
72+
81,72,sink
73+
82,73,refrigerator
74+
84,74,book
75+
85,75,clock
76+
86,76,vase
77+
87,77,scissors
78+
88,78,teddy bear
79+
89,79,hair drier
80+
90,80,toothbrush

data/config.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@
1212

1313
# SSD300 CONFIGS
1414
voc = {
15+
'num_classes': 21,
16+
'lr_steps': (80000, 100000, 120000),
17+
'max_iter': 120000,
1518
'feature_maps': [38, 19, 10, 5, 3, 1],
1619
'min_dim': 300,
1720
'steps': [8, 16, 32, 64, 100, 300],
@@ -24,6 +27,9 @@
2427
}
2528

2629
coco = {
30+
'num_classes': 201,
31+
'lr_steps': (280000, 360000, 400000),
32+
'max_iter': 400000,
2733
'feature_maps': [38, 19, 10, 5, 3, 1],
2834
'min_dim': 300,
2935
'steps': [8, 16, 32, 64, 100, 300],

data/voc0712.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,10 @@
66
Updated by: Ellis Brown, Max deGroot
77
"""
88
from .config import HOME
9-
import os
10-
import os.path
9+
import os.path as osp
1110
import sys
1211
import torch
1312
import torch.utils.data as data
14-
import torchvision.transforms as transforms
15-
from PIL import Image, ImageDraw, ImageFont
1613
import cv2
1714
import numpy as np
1815
if sys.version_info[0] == 2:
@@ -28,7 +25,7 @@
2825
'sheep', 'sofa', 'train', 'tvmonitor')
2926

3027
# note: if you used our download scripts, this should be right
31-
VOC_ROOT = os.path.join(HOME, "data/VOCdevkit/")
28+
VOC_ROOT = osp.join(HOME, "data/VOCdevkit/")
3229

3330

3431
class VOCAnnotationTransform(object):
@@ -97,19 +94,21 @@ class VOCDetection(data.Dataset):
9794
(default: 'VOC2007')
9895
"""
9996

100-
def __init__(self, root, image_sets, transform=None, target_transform=None,
97+
def __init__(self, root,
98+
image_sets=[('2007', 'trainval'), ('2012', 'trainval')],
99+
transform=None, target_transform=VOCAnnotationTransform(),
101100
dataset_name='VOC0712'):
102101
self.root = root
103102
self.image_set = image_sets
104103
self.transform = transform
105104
self.target_transform = target_transform
106105
self.name = dataset_name
107-
self._annopath = os.path.join('%s', 'Annotations', '%s.xml')
108-
self._imgpath = os.path.join('%s', 'JPEGImages', '%s.jpg')
106+
self._annopath = osp.join('%s', 'Annotations', '%s.xml')
107+
self._imgpath = osp.join('%s', 'JPEGImages', '%s.jpg')
109108
self.ids = list()
110109
for (year, name) in image_sets:
111-
rootpath = os.path.join(self.root, 'VOC' + year)
112-
for line in open(os.path.join(rootpath, 'ImageSets', 'Main', name + '.txt')):
110+
rootpath = osp.join(self.root, 'VOC' + year)
111+
for line in open(osp.join(rootpath, 'ImageSets', 'Main', name + '.txt')):
113112
self.ids.append((rootpath, line.strip()))
114113

115114
def __getitem__(self, index):

demo/demo.ipynb

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -26,16 +26,12 @@
2626
"import torch.nn as nn\n",
2727
"import torch.backends.cudnn as cudnn\n",
2828
"from torch.autograd import Variable\n",
29-
"import torch.utils.data as data\n",
30-
"import torchvision.transforms as transforms\n",
31-
"from torch.utils.serialization import load_lua\n",
3229
"import numpy as np\n",
3330
"import cv2\n",
3431
"if torch.cuda.is_available():\n",
3532
" torch.set_default_tensor_type('torch.cuda.FloatTensor')\n",
3633
"\n",
37-
"from ssd import build_ssd\n",
38-
"# from models import build_ssd as build_ssd_v1 # uncomment for older pool6 model"
34+
"from ssd import build_ssd"
3935
]
4036
},
4137
{
@@ -52,6 +48,7 @@
5248
"cell_type": "code",
5349
"execution_count": 2,
5450
"metadata": {
51+
"collapsed": false,
5552
"scrolled": false
5653
},
5754
"outputs": [
@@ -80,7 +77,9 @@
8077
{
8178
"cell_type": "code",
8279
"execution_count": 3,
83-
"metadata": {},
80+
"metadata": {
81+
"collapsed": false
82+
},
8483
"outputs": [
8584
{
8685
"data": {
@@ -97,9 +96,9 @@
9796
"# image = cv2.imread('./data/example.jpg', cv2.IMREAD_COLOR) # uncomment if dataset not downloaded\n",
9897
"%matplotlib inline\n",
9998
"from matplotlib import pyplot as plt\n",
100-
"from data import VOCDetection, VOCroot, AnnotationTransform\n",
99+
"from data import VOCDetection, VOC_ROOT, VOCAnnotationTransform\n",
101100
"# here we specify year (07 or 12) and dataset ('test', 'val', 'train') \n",
102-
"testset = VOCDetection(VOCroot, [('2007', 'val')], None, AnnotationTransform())\n",
101+
"testset = VOCDetection(VOC_ROOT, [('2007', 'val')], None, VOCAnnotationTransform())\n",
103102
"img_id = 60\n",
104103
"image = testset.pull_image(img_id)\n",
105104
"rgb_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)\n",
@@ -123,7 +122,9 @@
123122
{
124123
"cell_type": "code",
125124
"execution_count": 4,
126-
"metadata": {},
125+
"metadata": {
126+
"collapsed": false
127+
},
127128
"outputs": [
128129
{
129130
"data": {
@@ -157,6 +158,7 @@
157158
"cell_type": "code",
158159
"execution_count": 5,
159160
"metadata": {
161+
"collapsed": true,
160162
"scrolled": true
161163
},
162164
"outputs": [],
@@ -179,7 +181,9 @@
179181
{
180182
"cell_type": "code",
181183
"execution_count": 6,
182-
"metadata": {},
184+
"metadata": {
185+
"collapsed": false
186+
},
183187
"outputs": [
184188
{
185189
"data": {

0 commit comments

Comments
 (0)