Skip to content

Commit 4701d83

Browse files
tiltamdegroot
authored andcommitted
Fix batch-wise inference (amdegroot#98)
* Fix batch-wise inference
1 parent 27d4e4b commit 4701d83

File tree

2 files changed

+12
-20
lines changed

2 files changed

+12
-20
lines changed

layers/functions/detection.py

Lines changed: 10 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,5 @@
11
import torch
2-
import torch.nn as nn
3-
import torch.backends.cudnn as cudnn
42
from torch.autograd import Function
5-
from torch.autograd import Variable
63
from ..box_utils import decode, nms
74
from data import v2 as cfg
85

@@ -23,7 +20,6 @@ def __init__(self, num_classes, bkg_label, top_k, conf_thresh, nms_thresh):
2320
raise ValueError('nms_threshold must be non negative.')
2421
self.conf_thresh = conf_thresh
2522
self.variance = cfg['variance']
26-
self.output = torch.zeros(1, self.num_classes, self.top_k, 5)
2723

2824
def forward(self, loc_data, conf_data, prior_data):
2925
"""
@@ -37,21 +33,16 @@ def forward(self, loc_data, conf_data, prior_data):
3733
"""
3834
num = loc_data.size(0) # batch size
3935
num_priors = prior_data.size(0)
40-
self.output.zero_()
41-
if num == 1:
42-
# size batch x num_classes x num_priors
43-
conf_preds = conf_data.t().contiguous().unsqueeze(0)
44-
else:
45-
conf_preds = conf_data.view(num, num_priors,
46-
self.num_classes).transpose(2, 1)
47-
self.output.expand_(num, self.num_classes, self.top_k, 5)
36+
output = torch.zeros(num, self.num_classes, self.top_k, 5)
37+
conf_preds = conf_data.view(num, num_priors,
38+
self.num_classes).transpose(2, 1)
4839

4940
# Decode predictions into bboxes.
5041
for i in range(num):
5142
decoded_boxes = decode(loc_data[i], prior_data, self.variance)
5243
# For each class, perform nms
5344
conf_scores = conf_preds[i].clone()
54-
num_det = 0
45+
5546
for cl in range(1, self.num_classes):
5647
c_mask = conf_scores[cl].gt(self.conf_thresh)
5748
scores = conf_scores[cl][c_mask]
@@ -61,11 +52,11 @@ def forward(self, loc_data, conf_data, prior_data):
6152
boxes = decoded_boxes[l_mask].view(-1, 4)
6253
# idx of highest scoring and non-overlapping boxes per class
6354
ids, count = nms(boxes, scores, self.nms_thresh, self.top_k)
64-
self.output[i, cl, :count] = \
55+
output[i, cl, :count] = \
6556
torch.cat((scores[ids[:count]].unsqueeze(1),
6657
boxes[ids[:count]]), 1)
67-
flt = self.output.view(-1, 5)
68-
_, idx = flt[:, 0].sort(0)
69-
_, rank = idx.sort(0)
70-
flt[(rank >= self.top_k).unsqueeze(1).expand_as(flt)].fill_(0)
71-
return self.output
58+
flt = output.contiguous().view(num, -1, 5)
59+
_, idx = flt[:, :, 0].sort(1, descending=True)
60+
_, rank = idx.sort(1)
61+
flt[(rank < self.top_k).unsqueeze(-1).expand_as(flt)].fill_(0)
62+
return output

ssd.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,8 @@ def forward(self, x):
9797
if self.phase == "test":
9898
output = self.detect(
9999
loc.view(loc.size(0), -1, 4), # loc preds
100-
self.softmax(conf.view(-1, self.num_classes)), # conf preds
100+
self.softmax(conf.view(-1, self.num_classes)) \
101+
.view(conf.size(0), -1, self.num_classes), # conf preds
101102
self.priors.type(type(x.data)) # default boxes
102103
)
103104
else:

0 commit comments

Comments
 (0)