Skip to content

Commit 27d4e4b

Browse files
committed
Fix sum() bug in multibox_loss that was due to old pytorch compatibility
1 parent 3009727 commit 27d4e4b

File tree

2 files changed

+8
-6
lines changed

2 files changed

+8
-6
lines changed

layers/modules/multibox_loss.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from data import v2 as cfg
77
from ..box_utils import match, log_sum_exp
88

9+
910
class MultiBoxLoss(nn.Module):
1011
"""SSD Weighted Loss Function
1112
Compute Targets:
@@ -79,7 +80,7 @@ def forward(self, predictions, targets):
7980
conf_t = Variable(conf_t, requires_grad=False)
8081

8182
pos = conf_t > 0
82-
num_pos = pos.sum(keepdim=True)
83+
num_pos = pos.sum(dim=1, keepdim=True)
8384

8485
# Localization Loss (Smooth L1)
8586
# Shape: [batch,num_priors,4]

ssd.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,8 @@ def load_weights(self, base_file):
112112
other, ext = os.path.splitext(base_file)
113113
if ext == '.pkl' or '.pth':
114114
print('Loading weights into state dict...')
115-
self.load_state_dict(torch.load(base_file, map_location=lambda storage, loc: storage))
115+
self.load_state_dict(torch.load(base_file,
116+
map_location=lambda storage, loc: storage))
116117
print('Finished!')
117118
else:
118119
print('Sorry only .pth and .pkl files supported.')
@@ -199,7 +200,7 @@ def build_ssd(phase, size=300, num_classes=21):
199200
if size != 300:
200201
print("Error: Sorry only SSD300 is supported currently!")
201202
return
202-
base_,extras_,head_=multibox(vgg(base[str(size)], 3),
203-
add_extras(extras[str(size)], 1024),
204-
mbox[str(size)], num_classes)
205-
return SSD(phase,base_,extras_,head_, num_classes)
203+
base_, extras_, head_ = multibox(vgg(base[str(size)], 3),
204+
add_extras(extras[str(size)], 1024),
205+
mbox[str(size)], num_classes)
206+
return SSD(phase, base_, extras_, head_, num_classes)

0 commit comments

Comments
 (0)