11# pretrained models from pytorch: https://pytorch.org/vision/0.8/models.html
22from __future__ import print_function , division
33
4+ import itertools
45import torch
56import torch .nn as nn
67import torchvision .models as models
@@ -25,14 +26,19 @@ def __init__(self, params):
2526 super (ResNet18 , self ).__init__ ()
2627 self .params = params
2728 cls_num = params ['class_num' ]
28- in_chns = params .get ('input_chns' , 3 )
29+ self . in_chns = params .get ('input_chns' , 3 )
2930 self .pretrain = params .get ('pretrain' , True )
3031 self .update_layers = params .get ('update_layers' , 0 )
3132 self .net = models .resnet18 (pretrained = self .pretrain )
3233
3334 # replace the last layer
3435 num_ftrs = self .net .fc .in_features
3536 self .net .fc = nn .Linear (num_ftrs , cls_num )
37+
38+ # replace the first layer when in_chns is not 3
39+ if (self .in_chns != 3 ):
40+ self .net .conv1 = nn .Conv2d (self .in_chns , 64 , kernel_size = (7 , 7 ),
41+ stride = (2 , 2 ), padding = (3 , 3 ), bias = False )
3642
3743 def forward (self , x ):
3844 return self .net (x )
@@ -41,7 +47,14 @@ def get_parameters_to_update(self):
4147 if (self .pretrain == False or self .update_layers == 0 ):
4248 return self .net .parameters ()
4349 elif (self .update_layers == - 1 ):
44- return self .net .fc .parameters ()
50+ params = self .net .fc .parameters ()
51+ if (self .in_chns != 3 ):
52+ # combining the two iterables into a single one
53+ # see: https://dzone.com/articles/python-joining-multiple
54+ params = itertools .chain ()
55+ for pram in [self .net .fc .parameters (), self .net .conv1 .parameters ()]:
56+ params = itertools .chain (params , pram )
57+ return params
4558 else :
4659 raise (ValueError ("update_layers can only be 0 (all layers) " +
4760 "or -1 (the last layer)" ))
@@ -51,14 +64,19 @@ def __init__(self, params):
5164 super (VGG16 , self ).__init__ ()
5265 self .params = params
5366 cls_num = params ['class_num' ]
54- in_chns = params .get ('input_chns' , 3 )
67+ self . in_chns = params .get ('input_chns' , 3 )
5568 self .pretrain = params .get ('pretrain' , True )
5669 self .update_layers = params .get ('update_layers' , 0 )
5770 self .net = models .vgg16 (pretrained = self .pretrain )
5871
5972 # replace the last layer
6073 num_ftrs = self .net .classifier [- 1 ].in_features
6174 self .net .classifier [- 1 ] = nn .Linear (num_ftrs , cls_num )
75+
76+ # replace the first layer when in_chns is not 3
77+ if (self .in_chns != 3 ):
78+ self .net .conv1 = nn .Conv2d (self .in_chns , 64 , kernel_size = (7 , 7 ),
79+ stride = (2 , 2 ), padding = (3 , 3 ), bias = False )
6280
6381 def forward (self , x ):
6482 return self .net (x )
@@ -67,7 +85,14 @@ def get_parameters_to_update(self):
6785 if (self .pretrain == False or self .update_layers == 0 ):
6886 return self .net .parameters ()
6987 elif (self .update_layers == - 1 ):
70- return self .net .classifier [- 1 ].parameters ()
88+ params = self .net .classifier [- 1 ].parameters ()
89+ if (self .in_chns != 3 ):
90+ # combining the two iterables into a single one
91+ # see: https://dzone.com/articles/python-joining-multiple
92+ params = itertools .chain ()
93+ for pram in [self .net .classifier [- 1 ].parameters (), self .net .conv1 .parameters ()]:
94+ params = itertools .chain (params , pram )
95+ return params
7196 else :
7297 raise (ValueError ("update_layers can only be 0 (all layers) " +
7398 "or -1 (the last layer)" ))
@@ -85,6 +110,11 @@ def __init__(self, params):
85110 # replace the last layer
86111 num_ftrs = self .net .last_channel
87112 self .net .classifier [- 1 ] = nn .Linear (num_ftrs , cls_num )
113+
114+ # replace the first layer when in_chns is not 3
115+ if (self .in_chns != 3 ):
116+ self .net .conv1 = nn .Conv2d (self .in_chns , 64 , kernel_size = (7 , 7 ),
117+ stride = (2 , 2 ), padding = (3 , 3 ), bias = False )
88118
89119 def forward (self , x ):
90120 return self .net (x )
@@ -93,7 +123,14 @@ def get_parameters_to_update(self):
93123 if (self .pretrain == False or self .update_layers == 0 ):
94124 return self .net .parameters ()
95125 elif (self .update_layers == - 1 ):
96- return self .net .classifier [- 1 ].parameters ()
126+ params = self .net .classifier [- 1 ].parameters ()
127+ if (self .in_chns != 3 ):
128+ # combining the two iterables into a single one
129+ # see: https://dzone.com/articles/python-joining-multiple
130+ params = itertools .chain ()
131+ for pram in [self .net .classifier [- 1 ].parameters (), self .net .conv1 .parameters ()]:
132+ params = itertools .chain (params , pram )
133+ return params
97134 else :
98135 raise (ValueError ("update_layers can only be 0 (all layers) " +
99136 "or -1 (the last layer)" ))
0 commit comments