|
9 | 9 | from aiu_fms_testing_utils.utils.aiu_setup import dprint |
10 | 10 | import os |
11 | 11 |
|
12 | | -if "HF_HOME" not in os.environ: |
13 | | - os.environ["HF_HOME"] = "/tmp/models/hf_cache" |
| 12 | +ORIGINAL_HF_HOME = os.environ.get("HF_HOME", None) |
14 | 13 |
|
15 | 14 | # Add models to test here |
16 | 15 | LLAMA_3p1_8B_INSTRUCT = "meta-llama/Llama-3.1-8B-Instruct" |
@@ -72,6 +71,11 @@ def reset_compiler(): |
72 | 71 | yield # run the test |
73 | 72 | torch.compiler.reset() |
74 | 73 | torch._dynamo.reset() |
| 74 | + os.environ.pop('COMPILATION_MODE', None) |
| 75 | + if ORIGINAL_HF_HOME is None: |
| 76 | + os.environ.pop('HF_HOME', None) |
| 77 | + else: |
| 78 | + os.environ['HF_HOME'] = ORIGINAL_HF_HOME |
75 | 79 |
|
76 | 80 | def __prepare_inputs(batch_size, seq_length, tokenizer, seed=0): |
77 | 81 | prompts_and_sizes = sample_sharegpt_requests(SHARE_GPT_DATASET_PATH, batch_size, tokenizer, int(seq_length / 2), seq_length, seed) |
@@ -116,6 +120,10 @@ def __load_validation_info(model_path, batch_size, seq_length, max_new_tokens, t |
116 | 120 | @pytest.mark.parametrize("model_path,batch_size,seq_length,max_new_tokens", common_shapes) |
117 | 121 | def test_common_shapes(model_path, batch_size, seq_length, max_new_tokens): |
118 | 122 | os.environ["COMPILATION_MODE"] = "offline_decoder" |
| 123 | + |
| 124 | + if "HF_HOME" not in os.environ: |
| 125 | + os.environ["HF_HOME"] = "/tmp/models/hf_cache" |
| 126 | + |
119 | 127 | dprint(f"testing model={model_path}, batch_size={batch_size}, seq_length={seq_length}, max_new_tokens={max_new_tokens}, micro_model={USE_MICRO_MODELS}") |
120 | 128 |
|
121 | 129 | if USE_MICRO_MODELS: |
|
0 commit comments