Skip to content

Commit fb4d8f3

Browse files
Rebase fixes, linting
Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
1 parent 373939d commit fb4d8f3

File tree

1 file changed

+34
-27
lines changed

1 file changed

+34
-27
lines changed

tests/models/test_decoders.py

Lines changed: 34 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@
132132
if USE_MICRO_MODELS:
133133
VALIDATION_INFO_DIR = os.path.join(VALIDATION_INFO_DIR, "tiny_models")
134134

135-
# pass custom model path list for eg: EXPORT FMS_TESTING_COMMON_MODEL_PATHS="/tmp/models/granite-3-8b-base,/tmp/models/granite-7b-base"
135+
# pass custom model path list for eg: EXPORT FMS_TEST_SHAPES_COMMON_MODEL_PATHS="/tmp/models/granite-3-8b-base,/tmp/models/granite-7b-base"
136136
if isinstance(COMMON_MODEL_PATHS, str):
137137
COMMON_MODEL_PATHS = COMMON_MODEL_PATHS.split(",")
138138

@@ -185,7 +185,7 @@
185185
]
186186
)
187187
os.environ["VLLM_DT_MAX_BATCH_SIZE"] = str(max(max(COMMON_BATCH_SIZES), 2))
188-
188+
fx_config.backed_size_oblivious = True
189189

190190
# thresholds are chosen based on 1024 tokens per sequence
191191
# 1% error threshold rate between cpu fp32 and cuda fp16
@@ -220,22 +220,22 @@
220220
)
221221
USE_MICRO_MODELS = False
222222
COMMON_MODEL_PATHS = []
223-
frequency = int(MODEL_CONFIGURATION_FREQUENCY)
223+
FREQUENCY = int(MODEL_CONFIGURATION_FREQUENCY)
224224
with open(MODEL_CONFIGURATION_PATH, "r") as f:
225225
for line in f:
226226
try:
227-
model_config = json.loads(line)
228-
if model_config["frequency"] <= frequency:
229-
COMMON_MODEL_PATHS.append(model_config["model_id"])
227+
MODEL_CONFIG = json.loads(line)
228+
if MODEL_CONFIG["frequency"] <= FREQUENCY:
229+
COMMON_MODEL_PATHS.append(MODEL_CONFIG["model_id"])
230230
# assume fullsize models
231-
FAIL_THRESHOLDS[(model_config["model_id"], USE_MICRO_MODELS)] = (
232-
model_config["ce"],
233-
model_config["mean_diff"],
231+
FAIL_THRESHOLDS[(MODEL_CONFIG["model_id"], USE_MICRO_MODELS)] = (
232+
MODEL_CONFIG["ce"],
233+
MODEL_CONFIG["mean_diff"],
234234
)
235235
except json.JSONDecodeError:
236236
print(f"config contained an improper json line: {line.strip()}")
237237

238-
common_shapes = list(
238+
COMMON_SHAPES = list(
239239
itertools.product(
240240
COMMON_MODEL_PATHS,
241241
COMMON_BATCH_SIZES,
@@ -308,7 +308,7 @@ def __maybe_get_gptq_kwargs(model_path):
308308
return gptq_kwargs_aiu, gptq_kwargs_cpu
309309

310310

311-
def __prepare_inputs(batch_size, seq_length, tokenizer, seed=0):
311+
def __prepare_inputs(batch_size, seq_length, tokenizer, model_path, seed=0):
312312
if "paged" in ATTN_NAME:
313313
prompts_and_sizes = sample_sharegpt_requests(
314314
SHARE_GPT_DATASET_PATH,
@@ -480,17 +480,21 @@ def _metric_calculator(r: torch.Tensor, t: torch.Tensor):
480480

481481

482482
def _check_failure_thresholds(
483-
diff_fail_responses_list, ce_fail_responses_list, total_tokens
483+
diff_fail_responses_list,
484+
ce_fail_responses_list,
485+
total_tokens,
486+
record_property=None,
484487
):
485488
# test the failure rates for across all tokens
486489
diff_failure_rate = len(diff_fail_responses_list) / total_tokens
487490
ce_failure_rate = len(ce_fail_responses_list) / total_tokens
488491
dprint(f"mean diff failure rate: {diff_failure_rate}")
489492
dprint(f"cross entropy loss failure rate: {ce_failure_rate}")
490493

491-
# Add failure rates to xml report
492-
record_property("mean_diff_failure_rate", diff_failure_rate)
493-
record_property("cross_entropy_loss_failure_rate", ce_failure_rate)
494+
if record_property is not None:
495+
# Add failure rates to xml report
496+
record_property("mean_diff_failure_rate", diff_failure_rate)
497+
record_property("cross_entropy_loss_failure_rate", ce_failure_rate)
494498

495499
if "mean_diff" not in SKIP_ASSERTIONS:
496500
assert diff_failure_rate < FAILURE_RATE_THRESHOLD, (
@@ -559,7 +563,6 @@ def _get_cpu_model(is_gptq, is_fp8, micro_model_state_dict=None, **kwargs):
559563
return validation_model
560564

561565

562-
563566
def _get_device_validation_information(
564567
model_path,
565568
batch_size,
@@ -572,7 +575,6 @@ def _get_device_validation_information(
572575
token_iter,
573576
device="aiu",
574577
tokenizer=None,
575-
only_last_token=None,
576578
):
577579
# For CPU, we try to load it from disk first if it exists
578580
if device == "cpu":
@@ -583,7 +585,7 @@ def _get_device_validation_information(
583585
max_new_tokens,
584586
tokenizer,
585587
token_iter,
586-
ATTN_NAME, # TODO checkme
588+
ATTN_NAME,
587589
)
588590

589591
if cpu_validation_info is not None:
@@ -594,8 +596,7 @@ def _get_device_validation_information(
594596
if device == "cpu":
595597
device_dependent_kwargs["attn_algorithm"] = "math"
596598

597-
if device == "aiu" and only_last_token is not None:
598-
device_dependent_kwargs["only_last_token"] = only_last_token
599+
if device == "aiu":
599600
device_dependent_kwargs["last_n_tokens"] = 64 if "paged" in ATTN_NAME else 1
600601

601602
# Otherwise we need to get the AIU / CPU validation info
@@ -617,7 +618,7 @@ def _get_device_validation_information(
617618

618619
validation_info.save(
619620
get_validation_info_path(
620-
validation_info_dir,
621+
VALIDATION_INFO_DIR,
621622
model_path,
622623
batch_size,
623624
seq_length,
@@ -697,7 +698,6 @@ def _run_validation_level_0(
697698
token_iter=0,
698699
device="aiu",
699700
tokenizer=tokenizer,
700-
only_last_token="paged" not in ATTN_NAME,
701701
)
702702
dprint("aiu validation info extracted for validation level 0")
703703

@@ -727,6 +727,7 @@ def _run_validation_level_1(
727727
model,
728728
micro_model_path,
729729
validation_zero_info,
730+
record_property,
730731
):
731732
iters = int(CUMULATIVE_TEST_TOKENS_PER_SEQUENCE) // max_new_tokens
732733
ce_fail_responses_list = []
@@ -775,7 +776,6 @@ def _run_validation_level_1(
775776
token_iter=i,
776777
device="aiu",
777778
tokenizer=tokenizer,
778-
only_last_token=ATTN_TYPE != "paged",
779779
)
780780
dprint(f"aiu validation info extracted for validation level 1 - iter={i}")
781781

@@ -804,7 +804,10 @@ def _run_validation_level_1(
804804
total_tokens += len(level_1_metrics)
805805

806806
_check_failure_thresholds(
807-
diff_fail_responses_list, ce_fail_responses_list, total_tokens
807+
diff_fail_responses_list,
808+
ce_fail_responses_list,
809+
total_tokens,
810+
record_property,
808811
)
809812

810813

@@ -817,12 +820,15 @@ def _run_cpu_aiu_validation_test(
817820
cpu_model,
818821
aiu_model,
819822
micro_model_path,
823+
record_property,
820824
):
821825
# Get the tokenizer and AIU / CPU models to compare
822826
tokenizer = AutoTokenizer.from_pretrained(model_path)
823827

824828
# prepare input_ids
825-
input_ids, extra_kwargs = __prepare_inputs(batch_size, seq_length, tokenizer)
829+
input_ids, extra_kwargs = __prepare_inputs(
830+
batch_size, seq_length, tokenizer, model_path
831+
)
826832

827833
extra_kwargs["attn_name"] = ATTN_NAME
828834
if (
@@ -869,11 +875,12 @@ def _run_cpu_aiu_validation_test(
869875
aiu_model,
870876
micro_model_path,
871877
validation_zero_info,
878+
record_property,
872879
)
873880

874881

875882
@pytest.mark.parametrize(
876-
"model_path,batch_size,seq_length,max_new_tokens", common_shapes
883+
"model_path,batch_size,seq_length,max_new_tokens", COMMON_SHAPES
877884
)
878885
def test_common_shapes(
879886
model_path,
@@ -894,7 +901,6 @@ def test_common_shapes(
894901

895902
# we don't currently support inferring gptq from get_model, so we must use an adapter with hf_configured
896903
gptq_kwargs_aiu, gptq_kwargs_cpu = __maybe_get_gptq_kwargs(model_path)
897-
898904
is_gptq = len(gptq_kwargs_aiu) != 0
899905
is_fp8 = "fp8" in ATTN_NAME
900906
model_kwargs = _get_common_model_kwargs(is_gptq, model_path)
@@ -920,4 +926,5 @@ def test_common_shapes(
920926
validation_model,
921927
model,
922928
micro_model_path,
929+
record_property,
923930
)

0 commit comments

Comments
 (0)