Skip to content

Commit fe8c61f

Browse files
fix cache_dir in cache checks
Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
1 parent 42305bb commit fe8c61f

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

tests/models/test_decoders.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -944,7 +944,7 @@ def use_cached_model(persistent_model, record_property, tmp_path):
944944
micro_model_path = MICRO_MODEL_MAPPING.get(model_path, None)
945945

946946
def verify_cache_miss():
947-
cache_dir = os.environ.get("TORCH_SENDNN_CACHE_DIR")
947+
cache_dir = str(tmp_path)
948948
updated_cache_len = (
949949
len(os.listdir(cache_dir)) if os.path.isdir(cache_dir) else 0
950950
)
@@ -1048,7 +1048,7 @@ def test_cache(use_cached_model, persistent_model, record_property, tmp_path):
10481048
micro_model_path = MICRO_MODEL_MAPPING.get(model_path, None)
10491049

10501050
def verify_cache_hit():
1051-
cache_dir = os.environ.get("TORCH_SENDNN_CACHE_DIR")
1051+
cache_dir = str(tmp_path)
10521052
updated_cache_len = (
10531053
len(os.listdir(cache_dir)) if os.path.isdir(cache_dir) else 0
10541054
)

0 commit comments

Comments
 (0)