@@ -169,6 +169,18 @@ def test_model_backward(model_name, batch_size):
169169 assert not torch .isnan (outputs ).any (), 'Output included NaNs'
170170
171171
172+ # models with extra conv/linear layers after pooling
173+ EARLY_POOL_MODELS = (
174+ timm .models .EfficientVit ,
175+ timm .models .EfficientVitLarge ,
176+ timm .models .HighPerfGpuNet ,
177+ timm .models .GhostNet ,
178+ timm .models .MetaNeXt , # InceptionNeXt
179+ timm .models .MobileNetV3 ,
180+ timm .models .RepGhostNet ,
181+ timm .models .VGG ,
182+ )
183+
172184@pytest .mark .cfg
173185@pytest .mark .timeout (timeout300 )
174186@pytest .mark .parametrize ('model_name' , list_models (
@@ -179,6 +191,9 @@ def test_model_default_cfgs(model_name, batch_size):
179191 model = create_model (model_name , pretrained = False )
180192 model .eval ()
181193 model .to (torch_device )
194+ assert getattr (model , 'num_classes' ) >= 0
195+ assert getattr (model , 'num_features' ) > 0
196+ assert getattr (model , 'head_hidden_size' ) > 0
182197 state_dict = model .state_dict ()
183198 cfg = model .default_cfg
184199
@@ -195,37 +210,37 @@ def test_model_default_cfgs(model_name, batch_size):
195210 input_size = tuple ([min (x , MAX_FWD_OUT_SIZE ) for x in input_size ])
196211 input_tensor = torch .randn ((batch_size , * input_size ), device = torch_device )
197212
198- # test forward_features (always unpooled)
213+ # test forward_features (always unpooled) & forward_head w/ pre_logits
199214 outputs = model .forward_features (input_tensor )
200- assert outputs .shape [spatial_axis [0 ]] == pool_size [0 ], 'unpooled feature shape != config'
201- assert outputs .shape [spatial_axis [1 ]] == pool_size [1 ], 'unpooled feature shape != config'
202- if not isinstance (model , (timm .models .MobileNetV3 , timm .models .GhostNet , timm .models .RepGhostNet , timm .models .VGG )):
203- assert outputs .shape [feat_axis ] == model .num_features
215+ outputs_pre = model .forward_head (outputs , pre_logits = True )
216+ assert outputs .shape [spatial_axis [0 ]] == pool_size [0 ], f'unpooled feature shape { outputs .shape } != config'
217+ assert outputs .shape [spatial_axis [1 ]] == pool_size [1 ], f'unpooled feature shape { outputs .shape } != config'
218+ assert outputs .shape [feat_axis ] == model .num_features , f'unpooled feature dim { outputs .shape [feat_axis ]} != model.num_features { model .num_features } '
219+ assert outputs_pre .shape [1 ] == model .head_hidden_size , f'pre_logits feature dim { outputs_pre .shape [1 ]} != model.head_hidden_size { model .head_hidden_size } '
204220
205221 # test forward after deleting the classifier, output should be poooled, size(-1) == model.num_features
206222 model .reset_classifier (0 )
207223 model .to (torch_device )
208224 outputs = model .forward (input_tensor )
209225 assert len (outputs .shape ) == 2
210- assert outputs .shape [1 ] == model .num_features
226+ assert outputs .shape [1 ] == model .head_hidden_size , f'feature dim w/ removed classifier { outputs .shape [1 ]} != model.head_hidden_size { model .head_hidden_size } '
227+ assert outputs .shape == outputs_pre .shape , f'output shape of pre_logits { outputs_pre .shape } does not match reset_head(0) { outputs .shape } '
211228
212- # test model forward without pooling and classifier
213- model .reset_classifier (0 , '' ) # reset classifier and set global pooling to pass-through
214- model .to (torch_device )
215- outputs = model .forward (input_tensor )
216- assert len (outputs .shape ) == 4
217- if not isinstance (model , (timm .models .MobileNetV3 , timm .models .GhostNet , timm .models .RepGhostNet , timm .models .VGG )):
218- # mobilenetv3/ghostnet/repghostnet/vgg forward_features vs removed pooling differ due to location or lack of GAP
229+ # test model forward after removing pooling and classifier
230+ if not isinstance (model , EARLY_POOL_MODELS ):
231+ model .reset_classifier (0 , '' ) # reset classifier and disable global pooling
232+ model .to (torch_device )
233+ outputs = model .forward (input_tensor )
234+ assert len (outputs .shape ) == 4
219235 assert outputs .shape [spatial_axis [0 ]] == pool_size [0 ] and outputs .shape [spatial_axis [1 ]] == pool_size [1 ]
220236
221- if 'pruned' not in model_name : # FIXME better pruned model handling
222- # test classifier + global pool deletion via __init__
237+ # test classifier + global pool deletion via __init__
238+ if 'pruned' not in model_name and not isinstance ( model , EARLY_POOL_MODELS ):
223239 model = create_model (model_name , pretrained = False , num_classes = 0 , global_pool = '' ).eval ()
224240 model .to (torch_device )
225241 outputs = model .forward (input_tensor )
226242 assert len (outputs .shape ) == 4
227- if not isinstance (model , (timm .models .MobileNetV3 , timm .models .GhostNet , timm .models .RepGhostNet , timm .models .VGG )):
228- assert outputs .shape [spatial_axis [0 ]] == pool_size [0 ] and outputs .shape [spatial_axis [1 ]] == pool_size [1 ]
243+ assert outputs .shape [spatial_axis [0 ]] == pool_size [0 ] and outputs .shape [spatial_axis [1 ]] == pool_size [1 ]
229244
230245 # check classifier name matches default_cfg
231246 if cfg .get ('num_classes' , None ):
@@ -253,6 +268,9 @@ def test_model_default_cfgs_non_std(model_name, batch_size):
253268 model = create_model (model_name , pretrained = False )
254269 model .eval ()
255270 model .to (torch_device )
271+ assert getattr (model , 'num_classes' ) >= 0
272+ assert getattr (model , 'num_features' ) > 0
273+ assert getattr (model , 'head_hidden_size' ) > 0
256274 state_dict = model .state_dict ()
257275 cfg = model .default_cfg
258276
@@ -264,13 +282,15 @@ def test_model_default_cfgs_non_std(model_name, batch_size):
264282 feat_dim = getattr (model , 'feature_dim' , None )
265283
266284 outputs = model .forward_features (input_tensor )
285+ outputs_pre = model .forward_head (outputs , pre_logits = True )
267286 if isinstance (outputs , (tuple , list )):
268287 # cannot currently verify multi-tensor output.
269288 pass
270289 else :
271290 if feat_dim is None :
272291 feat_dim = - 1 if outputs .ndim == 3 else 1
273292 assert outputs .shape [feat_dim ] == model .num_features
293+ assert outputs_pre .shape [1 ] == model .head_hidden_size
274294
275295 # test forward after deleting the classifier, output should be poooled, size(-1) == model.num_features
276296 model .reset_classifier (0 )
@@ -280,7 +300,8 @@ def test_model_default_cfgs_non_std(model_name, batch_size):
280300 outputs = outputs [0 ]
281301 if feat_dim is None :
282302 feat_dim = - 1 if outputs .ndim == 3 else 1
283- assert outputs .shape [feat_dim ] == model .num_features , 'pooled num_features != config'
303+ assert outputs .shape [feat_dim ] == model .head_hidden_size , 'pooled num_features != config'
304+ assert outputs .shape == outputs_pre .shape
284305
285306 model = create_model (model_name , pretrained = False , num_classes = 0 ).eval ()
286307 model .to (torch_device )
0 commit comments