Skip to content

Commit 630fe39

Browse files
only warmup on cache tests
Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
1 parent fe8c61f commit 630fe39

File tree

1 file changed

+11
-6
lines changed

1 file changed

+11
-6
lines changed

tests/models/test_decoders.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -836,6 +836,7 @@ def _run_cpu_aiu_validation_test(
836836
micro_model_path,
837837
record_property,
838838
verify_cache_state=None,
839+
warmup_only=False,
839840
):
840841
# Get the tokenizer and AIU / CPU models to compare
841842
tokenizer = AutoTokenizer.from_pretrained(model_path)
@@ -859,6 +860,16 @@ def _run_cpu_aiu_validation_test(
859860
aiu_model, input_ids, max_new_tokens, COMPILE_DYNAMIC_SENDNN, **extra_kwargs
860861
)
861862

863+
# Used only for cache tests; this is a nonparametric closure that
864+
# should assert the cache for torch sendnn is in the correct state
865+
# for this test
866+
if verify_cache_state is not None:
867+
verify_cache_state()
868+
869+
# For some tests, e.g., cache checks, we only need to run the warmup
870+
if warmup_only:
871+
return
872+
862873
# Run validation level 0
863874
failed_validation_level_0, validation_zero_info = _run_validation_level_0(
864875
model_path,
@@ -872,12 +883,6 @@ def _run_cpu_aiu_validation_test(
872883
aiu_model,
873884
)
874885

875-
# Used only for cache tests; this is a nonparametric closure that
876-
# should assert the cache for torch sendnn is in the correct state
877-
# for this test
878-
if verify_cache_state is not None:
879-
verify_cache_state()
880-
881886
# if level 0 fails validation, validate level 1
882887
if FORCE_VALIDATION_LEVEL_1 or failed_validation_level_0:
883888
if failed_validation_level_0:

0 commit comments

Comments
 (0)