@@ -936,7 +936,7 @@ def _reset_cache_settings(purge_cache_dir, cache_dir=None):
936936
937937
938938@pytest .fixture
939- def use_cached_model (persistent_model , record_property , tmp_path ):
939+ def use_cached_model (request , persistent_model , record_property , tmp_path ):
940940 """Configures the torchsendnn cache and runs the AIU model prior to test execution;
941941 this is computationally expensive and should only be used in situations like testing
942942 cache hit correctness;
@@ -990,7 +990,9 @@ def verify_cache_miss():
990990 micro_model_path ,
991991 record_property ,
992992 verify_cache_state = verify_cache_miss ,
993+ warmup_only = True ,
993994 )
995+ return request .param
994996
995997
996998@pytest .mark .parametrize (
@@ -1044,12 +1046,19 @@ def test_common_shapes(
10441046 )
10451047
10461048
1049+ @pytest .mark .parametrize (
1050+ "use_cached_model" ,
1051+ COMMON_SHAPES ,
1052+ indirect = True ,
1053+ )
10471054def test_cache (use_cached_model , persistent_model , record_property , tmp_path ):
10481055 torch .manual_seed (42 )
10491056 torch .set_grad_enabled (False )
10501057 _reset_cache_settings (purge_cache_dir = False , cache_dir = tmp_path )
10511058
1052- model_path , batch_size , seq_length , max_new_tokens = _get_cache_test_params ()
1059+ # use_cached_model is an indirectly parametrized fixture, and the returned
1060+ # value is an expanded tuple from COMMON_SHAPES, so we unpack it here
1061+ model_path , batch_size , seq_length , max_new_tokens = use_cached_model
10531062 micro_model_path = MICRO_MODEL_MAPPING .get (model_path , None )
10541063
10551064 def verify_cache_hit ():
@@ -1094,4 +1103,5 @@ def verify_cache_hit():
10941103 micro_model_path ,
10951104 record_property ,
10961105 verify_cache_state = verify_cache_hit ,
1106+ warmup_only = True ,
10971107 )
0 commit comments