File tree Expand file tree Collapse file tree 1 file changed +7
-1
lines changed Expand file tree Collapse file tree 1 file changed +7
-1
lines changed Original file line number Diff line number Diff line change @@ -210,6 +210,7 @@ def test_model_backward(model_name, batch_size):
210210 pytest .skip ("Fixed input size model > limit." )
211211
212212 model = create_model (model_name , pretrained = False , num_classes = 42 )
213+ encoder_only = model .num_classes == 0 # FIXME better approach?
213214 num_params = sum ([x .numel () for x in model .parameters ()])
214215 model .train ()
215216
@@ -224,7 +225,12 @@ def test_model_backward(model_name, batch_size):
224225 assert x .grad is not None , f'No gradient for { n } '
225226 num_grad = sum ([x .grad .numel () for x in model .parameters () if x .grad is not None ])
226227
227- assert outputs .shape [- 1 ] == 42
228+ if encoder_only :
229+ output_fmt = getattr (model , 'output_fmt' , 'NCHW' )
230+ feat_axis = get_channel_dim (output_fmt )
231+ assert outputs .shape [feat_axis ] == model .num_features , f'unpooled feature dim { outputs .shape [feat_axis ]} != model.num_features { model .num_features } '
232+ else :
233+ assert outputs .shape [- 1 ] == 42
228234 assert num_params == num_grad , 'Some parameters are missing gradients'
229235 assert not torch .isnan (outputs ).any (), 'Output included NaNs'
230236
You can’t perform that action at this time.
0 commit comments