|
35 | 35 | GPTQ_ENABLED = False |
36 | 36 |
|
37 | 37 | ORIGINAL_HF_HOME = os.environ.get("HF_HOME", None) |
38 | | -MODELS_HOME = os.environ.get("FMS_TEST_SHAPES_MODELS_HOME", "/home/senuser/models") |
| 38 | +MICRO_MODELS_HOME = os.environ.get("FMS_TEST_SHAPES_MICRO_MODELS_HOME", "/mnt/home") |
39 | 39 |
|
40 | 40 | # Add models to test here |
41 | 41 | LLAMA_3p1_8B_INSTRUCT = "meta-llama/Llama-3.1-8B-Instruct" |
|
44 | 44 | LLAMA_3p1_70B_INSTRUCT = "meta-llama/Llama-3.1-70B-Instruct" |
45 | 45 |
|
46 | 46 | micro_model_mapping = { |
47 | | - LLAMA_3p1_8B_INSTRUCT: os.path.join(MODELS_HOME, "llama-8b-layers-3-step-24000"), |
| 47 | + LLAMA_3p1_8B_INSTRUCT: os.path.join(MICRO_MODELS_HOME, "llama-8b-layers-3-step-24000"), |
| 48 | + GRANITE_3p2_8B_INSTRUCT: os.path.join(MICRO_MODELS_HOME, "granite-3.2-8b-layers-3-step-24000") |
48 | 49 | } |
49 | 50 |
|
50 | 51 | SHARE_GPT_DATASET_PATH = os.environ.get( |
|
57 | 58 | ) |
58 | 59 | skip_assertions = os.environ.get("FMS_TEST_SHAPES_SKIP_ASSERTIONS", {}) |
59 | 60 | validation_info_dir = os.environ.get( |
60 | | - "FMS_TEST_SHAPES_VALIDATION_INFO_DIR", "/home/senuser/models/validation_info" |
| 61 | + "FMS_TEST_SHAPES_VALIDATION_INFO_DIR", "/tmp/models/validation_info" |
61 | 62 | ) |
62 | 63 | common_model_paths = os.environ.get( |
63 | 64 | "FMS_TEST_SHAPES_COMMON_MODEL_PATHS", |
|
136 | 137 | # thresholds are chosen based on 1024 tokens per sequence |
137 | 138 | # 1% error threshold rate between cpu fp32 and cuda fp16 |
138 | 139 | # if a models failure thresholds do not exist in this dict, default to the default_metrics_threshold defined above |
139 | | -# threshold key is model_id |
| 140 | +# threshold key is (model_id, is_tiny_model) |
140 | 141 | fail_thresholds = { |
141 | | - LLAMA_3p1_8B_INSTRUCT: ( |
| 142 | + (LLAMA_3p1_8B_INSTRUCT, False): ( |
142 | 143 | 2.6994638133048965, |
143 | 144 | 0.00047589250549208347, |
144 | 145 | ), |
145 | | - GRANITE_3p2_8B_INSTRUCT: ( |
| 146 | + (GRANITE_3p2_8B_INSTRUCT, False): ( |
146 | 147 | 2.3919514417648315, |
147 | 148 | 0.0005767398688476533, |
148 | 149 | ), |
149 | | - GRANITE_20B_CODE_INSTRUCT_8K: ( |
| 150 | + (GRANITE_20B_CODE_INSTRUCT_8K, False): ( |
150 | 151 | 2.640706129074097, |
151 | 152 | 0.00034344267623964697, |
152 | 153 | ), |
153 | | - LLAMA_3p1_70B_INSTRUCT: ( |
| 154 | + (LLAMA_3p1_70B_INSTRUCT, False): ( |
154 | 155 | 2.841279556751251, |
155 | 156 | 0.0044301633024588115, |
156 | 157 | ), |
@@ -314,7 +315,7 @@ def test_common_shapes(model_path, batch_size, seq_length, max_new_tokens): |
314 | 315 | os.environ["COMPILATION_MODE"] = "offline_decoder" |
315 | 316 |
|
316 | 317 | if "HF_HOME" not in os.environ: |
317 | | - os.environ["HF_HOME"] = "/home/senuser/models/hf_cache" |
| 318 | + os.environ["HF_HOME"] = "/tmp/models/hf_cache" |
318 | 319 |
|
319 | 320 | dprint( |
320 | 321 | f"testing model={model_path}, batch_size={batch_size}, seq_length={seq_length}, max_new_tokens={max_new_tokens}, micro_model={USE_MICRO_MODELS}" |
@@ -431,6 +432,7 @@ def test_common_shapes(model_path, batch_size, seq_length, max_new_tokens): |
431 | 432 |
|
432 | 433 | # if level 0 fails validation, validate level 1 |
433 | 434 | if FORCE_VALIDATION_LEVEL_1 or failed_validation_level_0: |
| 435 | + |
434 | 436 | if failed_validation_level_0: |
435 | 437 | dprint("failed validation level 0, testing validation level 1") |
436 | 438 | else: |
@@ -519,9 +521,15 @@ def _metric_calculator(r: torch.Tensor, t: torch.Tensor): |
519 | 521 | ce_threshold, diff_threshold = default_metrics_threshold |
520 | 522 | # if we have real weights, try and get the proper validation metrics threshold |
521 | 523 | else: |
522 | | - ce_threshold, diff_threshold = fail_thresholds.get( |
523 | | - model_path, default_metrics_threshold |
524 | | - ) |
| 524 | + # if we have a micro model with real weights, but no real thresholds, default to the full model thresholds |
| 525 | + if USE_MICRO_MODELS: |
| 526 | + ce_threshold, diff_threshold = fail_thresholds.get( |
| 527 | + (model_path, True), fail_thresholds.get((model_path, False), default_metrics_threshold) |
| 528 | + ) |
| 529 | + else: |
| 530 | + ce_threshold, diff_threshold = fail_thresholds.get( |
| 531 | + (model_path, False), default_metrics_threshold |
| 532 | + ) |
525 | 533 |
|
526 | 534 | # get all failed responses for each metric |
527 | 535 | ce_fail_responses = filter_failed_level_1_cases( |
|
0 commit comments