Skip to content

Commit ba4ab5d

Browse files
JRosenkranzflaviabeo
authored andcommitted
added trained micro model path for test_decoders
Signed-off-by: Joshua Rosenkranz <jmrosenk@us.ibm.com>
1 parent c380296 commit ba4ab5d

File tree

1 file changed

+37
-17
lines changed

1 file changed

+37
-17
lines changed

tests/models/test_decoders.py

Lines changed: 37 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -35,32 +35,44 @@
3535
GPTQ_ENABLED = False
3636

3737
ORIGINAL_HF_HOME = os.environ.get("HF_HOME", None)
38+
MODELS_HOME = os.environ.get("FMS_TEST_SHAPES_MODELS_HOME", "/home/senuser/models")
3839

3940
# Add models to test here
4041
LLAMA_3p1_8B_INSTRUCT = "meta-llama/Llama-3.1-8B-Instruct"
4142
GRANITE_3p2_8B_INSTRUCT = "ibm-granite/granite-3.2-8b-instruct"
4243
GRANITE_20B_CODE_INSTRUCT_8K = "ibm-granite/granite-20b-code-instruct-8k"
4344
LLAMA_3p1_70B_INSTRUCT = "meta-llama/Llama-3.1-70B-Instruct"
4445

46+
micro_model_mapping = {
47+
LLAMA_3p1_8B_INSTRUCT: os.path.join(MODELS_HOME, "llama-8b-layers-3-step-24000"),
48+
}
49+
4550
SHARE_GPT_DATASET_PATH = os.environ.get(
4651
"SHARE_GPT_DATASET_PATH", os.path.expanduser("~/share_gpt.json")
4752
)
4853
USE_MICRO_MODELS = os.environ.get("FMS_TEST_SHAPES_USE_MICRO_MODELS", "1") == "1"
4954
USE_DISTRIBUTED = os.environ.get("FMS_TEST_SHAPES_DISTRIBUTED", "0") == "1"
50-
FORCE_VALIDATION_LEVEL_1 = os.environ.get("FMS_TEST_SHAPES_FORCE_VALIDATION_LEVEL_1", "0") == "1"
55+
FORCE_VALIDATION_LEVEL_1 = (
56+
os.environ.get("FMS_TEST_SHAPES_FORCE_VALIDATION_LEVEL_1", "0") == "1"
57+
)
5158
skip_assertions = os.environ.get("FMS_TEST_SHAPES_SKIP_ASSERTIONS", {})
5259
validation_info_dir = os.environ.get(
53-
"FMS_TEST_SHAPES_VALIDATION_INFO_DIR", "/tmp/models/validation_info"
60+
"FMS_TEST_SHAPES_VALIDATION_INFO_DIR", "/home/senuser/models/validation_info"
5461
)
5562
common_model_paths = os.environ.get(
5663
"FMS_TEST_SHAPES_COMMON_MODEL_PATHS",
57-
[LLAMA_3p1_8B_INSTRUCT, GRANITE_3p2_8B_INSTRUCT, GRANITE_20B_CODE_INSTRUCT_8K, LLAMA_3p1_70B_INSTRUCT],
64+
[
65+
LLAMA_3p1_8B_INSTRUCT,
66+
GRANITE_3p2_8B_INSTRUCT,
67+
GRANITE_20B_CODE_INSTRUCT_8K,
68+
LLAMA_3p1_70B_INSTRUCT,
69+
],
5870
)
5971
# for validation level 1, the default is a failure rate of 1%
6072
# set this environment variable if you would like to relax that threshold
6173
failure_rate_threshold = os.environ.get("FMS_TEST_SHAPES_FAILURE_THRESHOLD", 0.01)
6274
default_metrics_threshold = os.environ.get(
63-
"FMS_TEST_SHAPES_METRICS_THRESHOLD", (3.0, .001)
75+
"FMS_TEST_SHAPES_METRICS_THRESHOLD", (3.0, 0.001)
6476
)
6577
save_validation_info_outputs = (
6678
os.environ.get("FMS_TEST_SHAPES_SAVE_VALIDATION_INFO_OUTPUTS", "0") == "1"
@@ -86,7 +98,9 @@
8698

8799
# pass custom default metrics threshold as a comma separated str of floats <cross-entropy threshold>,<mean diff threshold>
88100
if isinstance(default_metrics_threshold, str):
89-
default_metrics_threshold = tuple([float(m) for m in default_metrics_threshold.split(",")])
101+
default_metrics_threshold = tuple(
102+
[float(m) for m in default_metrics_threshold.split(",")]
103+
)
90104

91105
# pass custom common batch sizes as a comma separated str of ints
92106
if isinstance(common_batch_sizes, str):
@@ -126,19 +140,19 @@
126140
fail_thresholds = {
127141
(LLAMA_3p1_8B_INSTRUCT, True): (
128142
3.7392955756187423,
129-
.001, # FIXME: compute
143+
0.001, # FIXME: compute
130144
),
131145
(GRANITE_3p2_8B_INSTRUCT, True): (
132146
2.996668996810913,
133-
.001, # FIXME: compute
147+
0.001, # FIXME: compute
134148
),
135149
(GRANITE_20B_CODE_INSTRUCT_8K, True): (
136-
3.7392955756187423, # FIXME: compute -- setting to micro llama 3.1 8b instruct
137-
.001, # FIXME: compute
150+
3.7392955756187423, # FIXME: compute -- setting to micro llama 3.1 8b instruct
151+
0.001, # FIXME: compute
138152
),
139153
(LLAMA_3p1_70B_INSTRUCT, True): (
140154
3.8235735702514626,
141-
.001, # FIXME: compute
155+
0.001, # FIXME: compute
142156
),
143157
(LLAMA_3p1_8B_INSTRUCT, False): (
144158
2.6994638133048965,
@@ -316,7 +330,7 @@ def test_common_shapes(model_path, batch_size, seq_length, max_new_tokens):
316330
os.environ["COMPILATION_MODE"] = "offline_decoder"
317331

318332
if "HF_HOME" not in os.environ:
319-
os.environ["HF_HOME"] = "/tmp/models/hf_cache"
333+
os.environ["HF_HOME"] = "/home/senuser/models/hf_cache"
320334

321335
dprint(
322336
f"testing model={model_path}, batch_size={batch_size}, seq_length={seq_length}, max_new_tokens={max_new_tokens}, micro_model={USE_MICRO_MODELS}"
@@ -326,13 +340,18 @@ def test_common_shapes(model_path, batch_size, seq_length, max_new_tokens):
326340
gptq_kwargs_aiu, gptq_kwargs_cpu = __maybe_get_gptq_kwargs(model_path)
327341
is_gptq = len(gptq_kwargs_aiu) != 0
328342

329-
if USE_MICRO_MODELS:
343+
micro_model_path = micro_model_mapping.get(model_path, None)
344+
if USE_MICRO_MODELS and micro_model_path is None:
345+
dprint("using randomly initialized model")
330346
micro_model_kwargs = {"architecture": "hf_configured", "nlayers": 3}
331347
else:
348+
dprint("using trained model")
332349
micro_model_kwargs = {"architecture": "hf_pretrained"}
333350

334351
if not USE_MICRO_MODELS and os.path.exists(model_path):
335352
model_path_kwargs = {"model_path": model_path}
353+
elif USE_MICRO_MODELS and micro_model_path is not None:
354+
model_path_kwargs = {"model_path": micro_model_path}
336355
else:
337356
model_path_kwargs = {"variant": model_path}
338357

@@ -428,7 +447,6 @@ def test_common_shapes(model_path, batch_size, seq_length, max_new_tokens):
428447

429448
# if level 0 fails validation, validate level 1
430449
if FORCE_VALIDATION_LEVEL_1 or failed_validation_level_0:
431-
432450
if failed_validation_level_0:
433451
dprint("failed validation level 0, testing validation level 1")
434452
else:
@@ -439,10 +457,12 @@ def _metric_calculator(r: torch.Tensor, t: torch.Tensor):
439457
cross_entropy = torch.nn.CrossEntropyLoss()(
440458
r, t.softmax(dim=1).to(dtype=torch.float32)
441459
)
442-
diff = torch.mean(torch.abs(
443-
r.softmax(dim=1).to(dtype=torch.float32)
444-
- t.softmax(dim=1).to(dtype=torch.float32)
445-
))
460+
diff = torch.mean(
461+
torch.abs(
462+
r.softmax(dim=1).to(dtype=torch.float32)
463+
- t.softmax(dim=1).to(dtype=torch.float32)
464+
)
465+
)
446466
return (cross_entropy, diff)
447467

448468
iters = 1024 // max_new_tokens

0 commit comments

Comments
 (0)