Skip to content

Commit 642e49d

Browse files
committed
reset environment variables after each test
Signed-off-by: Joshua Rosenkranz <jmrosenk@us.ibm.com>
1 parent 089315f commit 642e49d

File tree

2 files changed

+19
-4
lines changed

2 files changed

+19
-4
lines changed

tests/models/test_decoders.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,7 @@
99
from aiu_fms_testing_utils.utils.aiu_setup import dprint
1010
import os
1111

12-
if "HF_HOME" not in os.environ:
13-
os.environ["HF_HOME"] = "/tmp/models/hf_cache"
12+
ORIGINAL_HF_HOME = os.environ.get("HF_HOME", None)
1413

1514
# Add models to test here
1615
LLAMA_3p1_8B_INSTRUCT = "meta-llama/Llama-3.1-8B-Instruct"
@@ -72,6 +71,11 @@ def reset_compiler():
7271
yield # run the test
7372
torch.compiler.reset()
7473
torch._dynamo.reset()
74+
os.environ.pop('COMPILATION_MODE', None)
75+
if ORIGINAL_HF_HOME is None:
76+
os.environ.pop('HF_HOME', None)
77+
else:
78+
os.environ['HF_HOME'] = ORIGINAL_HF_HOME
7579

7680
def __prepare_inputs(batch_size, seq_length, tokenizer, seed=0):
7781
prompts_and_sizes = sample_sharegpt_requests(SHARE_GPT_DATASET_PATH, batch_size, tokenizer, int(seq_length / 2), seq_length, seed)
@@ -116,6 +120,10 @@ def __load_validation_info(model_path, batch_size, seq_length, max_new_tokens, t
116120
@pytest.mark.parametrize("model_path,batch_size,seq_length,max_new_tokens", common_shapes)
117121
def test_common_shapes(model_path, batch_size, seq_length, max_new_tokens):
118122
os.environ["COMPILATION_MODE"] = "offline_decoder"
123+
124+
if "HF_HOME" not in os.environ:
125+
os.environ["HF_HOME"] = "/tmp/models/hf_cache"
126+
119127
dprint(f"testing model={model_path}, batch_size={batch_size}, seq_length={seq_length}, max_new_tokens={max_new_tokens}, micro_model={USE_MICRO_MODELS}")
120128

121129
if USE_MICRO_MODELS:

tests/models/test_encoders.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,7 @@
99
from aiu_fms_testing_utils.utils.aiu_setup import dprint
1010
import os
1111

12-
if "HF_HOME" not in os.environ:
13-
os.environ["HF_HOME"] = "/tmp/models/hf_cache"
12+
ORIGINAL_HF_HOME = os.environ.get("HF_HOME", None)
1413

1514
# Add models to test here
1615
ROBERTA_SQUAD_V2 = "deepset/roberta-base-squad2"
@@ -49,13 +48,21 @@ def reset_compiler():
4948
yield # run the test
5049
torch.compiler.reset()
5150
torch._dynamo.reset()
51+
os.environ.pop('COMPILATION_MODE', None)
52+
if ORIGINAL_HF_HOME is None:
53+
os.environ.pop('HF_HOME', None)
54+
else:
55+
os.environ['HF_HOME'] = ORIGINAL_HF_HOME
5256

5357
encoder_paths = ["deepset/roberta-base-squad2"]
5458
common_encoder_shapes = list(itertools.product(encoder_paths, common_batch_sizes, common_seq_lengths))
5559

5660
@pytest.mark.parametrize("model_path,batch_size,seq_length", common_encoder_shapes)
5761
def test_common_shapes(model_path, batch_size, seq_length):
5862
os.environ["COMPILATION_MODE"] = "offline"
63+
64+
if "HF_HOME" not in os.environ:
65+
os.environ["HF_HOME"] = "/tmp/models/hf_cache"
5966

6067
dprint(f"testing model={model_path}, batch_size={batch_size}, seq_length={seq_length}")
6168

0 commit comments

Comments
 (0)