Skip to content

Commit bcc4c05

Browse files
JRosenkranzflaviabeo
authored andcommitted
added granite micro model; reverted key as model_id and keeping as model_id and is_tiny_model -- in case key not found, we default to fullsize model
Signed-off-by: Joshua Rosenkranz <jmrosenk@us.ibm.com>
1 parent ccb9391 commit bcc4c05

File tree

1 file changed

+20
-12
lines changed

1 file changed

+20
-12
lines changed

tests/models/test_decoders.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
GPTQ_ENABLED = False
3636

3737
ORIGINAL_HF_HOME = os.environ.get("HF_HOME", None)
38-
MODELS_HOME = os.environ.get("FMS_TEST_SHAPES_MODELS_HOME", "/home/senuser/models")
38+
MICRO_MODELS_HOME = os.environ.get("FMS_TEST_SHAPES_MICRO_MODELS_HOME", "/mnt/home")
3939

4040
# Add models to test here
4141
LLAMA_3p1_8B_INSTRUCT = "meta-llama/Llama-3.1-8B-Instruct"
@@ -44,7 +44,8 @@
4444
LLAMA_3p1_70B_INSTRUCT = "meta-llama/Llama-3.1-70B-Instruct"
4545

4646
micro_model_mapping = {
47-
LLAMA_3p1_8B_INSTRUCT: os.path.join(MODELS_HOME, "llama-8b-layers-3-step-24000"),
47+
LLAMA_3p1_8B_INSTRUCT: os.path.join(MICRO_MODELS_HOME, "llama-8b-layers-3-step-24000"),
48+
GRANITE_3p2_8B_INSTRUCT: os.path.join(MICRO_MODELS_HOME, "granite-3.2-8b-layers-3-step-24000")
4849
}
4950

5051
SHARE_GPT_DATASET_PATH = os.environ.get(
@@ -57,7 +58,7 @@
5758
)
5859
skip_assertions = os.environ.get("FMS_TEST_SHAPES_SKIP_ASSERTIONS", {})
5960
validation_info_dir = os.environ.get(
60-
"FMS_TEST_SHAPES_VALIDATION_INFO_DIR", "/home/senuser/models/validation_info"
61+
"FMS_TEST_SHAPES_VALIDATION_INFO_DIR", "/tmp/models/validation_info"
6162
)
6263
common_model_paths = os.environ.get(
6364
"FMS_TEST_SHAPES_COMMON_MODEL_PATHS",
@@ -136,21 +137,21 @@
136137
# thresholds are chosen based on 1024 tokens per sequence
137138
# 1% error threshold rate between cpu fp32 and cuda fp16
138139
# if a models failure thresholds do not exist in this dict, default to the default_metrics_threshold defined above
139-
# threshold key is model_id
140+
# threshold key is (model_id, is_tiny_model)
140141
fail_thresholds = {
141-
LLAMA_3p1_8B_INSTRUCT: (
142+
(LLAMA_3p1_8B_INSTRUCT, False): (
142143
2.6994638133048965,
143144
0.00047589250549208347,
144145
),
145-
GRANITE_3p2_8B_INSTRUCT: (
146+
(GRANITE_3p2_8B_INSTRUCT, False): (
146147
2.3919514417648315,
147148
0.0005767398688476533,
148149
),
149-
GRANITE_20B_CODE_INSTRUCT_8K: (
150+
(GRANITE_20B_CODE_INSTRUCT_8K, False): (
150151
2.640706129074097,
151152
0.00034344267623964697,
152153
),
153-
LLAMA_3p1_70B_INSTRUCT: (
154+
(LLAMA_3p1_70B_INSTRUCT, False): (
154155
2.841279556751251,
155156
0.0044301633024588115,
156157
),
@@ -314,7 +315,7 @@ def test_common_shapes(model_path, batch_size, seq_length, max_new_tokens):
314315
os.environ["COMPILATION_MODE"] = "offline_decoder"
315316

316317
if "HF_HOME" not in os.environ:
317-
os.environ["HF_HOME"] = "/home/senuser/models/hf_cache"
318+
os.environ["HF_HOME"] = "/tmp/models/hf_cache"
318319

319320
dprint(
320321
f"testing model={model_path}, batch_size={batch_size}, seq_length={seq_length}, max_new_tokens={max_new_tokens}, micro_model={USE_MICRO_MODELS}"
@@ -431,6 +432,7 @@ def test_common_shapes(model_path, batch_size, seq_length, max_new_tokens):
431432

432433
# if level 0 fails validation, validate level 1
433434
if FORCE_VALIDATION_LEVEL_1 or failed_validation_level_0:
435+
434436
if failed_validation_level_0:
435437
dprint("failed validation level 0, testing validation level 1")
436438
else:
@@ -519,9 +521,15 @@ def _metric_calculator(r: torch.Tensor, t: torch.Tensor):
519521
ce_threshold, diff_threshold = default_metrics_threshold
520522
# if we have real weights, try and get the proper validation metrics threshold
521523
else:
522-
ce_threshold, diff_threshold = fail_thresholds.get(
523-
model_path, default_metrics_threshold
524-
)
524+
# if we have a micro model with real weights, but no real thresholds, default to the full model thresholds
525+
if USE_MICRO_MODELS:
526+
ce_threshold, diff_threshold = fail_thresholds.get(
527+
(model_path, True), fail_thresholds.get((model_path, False), default_metrics_threshold)
528+
)
529+
else:
530+
ce_threshold, diff_threshold = fail_thresholds.get(
531+
(model_path, False), default_metrics_threshold
532+
)
525533

526534
# get all failed responses for each metric
527535
ce_fail_responses = filter_failed_level_1_cases(

0 commit comments

Comments
 (0)