Skip to content

Commit e1bb0d0

Browse files
committed
Operation Coco is underway
1 parent 58f5af5 commit e1bb0d0

File tree

5 files changed

+90
-79
lines changed

5 files changed

+90
-79
lines changed

data/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import cv2
66
import numpy as np
77

8-
98
def detection_collate(batch):
109
"""Custom collate fn for dealing with batches of images that have a different
1110
number of associated object annotations (bounding boxes).

data/coco.py

Lines changed: 32 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
from .config import HOME
22
import os
3-
import os.path
3+
import os.path as osp
44
import sys
55
import torch
66
import torch.utils.data as data
77
import torchvision.transforms as transforms
88
import cv2
99
import numpy as np
1010

11-
COCO_ROOT = os.path.join(HOME, 'data/coco/')
11+
COCO_ROOT = osp.join(HOME, 'data/coco/')
1212
IMAGES = 'images'
1313
ANNOTATIONS = 'annotations'
1414
COCO_API = 'PythonAPI'
@@ -34,6 +34,8 @@ class COCOAnnotationTransform(object):
3434
"""Transforms a COCO annotation into a Tensor of bbox coords and label index
3535
Initilized with a dictionary lookup of classnames to indexes
3636
"""
37+
def __init__(self):
38+
self.label_map = get_label_map(osp.join(COCO_ROOT, 'coco_labels.txt'))
3739

3840
def __call__(self, target, width, height):
3941
"""
@@ -51,10 +53,13 @@ def __call__(self, target, width, height):
5153
bbox = obj['bbox']
5254
bbox[2] += bbox[0]
5355
bbox[3] += bbox[1]
54-
label_idx = obj['category_id']
56+
label_idx = self.label_map[obj['category_id']] - 1
5557
final_box = list(np.array(bbox)/scale)
5658
final_box.append(label_idx)
5759
res += [final_box] # [xmin, ymin, xmax, ymax, label_idx]
60+
else:
61+
print("no bbox problem!")
62+
5863
return res # [[xmin, ymin, xmax, ymax, label_idx], ... ]
5964

6065

@@ -70,16 +75,16 @@ class COCODetection(data.Dataset):
7075
"""
7176

7277
def __init__(self, root, image_set, transform=None,
73-
target_transform=None, dataset_name='COCO2014'):
74-
sys.path.append(os.path.join(root, COCO_API))
78+
target_transform=None):
79+
sys.path.append(osp.join(root, COCO_API))
7580
from pycocotools.coco import COCO
76-
self.root = os.path.join(root, IMAGES, image_set)
77-
self.coco = COCO(os.path.join(root, ANNOTATIONS,
78-
INSTANCES_SET.format(image_set)))
79-
self.ids = list(self.coco.imgs.keys())
81+
self.root = osp.join(root, IMAGES, image_set)
82+
self.coco = COCO(osp.join(root, ANNOTATIONS,
83+
INSTANCES_SET.format(image_set)))
84+
self.ids = list(self.coco.imgToAnns.keys())
8085
self.transform = transform
8186
self.target_transform = target_transform
82-
self.name = dataset_name
87+
self.name = 'MS COCO ' + image_set
8388

8489
def __getitem__(self, index):
8590
"""
@@ -104,11 +109,14 @@ def pull_item(self, index):
104109
target is the object returned by ``coco.loadAnns``.
105110
"""
106111
img_id = self.ids[index]
112+
target = self.coco.imgToAnns[img_id]
107113
ann_ids = self.coco.getAnnIds(imgIds=img_id)
114+
108115
target = self.coco.loadAnns(ann_ids)
109-
path = self.coco.loadImgs(img_id)[0]['file_name']
110-
img = cv2.imread(os.path.join(self.root, path))
111-
height, width, channels = img.shape
116+
path = osp.join(self.root, self.coco.loadImgs(img_id)[0]['file_name'])
117+
assert osp.exists(path), 'Image path does not exist: {}'.format(path)
118+
img = cv2.imread(osp.join(self.root, path))
119+
height, width, _ = img.shape
112120
if self.target_transform is not None:
113121
target = self.target_transform(target, width, height)
114122
if self.transform is not None:
@@ -117,7 +125,7 @@ def pull_item(self, index):
117125
target[:, 4])
118126
# to rgb
119127
img = img[:, :, (2, 1, 0)]
120-
# img = img.transpose(2, 0, 1)
128+
121129
target = np.hstack((boxes, np.expand_dims(labels, axis=1)))
122130
return torch.from_numpy(img).permute(2, 0, 1), target, height, width
123131

@@ -134,7 +142,7 @@ def pull_image(self, index):
134142
'''
135143
img_id = self.ids[index]
136144
path = self.coco.loadImgs(img_id)[0]['file_name']
137-
return cv2.imread(os.path.join(self.root, path), cv2.IMREAD_COLOR)
145+
return cv2.imread(osp.join(self.root, path), cv2.IMREAD_COLOR)
138146

139147
def pull_anno(self, index):
140148
'''Returns the original annotation of image at index
@@ -161,3 +169,12 @@ def __repr__(self):
161169
tmp = ' Target Transforms (if any): '
162170
fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
163171
return fmt_str
172+
173+
174+
def get_label_map(label_file):
175+
label_map = {}
176+
labels = open(label_file, 'r')
177+
for line in labels:
178+
ids = line.split(',')
179+
label_map[int(ids[0])] = int(ids[1])
180+
return label_map

layers/box_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# -*- coding: utf-8 -*-
22
import torch
33

4+
45
def point_form(boxes):
56
""" Convert prior_boxes to (xmin, ymin, xmax, ymax)
67
representation for comparison to point form ground truth data.

layers/modules/multibox_loss.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def forward(self, predictions, targets):
5454
loc shape: torch.size(batch_size,num_priors,4)
5555
priors shape: torch.size(num_priors,4)
5656
57-
ground_truth (tensor): Ground truth boxes and labels for a batch,
57+
targets (tensor): Ground truth boxes and labels for a batch,
5858
shape: [batch_size,num_objs,5] (last idx is the label).
5959
"""
6060
loc_data, conf_data, priors = predictions
@@ -91,12 +91,11 @@ def forward(self, predictions, targets):
9191

9292
# Compute max conf across batch for hard negative mining
9393
batch_conf = conf_data.view(-1, self.num_classes)
94-
9594
loss_c = log_sum_exp(batch_conf) - batch_conf.gather(1, conf_t.view(-1, 1))
9695

9796
# Hard Negative Mining
98-
loss_c = loss_c.view(num, -1)
9997
loss_c[pos] = 0 # filter out pos boxes for now
98+
loss_c = loss_c.view(num, -1)
10099
_, loss_idx = loss_c.sort(1, descending=True)
101100
_, idx_rank = loss_idx.sort(1)
102101
num_pos = pos.long().sum(1, keepdim=True)

train.py

Lines changed: 55 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@ def str2bool(v):
3434
parser.add_argument('--weight_decay', default=5e-4, type=float, help='Weight decay for SGD')
3535
parser.add_argument('--gamma', default=0.1, type=float, help='Gamma update for SGD')
3636
parser.add_argument('--log_iters', default=True, type=bool, help='Print the loss at each iteration')
37-
parser.add_argument('--visdom', default=True, type=str2bool, help='Use visdom for loss visualization')
38-
parser.add_argument('--send_images_to_visdom', type=str2bool, default=True, help='Sample a random image from every 10th batch, send it to visdom after augmentations step')
37+
parser.add_argument('--visdom', default=False, type=str2bool, help='Use visdom for loss visualization')
38+
parser.add_argument('--send_images_to_visdom', type=str2bool, default=False, help='Sample a random image from every 10th batch, send it to visdom after augmentations step')
3939
parser.add_argument('--save_folder', default='weights/', help='Directory for saving checkpoint models')
4040
parser.add_argument('--dataset_root', default=COCO_ROOT, help='Dataset root directory path')
4141
parser.add_argument('-f', default=None, type=str, help="Dummy arg so we can load in Jupyter Notebooks")
@@ -104,63 +104,41 @@ def weights_init(m):
104104
def train():
105105
net.train()
106106
# loss counters
107-
loc_loss = 0 # epoch
107+
loc_loss = 0
108108
conf_loss = 0
109109
epoch = 0
110110
print('Loading Dataset...')
111111
dataset = COCODetection(args.dataset_root, args.image_set, SSDAugmentation(
112112
SSD_DIM, MEANS), COCOAnnotationTransform())
113113

114114
epoch_size = len(dataset) // args.batch_size
115-
print('Training SSD on ', dataset.name)
115+
print('Training SSD on', dataset.name)
116116
step_index = 0
117+
117118
if args.visdom:
118-
# initialize visdom loss plot
119-
lot = viz.line(
120-
X=torch.zeros((1,)).cpu(),
121-
Y=torch.zeros((1, 3)).cpu(),
122-
opts=dict(
123-
xlabel='Iteration',
124-
ylabel='Loss',
125-
title='Current SSD Training Loss',
126-
legend=['Loc Loss', 'Conf Loss', 'Loss']
127-
)
128-
)
129-
epoch_lot = viz.line(
130-
X=torch.zeros((1,)).cpu(),
131-
Y=torch.zeros((1, 3)).cpu(),
132-
opts=dict(
133-
xlabel='Epoch',
134-
ylabel='Loss',
135-
title='Epoch SSD Training Loss',
136-
legend=['Loc Loss', 'Conf Loss', 'Loss']
137-
)
138-
)
139-
batch_iterator = None
119+
vis_title = 'SSD.PyTorch on ' + args.image_set
120+
vis_legend = ['Loc Loss', 'Conf Loss', 'Total Loss']
121+
iter_plot = create_vis_plot('Iteration', 'Loss', vis_title, vis_legend)
122+
epoch_plot = create_vis_plot('Epoch', 'Loss', vis_title, vis_legend)
140123
data_loader = data.DataLoader(dataset, args.batch_size,
141124
num_workers=args.num_workers,
142125
shuffle=True, collate_fn=detection_collate,
143126
pin_memory=True)
127+
# create batch iterator
128+
batch_iterator = iter(data_loader)
144129
for iteration in range(args.start_iter, args.max_iter):
145-
if (not batch_iterator) or (iteration % epoch_size == 0):
146-
# create batch iterator
147-
batch_iterator = iter(data_loader)
148-
if iteration in STEP_VALUES:
149-
step_index += 1
150-
adjust_learning_rate(optimizer, args.gamma, step_index)
151-
if args.visdom:
152-
viz.line(
153-
X=torch.ones((1, 3)).cpu() * epoch,
154-
Y=torch.Tensor([loc_loss, conf_loss,
155-
loc_loss + conf_loss]).unsqueeze(0).cpu() / epoch_size,
156-
win=epoch_lot,
157-
update='append'
158-
)
130+
if iteration != 0 and (iteration % epoch_size == 0) and args.visdom:
131+
update_vis_plot(epoch, loc_loss, conf_loss, epoch_plot, None,
132+
'append', epoch_size)
159133
# reset epoch loss counters
160134
loc_loss = 0
161135
conf_loss = 0
162136
epoch += 1
163137

138+
if iteration in STEP_VALUES:
139+
step_index += 1
140+
adjust_learning_rate(optimizer, args.gamma, step_index)
141+
164142
# load train data
165143
images, targets = next(batch_iterator)
166144

@@ -182,29 +160,15 @@ def train():
182160
t1 = time.time()
183161
loc_loss += loss_l.data[0]
184162
conf_loss += loss_c.data[0]
163+
185164
if iteration % 10 == 0:
186-
print('Timer: %.4f sec.' % (t1 - t0))
165+
print('timer: %.4f sec.' % (t1 - t0))
187166
print('iter ' + repr(iteration) + ' || Loss: %.4f ||' % (loss.data[0]), end=' ')
188-
if args.visdom and args.send_images_to_visdom:
189-
random_batch_index = np.random.randint(images.size(0))
190-
viz.image(images.data[random_batch_index].cpu().numpy())
167+
191168
if args.visdom:
192-
viz.line(
193-
X=torch.ones((1, 3)).cpu() * iteration,
194-
Y=torch.Tensor([loss_l.data[0], loss_c.data[0],
195-
loss_l.data[0] + loss_c.data[0]]).unsqueeze(0).cpu(),
196-
win=lot,
197-
update='append'
198-
)
199-
# hacky fencepost solution for 0th epoch plot
200-
if iteration == 0:
201-
viz.line(
202-
X=torch.zeros((1, 3)).cpu(),
203-
Y=torch.Tensor([loc_loss, conf_loss,
204-
loc_loss + conf_loss]).unsqueeze(0).cpu(),
205-
win=epoch_lot,
206-
update=True
207-
)
169+
update_vis_plot(iteration, loss_l.data[0], loss_c.data[0],
170+
iter_plot, epoch_plot, 'append')
171+
208172
if iteration % 5000 == 0:
209173
print('Saving state, iter:', iteration)
210174
torch.save(ssd_net.state_dict(), 'weights/ssd300_COCO_' +
@@ -224,5 +188,36 @@ def adjust_learning_rate(optimizer, gamma, step):
224188
param_group['lr'] = lr
225189

226190

191+
def create_vis_plot(_xlabel, _ylabel, _title, _legend):
192+
return viz.line(
193+
X=torch.zeros((1,)).cpu(),
194+
Y=torch.zeros((1, 3)).cpu(),
195+
opts=dict(
196+
xlabel=_xlabel,
197+
ylabel=_ylabel,
198+
title=_title,
199+
legend=_legend
200+
)
201+
)
202+
203+
204+
def update_vis_plot(iteration, loc, conf, window1, window2, update_type,
205+
epoch_size=1):
206+
viz.line(
207+
X=torch.ones((1, 3)).cpu() * iteration,
208+
Y=torch.Tensor([loc, conf, loc + conf]).unsqueeze(0).cpu() / epoch_size,
209+
win=window1,
210+
update=update_type
211+
)
212+
# initialize epoch plot on first iteration
213+
if iteration == 0:
214+
viz.line(
215+
X=torch.zeros((1, 3)).cpu(),
216+
Y=torch.Tensor([loc, conf, loc + conf]).unsqueeze(0).cpu(),
217+
win=window2,
218+
update=True
219+
)
220+
221+
227222
if __name__ == '__main__':
228223
train()

0 commit comments

Comments
 (0)