Skip to content

Commit 42305bb

Browse files
use tmp_path fixture for cache test
Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
1 parent 18bbf01 commit 42305bb

File tree

1 file changed

+12
-6
lines changed

1 file changed

+12
-6
lines changed

tests/models/test_decoders.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -910,29 +910,35 @@ def _get_cache_test_params():
910910
return [model_path, batch_size, seq_length, max_new_tokens]
911911

912912

913-
def _reset_cache_settings(purge_cache_dir):
913+
def _reset_cache_settings(purge_cache_dir, cache_dir=None):
914914
os.environ["TORCH_SENDNN_CACHE_ENABLE"] = "1"
915915
os.environ["COMPILATION_MODE"] = "offline_decoder"
916-
cache_dir = os.environ["TORCH_SENDNN_CACHE_DIR"]
916+
if cache_dir is not None:
917+
# Might be a posixpath
918+
cache_dir = str(cache_dir)
919+
os.environ["TORCH_SENDNN_CACHE_DIR"] = cache_dir
917920

918921
# Ensure we start in clean state
919922
if purge_cache_dir and os.path.isdir(cache_dir):
920923
shutil.rmtree(cache_dir)
921924
os.mkdir(cache_dir)
922925

926+
# NOTE: currently, the cache dir is pulled from
927+
# TORCH_SENDNN_CACHE_DIR at initialization time,
928+
# so this should correctly use the cache_dir
923929
_get_global_state().use_aiu_cache = True
924930
_get_global_state().spyre_graph_cache = SpyreGraphCache()
925931

926932

927933
@pytest.fixture
928-
def use_cached_model(persistent_model, record_property):
934+
def use_cached_model(persistent_model, record_property, tmp_path):
929935
"""Configures the torchsendnn cache and runs the AIU model prior to test execution;
930936
this is computationally expensive and should only be used in situations like testing
931937
cache hit correctness;
932938
"""
933939
torch.manual_seed(42)
934940
torch.set_grad_enabled(False)
935-
_reset_cache_settings(purge_cache_dir=True)
941+
_reset_cache_settings(purge_cache_dir=True, cache_dir=tmp_path)
936942

937943
model_path, batch_size, seq_length, max_new_tokens = _get_cache_test_params()
938944
micro_model_path = MICRO_MODEL_MAPPING.get(model_path, None)
@@ -1033,10 +1039,10 @@ def test_common_shapes(
10331039
)
10341040

10351041

1036-
def test_cache(use_cached_model, persistent_model, record_property):
1042+
def test_cache(use_cached_model, persistent_model, record_property, tmp_path):
10371043
torch.manual_seed(42)
10381044
torch.set_grad_enabled(False)
1039-
_reset_cache_settings(purge_cache_dir=False)
1045+
_reset_cache_settings(purge_cache_dir=False, cache_dir=tmp_path)
10401046

10411047
model_path, batch_size, seq_length, max_new_tokens = _get_cache_test_params()
10421048
micro_model_path = MICRO_MODEL_MAPPING.get(model_path, None)

0 commit comments

Comments
 (0)