Skip to content

Commit 1566583

Browse files
use request param for setting up use_cache
Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
1 parent fd1c20f commit 1566583

File tree

1 file changed

+4
-14
lines changed

1 file changed

+4
-14
lines changed

tests/models/test_decoders.py

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -905,16 +905,6 @@ def _run_cpu_aiu_validation_test(
905905
)
906906

907907

908-
def _get_cache_test_params():
909-
# NOTE - currently we always use granite 3.3 for the cache test,
910-
# TODO make this configurable as tests are refactored
911-
model_path = GRANITE_3p3_8B_INSTRUCT
912-
batch_size = COMMON_BATCH_SIZES[0]
913-
seq_length = COMMON_SEQ_LENGTHS[0]
914-
max_new_tokens = COMMON_MAX_NEW_TOKENS[0]
915-
return [model_path, batch_size, seq_length, max_new_tokens]
916-
917-
918908
def _reset_cache_settings(purge_cache_dir, cache_dir=None):
919909
os.environ["TORCH_SENDNN_CACHE_ENABLE"] = "1"
920910
os.environ["COMPILATION_MODE"] = "offline_decoder"
@@ -937,15 +927,15 @@ def _reset_cache_settings(purge_cache_dir, cache_dir=None):
937927

938928
@pytest.fixture
939929
def use_cached_model(request, persistent_model, record_property, tmp_path):
940-
"""Configures the torchsendnn cache and runs the AIU model prior to test execution;
941-
this is computationally expensive and should only be used in situations like testing
942-
cache hit correctness;
930+
"""Configures the torchsendnn cache and runs the AIU model (warmup)
931+
prior to test execution; this is computationally expensive and should
932+
only be used in situations like testing cache hit correctness.
943933
"""
944934
torch.manual_seed(42)
945935
torch.set_grad_enabled(False)
946936
_reset_cache_settings(purge_cache_dir=True, cache_dir=tmp_path)
947937

948-
model_path, batch_size, seq_length, max_new_tokens = _get_cache_test_params()
938+
model_path, batch_size, seq_length, max_new_tokens = request.param
949939
micro_model_path = MICRO_MODEL_MAPPING.get(model_path, None)
950940

951941
def verify_cache_miss():

0 commit comments

Comments
 (0)