Skip to content

Commit 0b9f5f4

Browse files
committed
Fix extract.py
1 parent e41082c commit 0b9f5f4

File tree

2 files changed

+59
-31
lines changed

2 files changed

+59
-31
lines changed

extract.py

Lines changed: 11 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,10 @@
1313
import torchvision.datasets as datasets
1414
import torchvision.models as models
1515

16-
import vqa.lib.utils as utils
1716
import vqa.datasets.coco as coco
1817
from vqa.lib.dataloader import DataLoader
18+
from vqa.models.utils import ResNet
19+
from vqa.lib.logger import AvgMeter
1920

2021
model_names = sorted(name for name in models.__dict__
2122
if name.islower() and name.startswith("resnet")
@@ -33,8 +34,8 @@
3334
' (default: resnet152)')
3435
parser.add_argument('--workers', default=4, type=int, metavar='N',
3536
help='number of data loading workers (default: 8)')
36-
parser.add_argument('--batch_size', default=10, type=int, metavar='N',
37-
help='mini-batch size (default: 10)')
37+
parser.add_argument('--batch_size', '-b', default=80, type=int, metavar='N',
38+
help='mini-batch size (default: 80)')
3839
parser.add_argument('--mode', default='both', type=str,
3940
help='Options: att | noatt | (default) both')
4041

@@ -44,7 +45,7 @@ def main():
4445

4546
print("=> using pre-trained model '{}'".format(args.arch))
4647
model = models.__dict__[args.arch](pretrained=True)
47-
model = ResNet(model)
48+
model = ResNet(model, False)
4849
model = nn.DataParallel(model).cuda()
4950

5051
#extract_name = 'arch,{}_layer,{}_resize,{}'.format()
@@ -54,7 +55,7 @@ def main():
5455
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
5556
std=[0.229, 0.224, 0.225])
5657

57-
dataset = coco.COCOImages(args.data_split, args.dir_data,
58+
dataset = coco.COCOImages(args.data_split, dict(dir=args.dir_data),
5859
transform=transforms.Compose([
5960
transforms.Scale(448),
6061
transforms.CenterCrop(448),
@@ -90,15 +91,13 @@ def extract(data_loader, model, path_file, mode):
9091

9192
model.eval()
9293

93-
batch_time = utils.AverageMeter()
94-
data_time = utils.AverageMeter()
94+
batch_time = AvgMeter()
95+
data_time = AvgMeter()
9596
begin = time.time()
9697
end = time.time()
9798

9899
idx = 0
99-
image_names = []
100100
for i, input in enumerate(data_loader):
101-
image_names.append(input['name'])
102101
input_var = torch.autograd.Variable(input['visual'], volatile=True)
103102
output_att = model(input_var)
104103

@@ -124,34 +123,15 @@ def extract(data_loader, model, path_file, mode):
124123
data_time=data_time,))
125124

126125
hdf5_file.close()
126+
127+
# Saving image names in the same order than extraction
127128
with open(path_txt, 'w') as handle:
128-
for name in image_names:
129+
for name in data_loader.dataset.dataset.imgs:
129130
handle.write(name + '\n')
130131

131132
end = time.time() - begin
132133
print('Finished in {}m and {}s'.format(int(end/60), int(end%60)))
133134

134135

135-
136-
class ResNet(nn.Module):
137-
138-
def __init__(self, resnet):
139-
super(ResNet, self).__init__()
140-
self.resnet = resnet
141-
142-
def forward(self, x):
143-
x = self.resnet.conv1(x)
144-
x = self.resnet.bn1(x)
145-
x = self.resnet.relu(x)
146-
x = self.resnet.maxpool(x)
147-
x = self.resnet.layer1(x)
148-
x = self.resnet.layer2(x)
149-
x = self.resnet.layer3(x)
150-
x = self.resnet.layer4(x)
151-
# x = self.avgpool(x)
152-
# x = x.view(x.size(0), -1)
153-
# x = self.fc(x)
154-
return x
155-
156136
if __name__ == '__main__':
157137
main()

vqa/models/utils.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import sys
2+
import torch
3+
import torch.nn as nn
4+
import torchvision.models as models
5+
6+
from .noatt import MLBNoAtt, MutanNoAtt
7+
from .att import MLBAtt, MutanAtt
8+
9+
class ResNet(nn.Module):
10+
11+
def __init__(self, resnet, pooling, fix_until=None):
12+
# pooling: boolean
13+
# fix_until: None or layer name (included)
14+
super(ResNet, self).__init__()
15+
self.resnet = resnet
16+
self.pooling = pooling
17+
if fix_until is not None:
18+
self.fixable_layers = [
19+
'conv1', 'bn1', 'relu', 'maxpool',
20+
'layer1', 'layer2', 'layer3', 'layer4']
21+
if fix_until in self.fixable_layers:
22+
self.fix_until = fix_until
23+
self._fix_layers(fix_until)
24+
else:
25+
raise ValueError
26+
27+
def _fix_layers(self, fix_until):
28+
for layer in self.fixable_layers:
29+
print('Warning models/utils.py: Fix cnn layer '+layer)
30+
for p in getattr(self.resnet, layer).parameters():
31+
p.requires_grad = False
32+
if layer == self.fix_until:
33+
break
34+
35+
def forward(self, x):
36+
x = self.resnet.conv1(x)
37+
x = self.resnet.bn1(x)
38+
x = self.resnet.relu(x)
39+
x = self.resnet.maxpool(x)
40+
x = self.resnet.layer1(x)
41+
x = self.resnet.layer2(x)
42+
x = self.resnet.layer3(x)
43+
x = self.resnet.layer4(x)
44+
if self.pooling:
45+
x = self.resnet.avgpool(x)
46+
x = x.view(x.size(0), -1)
47+
# x = self.fc(x)
48+
return x

0 commit comments

Comments
 (0)