@@ -910,29 +910,35 @@ def _get_cache_test_params():
910910 return [model_path , batch_size , seq_length , max_new_tokens ]
911911
912912
913- def _reset_cache_settings (purge_cache_dir ):
913+ def _reset_cache_settings (purge_cache_dir , cache_dir = None ):
914914 os .environ ["TORCH_SENDNN_CACHE_ENABLE" ] = "1"
915915 os .environ ["COMPILATION_MODE" ] = "offline_decoder"
916- cache_dir = os .environ ["TORCH_SENDNN_CACHE_DIR" ]
916+ if cache_dir is not None :
917+ # Might be a posixpath
918+ cache_dir = str (cache_dir )
919+ os .environ ["TORCH_SENDNN_CACHE_DIR" ] = cache_dir
917920
918921 # Ensure we start in clean state
919922 if purge_cache_dir and os .path .isdir (cache_dir ):
920923 shutil .rmtree (cache_dir )
921924 os .mkdir (cache_dir )
922925
926+ # NOTE: currently, the cache dir is pulled from
927+ # TORCH_SENDNN_CACHE_DIR at initialization time,
928+ # so this should correctly use the cache_dir
923929 _get_global_state ().use_aiu_cache = True
924930 _get_global_state ().spyre_graph_cache = SpyreGraphCache ()
925931
926932
927933@pytest .fixture
928- def use_cached_model (persistent_model , record_property ):
934+ def use_cached_model (persistent_model , record_property , tmp_path ):
929935 """Configures the torchsendnn cache and runs the AIU model prior to test execution;
930936 this is computationally expensive and should only be used in situations like testing
931937 cache hit correctness;
932938 """
933939 torch .manual_seed (42 )
934940 torch .set_grad_enabled (False )
935- _reset_cache_settings (purge_cache_dir = True )
941+ _reset_cache_settings (purge_cache_dir = True , cache_dir = tmp_path )
936942
937943 model_path , batch_size , seq_length , max_new_tokens = _get_cache_test_params ()
938944 micro_model_path = MICRO_MODEL_MAPPING .get (model_path , None )
@@ -1033,10 +1039,10 @@ def test_common_shapes(
10331039 )
10341040
10351041
1036- def test_cache (use_cached_model , persistent_model , record_property ):
1042+ def test_cache (use_cached_model , persistent_model , record_property , tmp_path ):
10371043 torch .manual_seed (42 )
10381044 torch .set_grad_enabled (False )
1039- _reset_cache_settings (purge_cache_dir = False )
1045+ _reset_cache_settings (purge_cache_dir = False , cache_dir = tmp_path )
10401046
10411047 model_path , batch_size , seq_length , max_new_tokens = _get_cache_test_params ()
10421048 micro_model_path = MICRO_MODEL_MAPPING .get (model_path , None )
0 commit comments