diff --git a/layers/modules/multibox_loss.py b/layers/modules/multibox_loss.py index fb49cf439..40436de12 100644 --- a/layers/modules/multibox_loss.py +++ b/layers/modules/multibox_loss.py @@ -32,8 +32,10 @@ class MultiBoxLoss(nn.Module): def __init__(self, num_classes, overlap_thresh, prior_for_matching, bkg_label, neg_mining, neg_pos, neg_overlap, encode_target, - use_gpu=True): + use_gpu=None): super(MultiBoxLoss, self).__init__() + if use_gpu is None: + use_gpu = torch.cuda.is_available() self.use_gpu = use_gpu self.num_classes = num_classes self.threshold = overlap_thresh