Skip to content

Commit 49a326f

Browse files
committed
fix VOC references
1 parent 66faf9c commit 49a326f

File tree

1 file changed

+14
-7
lines changed

1 file changed

+14
-7
lines changed

test.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@
77
import torch.backends.cudnn as cudnn
88
import torchvision.transforms as transforms
99
from torch.autograd import Variable
10-
from data import VOCroot, VOC_CLASSES as labelmap
10+
from data import VOC_ROOT, VOC_CLASSES as labelmap
1111
from PIL import Image
12-
from data import AnnotationTransform, VOCDetection, BaseTransform, VOC_CLASSES
12+
from data import VOCAnnotationTransform, VOCDetection, BaseTransform, VOC_CLASSES
1313
import torch.utils.data as data
1414
from ssd import build_ssd
1515

@@ -20,12 +20,16 @@
2020
help='Dir to save results')
2121
parser.add_argument('--visual_threshold', default=0.6, type=float,
2222
help='Final confidence threshold')
23-
parser.add_argument('--cuda', default=False, type=bool,
23+
parser.add_argument('--cuda', default=True, type=bool,
2424
help='Use cuda to train model')
25-
parser.add_argument('--voc_root', default=VOCroot, help='Location of VOC root directory')
26-
25+
parser.add_argument('--voc_root', default=VOC_ROOT, help='Location of VOC root directory')
2726
args = parser.parse_args()
2827

28+
if args.cuda and torch.cuda.is_available():
29+
torch.set_default_tensor_type('torch.cuda.FloatTensor')
30+
else:
31+
torch.set_default_tensor_type('torch.FloatTensor')
32+
2933
if not os.path.exists(args.save_folder):
3034
os.mkdir(args.save_folder)
3135

@@ -71,19 +75,22 @@ def test_net(save_folder, net, cuda, testset, transform, thresh):
7175
j += 1
7276

7377

74-
if __name__ == '__main__':
78+
def test_voc():
7579
# load net
7680
num_classes = len(VOC_CLASSES) + 1 # +1 background
7781
net = build_ssd('test', 300, num_classes) # initialize SSD
7882
net.load_state_dict(torch.load(args.trained_model))
7983
net.eval()
8084
print('Finished loading model!')
8185
# load data
82-
testset = VOCDetection(args.voc_root, [('2007', 'test')], None, AnnotationTransform())
86+
testset = VOCDetection(args.voc_root, [('2007', 'test')], None, VOCAnnotationTransform())
8387
if args.cuda:
8488
net = net.cuda()
8589
cudnn.benchmark = True
8690
# evaluation
8791
test_net(args.save_folder, net, args.cuda, testset,
8892
BaseTransform(net.size, (104, 117, 123)),
8993
thresh=args.visual_threshold)
94+
95+
if __name__ == '__main__':
96+
test_voc()

0 commit comments

Comments
 (0)