@@ -130,18 +130,19 @@ def forward(self, x):
130130
131131
132132def make_stem (cfg : TModelCfg ) -> nn .Sequential : # type: ignore
133+ len_stem = len (cfg .stem_sizes )
133134 stem : List [tuple [str , nn .Module ]] = [
134135 (f"conv_{ i } " , cfg .conv_layer (
135- cfg .stem_sizes [i ] , # type: ignore
136- cfg .stem_sizes [i + 1 ],
136+ cfg .stem_sizes [i - 1 ] if i else cfg . in_chans , # type: ignore
137+ cfg .stem_sizes [i ],
137138 stride = 2 if i == cfg .stem_stride_on else 1 ,
138139 bn_layer = (not cfg .stem_bn_end )
139- if i == (len ( cfg . stem_sizes ) - 2 )
140+ if i == (len_stem - 1 )
140141 else True ,
141142 act_fn = cfg .act_fn ,
142143 bn_1st = cfg .bn_1st ,
143144 ),)
144- for i in range (len ( cfg . stem_sizes ) - 1 )
145+ for i in range (len_stem )
145146 ]
146147 if cfg .stem_pool :
147148 stem .append (("stem_pool" , cfg .stem_pool ()))
@@ -262,8 +263,6 @@ class ModelConstructor(ModelCfg):
262263
263264 @root_validator
264265 def post_init (cls , values ): # pylint: disable=E0213
265- if values ["stem_sizes" ][0 ] != values ["in_chans" ]:
266- values ["stem_sizes" ] = [values ["in_chans" ]] + values ["stem_sizes" ]
267266 if values ["se" ] and isinstance (values ["se" ], (bool , int )): # if se=1 or se=True
268267 values ["se" ] = SEModule
269268 if values ["sa" ] and isinstance (values ["sa" ], (bool , int )): # if sa=1 or sa=True
0 commit comments