Skip to content

Commit 1f9eb66

Browse files
committed
Add basic grad checkpointing tests
1 parent b6692ed commit 1f9eb66

File tree

1 file changed

+27
-1
lines changed

1 file changed

+27
-1
lines changed

tests/test_models.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,18 @@ def test_model_forward(model_name, batch_size):
186186
assert outputs.shape[0] == batch_size
187187
assert not torch.isnan(outputs).any(), 'Output included NaNs'
188188

189+
# Test that grad-checkpointing, if supported, doesn't cause model failures or change in output
190+
try:
191+
model.set_grad_checkpointing()
192+
except:
193+
# throws if not supported, that's fine
194+
pass
195+
else:
196+
outputs2 = model(inputs)
197+
if isinstance(outputs, tuple):
198+
outputs2 = torch.cat(outputs2)
199+
assert torch.allclose(outputs, outputs2, rtol=1e-4, atol=1e-5), 'Output does not match'
200+
189201

190202
@pytest.mark.base
191203
@pytest.mark.timeout(timeout120)
@@ -529,6 +541,20 @@ def test_model_forward_intermediates(model_name, batch_size):
529541
output2 = model.forward_features(inpt)
530542
assert torch.allclose(output, output2)
531543

544+
# Test that grad-checkpointing, if supported
545+
try:
546+
model.set_grad_checkpointing()
547+
except:
548+
# throws if not supported, that's fine
549+
pass
550+
else:
551+
output3, _ = model.forward_intermediates(
552+
inpt,
553+
output_fmt=output_fmt,
554+
)
555+
assert torch.allclose(output, output3, rtol=1e-4, atol=1e-5), 'Output does not match'
556+
557+
532558

533559
def _create_fx_model(model, train=False):
534560
# This block of code does a bit of juggling to handle any case where there are multiple outputs in train mode
@@ -717,4 +743,4 @@ def test_model_forward_torchscript_with_features_fx(model_name, batch_size):
717743

718744
for tensor in outputs:
719745
assert tensor.shape[0] == batch_size
720-
assert not torch.isnan(tensor).any(), 'Output included NaNs'
746+
assert not torch.isnan(tensor).any(), 'Output included NaNs'

0 commit comments

Comments
 (0)