Skip to content

Commit 78f98a4

Browse files
committed
Update torch_pretrained_net.py
1 parent 42b42ca commit 78f98a4

File tree

1 file changed

+42
-5
lines changed

1 file changed

+42
-5
lines changed

pymic/net/cls/torch_pretrained_net.py

Lines changed: 42 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# pretrained models from pytorch: https://pytorch.org/vision/0.8/models.html
22
from __future__ import print_function, division
33

4+
import itertools
45
import torch
56
import torch.nn as nn
67
import torchvision.models as models
@@ -25,14 +26,19 @@ def __init__(self, params):
2526
super(ResNet18, self).__init__()
2627
self.params = params
2728
cls_num = params['class_num']
28-
in_chns = params.get('input_chns', 3)
29+
self.in_chns = params.get('input_chns', 3)
2930
self.pretrain = params.get('pretrain', True)
3031
self.update_layers = params.get('update_layers', 0)
3132
self.net = models.resnet18(pretrained = self.pretrain)
3233

3334
# replace the last layer
3435
num_ftrs = self.net.fc.in_features
3536
self.net.fc = nn.Linear(num_ftrs, cls_num)
37+
38+
# replace the first layer when in_chns is not 3
39+
if(self.in_chns != 3):
40+
self.net.conv1 = nn.Conv2d(self.in_chns, 64, kernel_size=(7, 7),
41+
stride=(2, 2), padding=(3, 3), bias=False)
3642

3743
def forward(self, x):
3844
return self.net(x)
@@ -41,7 +47,14 @@ def get_parameters_to_update(self):
4147
if(self.pretrain == False or self.update_layers == 0):
4248
return self.net.parameters()
4349
elif(self.update_layers == -1):
44-
return self.net.fc.parameters()
50+
params = self.net.fc.parameters()
51+
if(self.in_chns !=3):
52+
# combining the two iterables into a single one
53+
# see: https://dzone.com/articles/python-joining-multiple
54+
params = itertools.chain()
55+
for pram in [self.net.fc.parameters(), self.net.conv1.parameters()]:
56+
params = itertools.chain(params, pram)
57+
return params
4558
else:
4659
raise(ValueError("update_layers can only be 0 (all layers) " +
4760
"or -1 (the last layer)"))
@@ -51,14 +64,19 @@ def __init__(self, params):
5164
super(VGG16, self).__init__()
5265
self.params = params
5366
cls_num = params['class_num']
54-
in_chns = params.get('input_chns', 3)
67+
self.in_chns = params.get('input_chns', 3)
5568
self.pretrain = params.get('pretrain', True)
5669
self.update_layers = params.get('update_layers', 0)
5770
self.net = models.vgg16(pretrained = self.pretrain)
5871

5972
# replace the last layer
6073
num_ftrs = self.net.classifier[-1].in_features
6174
self.net.classifier[-1] = nn.Linear(num_ftrs, cls_num)
75+
76+
# replace the first layer when in_chns is not 3
77+
if(self.in_chns != 3):
78+
self.net.conv1 = nn.Conv2d(self.in_chns, 64, kernel_size=(7, 7),
79+
stride=(2, 2), padding=(3, 3), bias=False)
6280

6381
def forward(self, x):
6482
return self.net(x)
@@ -67,7 +85,14 @@ def get_parameters_to_update(self):
6785
if(self.pretrain == False or self.update_layers == 0):
6886
return self.net.parameters()
6987
elif(self.update_layers == -1):
70-
return self.net.classifier[-1].parameters()
88+
params = self.net.classifier[-1].parameters()
89+
if(self.in_chns !=3):
90+
# combining the two iterables into a single one
91+
# see: https://dzone.com/articles/python-joining-multiple
92+
params = itertools.chain()
93+
for pram in [self.net.classifier[-1].parameters(), self.net.conv1.parameters()]:
94+
params = itertools.chain(params, pram)
95+
return params
7196
else:
7297
raise(ValueError("update_layers can only be 0 (all layers) " +
7398
"or -1 (the last layer)"))
@@ -85,6 +110,11 @@ def __init__(self, params):
85110
# replace the last layer
86111
num_ftrs = self.net.last_channel
87112
self.net.classifier[-1] = nn.Linear(num_ftrs, cls_num)
113+
114+
# replace the first layer when in_chns is not 3
115+
if(self.in_chns != 3):
116+
self.net.conv1 = nn.Conv2d(self.in_chns, 64, kernel_size=(7, 7),
117+
stride=(2, 2), padding=(3, 3), bias=False)
88118

89119
def forward(self, x):
90120
return self.net(x)
@@ -93,7 +123,14 @@ def get_parameters_to_update(self):
93123
if(self.pretrain == False or self.update_layers == 0):
94124
return self.net.parameters()
95125
elif(self.update_layers == -1):
96-
return self.net.classifier[-1].parameters()
126+
params = self.net.classifier[-1].parameters()
127+
if(self.in_chns !=3):
128+
# combining the two iterables into a single one
129+
# see: https://dzone.com/articles/python-joining-multiple
130+
params = itertools.chain()
131+
for pram in [self.net.classifier[-1].parameters(), self.net.conv1.parameters()]:
132+
params = itertools.chain(params, pram)
133+
return params
97134
else:
98135
raise(ValueError("update_layers can only be 0 (all layers) " +
99136
"or -1 (the last layer)"))

0 commit comments

Comments
 (0)