11import torch
2- import torch .nn as nn
3- import torch .backends .cudnn as cudnn
42from torch .autograd import Function
5- from torch .autograd import Variable
63from ..box_utils import decode , nms
74from data import v2 as cfg
85
@@ -23,7 +20,6 @@ def __init__(self, num_classes, bkg_label, top_k, conf_thresh, nms_thresh):
2320 raise ValueError ('nms_threshold must be non negative.' )
2421 self .conf_thresh = conf_thresh
2522 self .variance = cfg ['variance' ]
26- self .output = torch .zeros (1 , self .num_classes , self .top_k , 5 )
2723
2824 def forward (self , loc_data , conf_data , prior_data ):
2925 """
@@ -37,21 +33,16 @@ def forward(self, loc_data, conf_data, prior_data):
3733 """
3834 num = loc_data .size (0 ) # batch size
3935 num_priors = prior_data .size (0 )
40- self .output .zero_ ()
41- if num == 1 :
42- # size batch x num_classes x num_priors
43- conf_preds = conf_data .t ().contiguous ().unsqueeze (0 )
44- else :
45- conf_preds = conf_data .view (num , num_priors ,
46- self .num_classes ).transpose (2 , 1 )
47- self .output .expand_ (num , self .num_classes , self .top_k , 5 )
36+ output = torch .zeros (num , self .num_classes , self .top_k , 5 )
37+ conf_preds = conf_data .view (num , num_priors ,
38+ self .num_classes ).transpose (2 , 1 )
4839
4940 # Decode predictions into bboxes.
5041 for i in range (num ):
5142 decoded_boxes = decode (loc_data [i ], prior_data , self .variance )
5243 # For each class, perform nms
5344 conf_scores = conf_preds [i ].clone ()
54- num_det = 0
45+
5546 for cl in range (1 , self .num_classes ):
5647 c_mask = conf_scores [cl ].gt (self .conf_thresh )
5748 scores = conf_scores [cl ][c_mask ]
@@ -61,11 +52,11 @@ def forward(self, loc_data, conf_data, prior_data):
6152 boxes = decoded_boxes [l_mask ].view (- 1 , 4 )
6253 # idx of highest scoring and non-overlapping boxes per class
6354 ids , count = nms (boxes , scores , self .nms_thresh , self .top_k )
64- self . output [i , cl , :count ] = \
55+ output [i , cl , :count ] = \
6556 torch .cat ((scores [ids [:count ]].unsqueeze (1 ),
6657 boxes [ids [:count ]]), 1 )
67- flt = self . output .view (- 1 , 5 )
68- _ , idx = flt [:, 0 ].sort (0 )
69- _ , rank = idx .sort (0 )
70- flt [(rank >= self .top_k ).unsqueeze (1 ).expand_as (flt )].fill_ (0 )
71- return self . output
58+ flt = output .contiguous (). view (num , - 1 , 5 )
59+ _ , idx = flt [:, :, 0 ].sort (1 , descending = True )
60+ _ , rank = idx .sort (1 )
61+ flt [(rank < self .top_k ).unsqueeze (- 1 ).expand_as (flt )].fill_ (0 )
62+ return output
0 commit comments