Skip to content

Commit fd1c20f

Browse files
parametrize use_cache
Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
1 parent 630fe39 commit fd1c20f

File tree

1 file changed

+12
-2
lines changed

1 file changed

+12
-2
lines changed

tests/models/test_decoders.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -936,7 +936,7 @@ def _reset_cache_settings(purge_cache_dir, cache_dir=None):
936936

937937

938938
@pytest.fixture
939-
def use_cached_model(persistent_model, record_property, tmp_path):
939+
def use_cached_model(request, persistent_model, record_property, tmp_path):
940940
"""Configures the torchsendnn cache and runs the AIU model prior to test execution;
941941
this is computationally expensive and should only be used in situations like testing
942942
cache hit correctness;
@@ -990,7 +990,9 @@ def verify_cache_miss():
990990
micro_model_path,
991991
record_property,
992992
verify_cache_state=verify_cache_miss,
993+
warmup_only=True,
993994
)
995+
return request.param
994996

995997

996998
@pytest.mark.parametrize(
@@ -1044,12 +1046,19 @@ def test_common_shapes(
10441046
)
10451047

10461048

1049+
@pytest.mark.parametrize(
1050+
"use_cached_model",
1051+
COMMON_SHAPES,
1052+
indirect=True,
1053+
)
10471054
def test_cache(use_cached_model, persistent_model, record_property, tmp_path):
10481055
torch.manual_seed(42)
10491056
torch.set_grad_enabled(False)
10501057
_reset_cache_settings(purge_cache_dir=False, cache_dir=tmp_path)
10511058

1052-
model_path, batch_size, seq_length, max_new_tokens = _get_cache_test_params()
1059+
# use_cached_model is an indirectly parametrized fixture, and the returned
1060+
# value is an expanded tuple from COMMON_SHAPES, so we unpack it here
1061+
model_path, batch_size, seq_length, max_new_tokens = use_cached_model
10531062
micro_model_path = MICRO_MODEL_MAPPING.get(model_path, None)
10541063

10551064
def verify_cache_hit():
@@ -1094,4 +1103,5 @@ def verify_cache_hit():
10941103
micro_model_path,
10951104
record_property,
10961105
verify_cache_state=verify_cache_hit,
1106+
warmup_only=True,
10971107
)

0 commit comments

Comments
 (0)