@@ -147,6 +147,15 @@ def test_model_default_cfgs(model_name, batch_size):
147147 # FIXME mobilenetv3/ghostnet forward_features vs removed pooling differ
148148 assert outputs .shape [- 1 ] == pool_size [- 1 ] and outputs .shape [- 2 ] == pool_size [- 2 ]
149149
150+ if 'pruned' not in model_name : # FIXME better pruned model handling
151+ # test classifier + global pool deletion via __init__
152+ model = create_model (model_name , pretrained = False , num_classes = 0 , global_pool = '' ).eval ()
153+ outputs = model .forward (input_tensor )
154+ assert len (outputs .shape ) == 4
155+ if not isinstance (model , timm .models .MobileNetV3 ) and not isinstance (model , timm .models .GhostNet ):
156+ # FIXME mobilenetv3/ghostnet forward_features vs removed pooling differ
157+ assert outputs .shape [- 1 ] == pool_size [- 1 ] and outputs .shape [- 2 ] == pool_size [- 2 ]
158+
150159 # check classifier name matches default_cfg
151160 classifier = cfg ['classifier' ]
152161 if not isinstance (classifier , (tuple , list )):
@@ -193,6 +202,13 @@ def test_model_default_cfgs_non_std(model_name, batch_size):
193202 assert len (outputs .shape ) == 2
194203 assert outputs .shape [1 ] == model .num_features
195204
205+ model = create_model (model_name , pretrained = False , num_classes = 0 ).eval ()
206+ outputs = model .forward (input_tensor )
207+ if isinstance (outputs , tuple ):
208+ outputs = outputs [0 ]
209+ assert len (outputs .shape ) == 2
210+ assert outputs .shape [1 ] == model .num_features
211+
196212 # check classifier name matches default_cfg
197213 classifier = cfg ['classifier' ]
198214 if not isinstance (classifier , (tuple , list )):
@@ -217,6 +233,7 @@ def test_model_load_pretrained(model_name, batch_size):
217233 """Create that pretrained weights load, verify support for in_chans != 3 while doing so."""
218234 in_chans = 3 if 'pruned' in model_name else 1 # pruning not currently supported with in_chans change
219235 create_model (model_name , pretrained = True , in_chans = in_chans , num_classes = 5 )
236+ create_model (model_name , pretrained = True , in_chans = in_chans , num_classes = 0 )
220237
221238 @pytest .mark .timeout (120 )
222239 @pytest .mark .parametrize ('model_name' , list_models (pretrained = True , exclude_filters = NON_STD_FILTERS ))
0 commit comments