@@ -186,6 +186,18 @@ def test_model_forward(model_name, batch_size):
186186 assert outputs .shape [0 ] == batch_size
187187 assert not torch .isnan (outputs ).any (), 'Output included NaNs'
188188
189+ # Test that grad-checkpointing, if supported, doesn't cause model failures or change in output
190+ try :
191+ model .set_grad_checkpointing ()
192+ except :
193+ # throws if not supported, that's fine
194+ pass
195+ else :
196+ outputs2 = model (inputs )
197+ if isinstance (outputs , tuple ):
198+ outputs2 = torch .cat (outputs2 )
199+ assert torch .allclose (outputs , outputs2 , rtol = 1e-4 , atol = 1e-5 ), 'Output does not match'
200+
189201
190202@pytest .mark .base
191203@pytest .mark .timeout (timeout120 )
@@ -529,6 +541,20 @@ def test_model_forward_intermediates(model_name, batch_size):
529541 output2 = model .forward_features (inpt )
530542 assert torch .allclose (output , output2 )
531543
544+ # Test that grad-checkpointing, if supported
545+ try :
546+ model .set_grad_checkpointing ()
547+ except :
548+ # throws if not supported, that's fine
549+ pass
550+ else :
551+ output3 , _ = model .forward_intermediates (
552+ inpt ,
553+ output_fmt = output_fmt ,
554+ )
555+ assert torch .allclose (output , output3 , rtol = 1e-4 , atol = 1e-5 ), 'Output does not match'
556+
557+
532558
533559def _create_fx_model (model , train = False ):
534560 # This block of code does a bit of juggling to handle any case where there are multiple outputs in train mode
@@ -717,4 +743,4 @@ def test_model_forward_torchscript_with_features_fx(model_name, batch_size):
717743
718744 for tensor in outputs :
719745 assert tensor .shape [0 ] == batch_size
720- assert not torch .isnan (tensor ).any (), 'Output included NaNs'
746+ assert not torch .isnan (tensor ).any (), 'Output included NaNs'
0 commit comments