|
7 | 7 | import torch.backends.cudnn as cudnn |
8 | 8 | import torchvision.transforms as transforms |
9 | 9 | from torch.autograd import Variable |
10 | | -from data import VOCroot, VOC_CLASSES as labelmap |
| 10 | +from data import VOC_ROOT, VOC_CLASSES as labelmap |
11 | 11 | from PIL import Image |
12 | | -from data import AnnotationTransform, VOCDetection, BaseTransform, VOC_CLASSES |
| 12 | +from data import VOCAnnotationTransform, VOCDetection, BaseTransform, VOC_CLASSES |
13 | 13 | import torch.utils.data as data |
14 | 14 | from ssd import build_ssd |
15 | 15 |
|
|
20 | 20 | help='Dir to save results') |
21 | 21 | parser.add_argument('--visual_threshold', default=0.6, type=float, |
22 | 22 | help='Final confidence threshold') |
23 | | -parser.add_argument('--cuda', default=False, type=bool, |
| 23 | +parser.add_argument('--cuda', default=True, type=bool, |
24 | 24 | help='Use cuda to train model') |
25 | | -parser.add_argument('--voc_root', default=VOCroot, help='Location of VOC root directory') |
26 | | - |
| 25 | +parser.add_argument('--voc_root', default=VOC_ROOT, help='Location of VOC root directory') |
27 | 26 | args = parser.parse_args() |
28 | 27 |
|
| 28 | +if args.cuda and torch.cuda.is_available(): |
| 29 | + torch.set_default_tensor_type('torch.cuda.FloatTensor') |
| 30 | +else: |
| 31 | + torch.set_default_tensor_type('torch.FloatTensor') |
| 32 | + |
29 | 33 | if not os.path.exists(args.save_folder): |
30 | 34 | os.mkdir(args.save_folder) |
31 | 35 |
|
@@ -71,19 +75,22 @@ def test_net(save_folder, net, cuda, testset, transform, thresh): |
71 | 75 | j += 1 |
72 | 76 |
|
73 | 77 |
|
74 | | -if __name__ == '__main__': |
| 78 | +def test_voc(): |
75 | 79 | # load net |
76 | 80 | num_classes = len(VOC_CLASSES) + 1 # +1 background |
77 | 81 | net = build_ssd('test', 300, num_classes) # initialize SSD |
78 | 82 | net.load_state_dict(torch.load(args.trained_model)) |
79 | 83 | net.eval() |
80 | 84 | print('Finished loading model!') |
81 | 85 | # load data |
82 | | - testset = VOCDetection(args.voc_root, [('2007', 'test')], None, AnnotationTransform()) |
| 86 | + testset = VOCDetection(args.voc_root, [('2007', 'test')], None, VOCAnnotationTransform()) |
83 | 87 | if args.cuda: |
84 | 88 | net = net.cuda() |
85 | 89 | cudnn.benchmark = True |
86 | 90 | # evaluation |
87 | 91 | test_net(args.save_folder, net, args.cuda, testset, |
88 | 92 | BaseTransform(net.size, (104, 117, 123)), |
89 | 93 | thresh=args.visual_threshold) |
| 94 | + |
| 95 | +if __name__ == '__main__': |
| 96 | + test_voc() |
0 commit comments