1313import torchvision .datasets as datasets
1414import torchvision .models as models
1515
16- import vqa .lib .utils as utils
1716import vqa .datasets .coco as coco
1817from vqa .lib .dataloader import DataLoader
18+ from vqa .models .utils import ResNet
19+ from vqa .lib .logger import AvgMeter
1920
2021model_names = sorted (name for name in models .__dict__
2122 if name .islower () and name .startswith ("resnet" )
3334 ' (default: resnet152)' )
3435parser .add_argument ('--workers' , default = 4 , type = int , metavar = 'N' ,
3536 help = 'number of data loading workers (default: 8)' )
36- parser .add_argument ('--batch_size' , default = 10 , type = int , metavar = 'N' ,
37- help = 'mini-batch size (default: 10 )' )
37+ parser .add_argument ('--batch_size' , '-b' , default = 80 , type = int , metavar = 'N' ,
38+ help = 'mini-batch size (default: 80 )' )
3839parser .add_argument ('--mode' , default = 'both' , type = str ,
3940 help = 'Options: att | noatt | (default) both' )
4041
@@ -44,7 +45,7 @@ def main():
4445
4546 print ("=> using pre-trained model '{}'" .format (args .arch ))
4647 model = models .__dict__ [args .arch ](pretrained = True )
47- model = ResNet (model )
48+ model = ResNet (model , False )
4849 model = nn .DataParallel (model ).cuda ()
4950
5051 #extract_name = 'arch,{}_layer,{}_resize,{}'.format()
@@ -54,7 +55,7 @@ def main():
5455 normalize = transforms .Normalize (mean = [0.485 , 0.456 , 0.406 ],
5556 std = [0.229 , 0.224 , 0.225 ])
5657
57- dataset = coco .COCOImages (args .data_split , args .dir_data ,
58+ dataset = coco .COCOImages (args .data_split , dict ( dir = args .dir_data ) ,
5859 transform = transforms .Compose ([
5960 transforms .Scale (448 ),
6061 transforms .CenterCrop (448 ),
@@ -90,15 +91,13 @@ def extract(data_loader, model, path_file, mode):
9091
9192 model .eval ()
9293
93- batch_time = utils . AverageMeter ()
94- data_time = utils . AverageMeter ()
94+ batch_time = AvgMeter ()
95+ data_time = AvgMeter ()
9596 begin = time .time ()
9697 end = time .time ()
9798
9899 idx = 0
99- image_names = []
100100 for i , input in enumerate (data_loader ):
101- image_names .append (input ['name' ])
102101 input_var = torch .autograd .Variable (input ['visual' ], volatile = True )
103102 output_att = model (input_var )
104103
@@ -124,34 +123,15 @@ def extract(data_loader, model, path_file, mode):
124123 data_time = data_time ,))
125124
126125 hdf5_file .close ()
126+
127+ # Saving image names in the same order than extraction
127128 with open (path_txt , 'w' ) as handle :
128- for name in image_names :
129+ for name in data_loader . dataset . dataset . imgs :
129130 handle .write (name + '\n ' )
130131
131132 end = time .time () - begin
132133 print ('Finished in {}m and {}s' .format (int (end / 60 ), int (end % 60 )))
133134
134135
135-
136- class ResNet (nn .Module ):
137-
138- def __init__ (self , resnet ):
139- super (ResNet , self ).__init__ ()
140- self .resnet = resnet
141-
142- def forward (self , x ):
143- x = self .resnet .conv1 (x )
144- x = self .resnet .bn1 (x )
145- x = self .resnet .relu (x )
146- x = self .resnet .maxpool (x )
147- x = self .resnet .layer1 (x )
148- x = self .resnet .layer2 (x )
149- x = self .resnet .layer3 (x )
150- x = self .resnet .layer4 (x )
151- # x = self.avgpool(x)
152- # x = x.view(x.size(0), -1)
153- # x = self.fc(x)
154- return x
155-
156136if __name__ == '__main__' :
157137 main ()
0 commit comments