@@ -905,16 +905,6 @@ def _run_cpu_aiu_validation_test(
905905 )
906906
907907
908- def _get_cache_test_params ():
909- # NOTE - currently we always use granite 3.3 for the cache test,
910- # TODO make this configurable as tests are refactored
911- model_path = GRANITE_3p3_8B_INSTRUCT
912- batch_size = COMMON_BATCH_SIZES [0 ]
913- seq_length = COMMON_SEQ_LENGTHS [0 ]
914- max_new_tokens = COMMON_MAX_NEW_TOKENS [0 ]
915- return [model_path , batch_size , seq_length , max_new_tokens ]
916-
917-
918908def _reset_cache_settings (purge_cache_dir , cache_dir = None ):
919909 os .environ ["TORCH_SENDNN_CACHE_ENABLE" ] = "1"
920910 os .environ ["COMPILATION_MODE" ] = "offline_decoder"
@@ -937,15 +927,15 @@ def _reset_cache_settings(purge_cache_dir, cache_dir=None):
937927
938928@pytest .fixture
939929def use_cached_model (request , persistent_model , record_property , tmp_path ):
940- """Configures the torchsendnn cache and runs the AIU model prior to test execution;
941- this is computationally expensive and should only be used in situations like testing
942- cache hit correctness;
930+ """Configures the torchsendnn cache and runs the AIU model (warmup)
931+ prior to test execution; this is computationally expensive and should
932+ only be used in situations like testing cache hit correctness.
943933 """
944934 torch .manual_seed (42 )
945935 torch .set_grad_enabled (False )
946936 _reset_cache_settings (purge_cache_dir = True , cache_dir = tmp_path )
947937
948- model_path , batch_size , seq_length , max_new_tokens = _get_cache_test_params ()
938+ model_path , batch_size , seq_length , max_new_tokens = request . param
949939 micro_model_path = MICRO_MODEL_MAPPING .get (model_path , None )
950940
951941 def verify_cache_miss ():
0 commit comments