132132if USE_MICRO_MODELS :
133133 VALIDATION_INFO_DIR = os .path .join (VALIDATION_INFO_DIR , "tiny_models" )
134134
135- # pass custom model path list for eg: EXPORT FMS_TESTING_COMMON_MODEL_PATHS ="/tmp/models/granite-3-8b-base,/tmp/models/granite-7b-base"
135+ # pass custom model path list for eg: EXPORT FMS_TEST_SHAPES_COMMON_MODEL_PATHS ="/tmp/models/granite-3-8b-base,/tmp/models/granite-7b-base"
136136if isinstance (COMMON_MODEL_PATHS , str ):
137137 COMMON_MODEL_PATHS = COMMON_MODEL_PATHS .split ("," )
138138
185185 ]
186186 )
187187 os .environ ["VLLM_DT_MAX_BATCH_SIZE" ] = str (max (max (COMMON_BATCH_SIZES ), 2 ))
188-
188+ fx_config . backed_size_oblivious = True
189189
190190# thresholds are chosen based on 1024 tokens per sequence
191191# 1% error threshold rate between cpu fp32 and cuda fp16
220220 )
221221 USE_MICRO_MODELS = False
222222 COMMON_MODEL_PATHS = []
223- frequency = int (MODEL_CONFIGURATION_FREQUENCY )
223+ FREQUENCY = int (MODEL_CONFIGURATION_FREQUENCY )
224224 with open (MODEL_CONFIGURATION_PATH , "r" ) as f :
225225 for line in f :
226226 try :
227- model_config = json .loads (line )
228- if model_config ["frequency" ] <= frequency :
229- COMMON_MODEL_PATHS .append (model_config ["model_id" ])
227+ MODEL_CONFIG = json .loads (line )
228+ if MODEL_CONFIG ["frequency" ] <= FREQUENCY :
229+ COMMON_MODEL_PATHS .append (MODEL_CONFIG ["model_id" ])
230230 # assume fullsize models
231- FAIL_THRESHOLDS [(model_config ["model_id" ], USE_MICRO_MODELS )] = (
232- model_config ["ce" ],
233- model_config ["mean_diff" ],
231+ FAIL_THRESHOLDS [(MODEL_CONFIG ["model_id" ], USE_MICRO_MODELS )] = (
232+ MODEL_CONFIG ["ce" ],
233+ MODEL_CONFIG ["mean_diff" ],
234234 )
235235 except json .JSONDecodeError :
236236 print (f"config contained an improper json line: { line .strip ()} " )
237237
238- common_shapes = list (
238+ COMMON_SHAPES = list (
239239 itertools .product (
240240 COMMON_MODEL_PATHS ,
241241 COMMON_BATCH_SIZES ,
@@ -308,7 +308,7 @@ def __maybe_get_gptq_kwargs(model_path):
308308 return gptq_kwargs_aiu , gptq_kwargs_cpu
309309
310310
311- def __prepare_inputs (batch_size , seq_length , tokenizer , seed = 0 ):
311+ def __prepare_inputs (batch_size , seq_length , tokenizer , model_path , seed = 0 ):
312312 if "paged" in ATTN_NAME :
313313 prompts_and_sizes = sample_sharegpt_requests (
314314 SHARE_GPT_DATASET_PATH ,
@@ -480,17 +480,21 @@ def _metric_calculator(r: torch.Tensor, t: torch.Tensor):
480480
481481
482482def _check_failure_thresholds (
483- diff_fail_responses_list , ce_fail_responses_list , total_tokens
483+ diff_fail_responses_list ,
484+ ce_fail_responses_list ,
485+ total_tokens ,
486+ record_property = None ,
484487):
485488 # test the failure rates for across all tokens
486489 diff_failure_rate = len (diff_fail_responses_list ) / total_tokens
487490 ce_failure_rate = len (ce_fail_responses_list ) / total_tokens
488491 dprint (f"mean diff failure rate: { diff_failure_rate } " )
489492 dprint (f"cross entropy loss failure rate: { ce_failure_rate } " )
490493
491- # Add failure rates to xml report
492- record_property ("mean_diff_failure_rate" , diff_failure_rate )
493- record_property ("cross_entropy_loss_failure_rate" , ce_failure_rate )
494+ if record_property is not None :
495+ # Add failure rates to xml report
496+ record_property ("mean_diff_failure_rate" , diff_failure_rate )
497+ record_property ("cross_entropy_loss_failure_rate" , ce_failure_rate )
494498
495499 if "mean_diff" not in SKIP_ASSERTIONS :
496500 assert diff_failure_rate < FAILURE_RATE_THRESHOLD , (
@@ -559,7 +563,6 @@ def _get_cpu_model(is_gptq, is_fp8, micro_model_state_dict=None, **kwargs):
559563 return validation_model
560564
561565
562-
563566def _get_device_validation_information (
564567 model_path ,
565568 batch_size ,
@@ -572,7 +575,6 @@ def _get_device_validation_information(
572575 token_iter ,
573576 device = "aiu" ,
574577 tokenizer = None ,
575- only_last_token = None ,
576578):
577579 # For CPU, we try to load it from disk first if it exists
578580 if device == "cpu" :
@@ -583,7 +585,7 @@ def _get_device_validation_information(
583585 max_new_tokens ,
584586 tokenizer ,
585587 token_iter ,
586- ATTN_NAME , # TODO checkme
588+ ATTN_NAME ,
587589 )
588590
589591 if cpu_validation_info is not None :
@@ -594,8 +596,7 @@ def _get_device_validation_information(
594596 if device == "cpu" :
595597 device_dependent_kwargs ["attn_algorithm" ] = "math"
596598
597- if device == "aiu" and only_last_token is not None :
598- device_dependent_kwargs ["only_last_token" ] = only_last_token
599+ if device == "aiu" :
599600 device_dependent_kwargs ["last_n_tokens" ] = 64 if "paged" in ATTN_NAME else 1
600601
601602 # Otherwise we need to get the AIU / CPU validation info
@@ -617,7 +618,7 @@ def _get_device_validation_information(
617618
618619 validation_info .save (
619620 get_validation_info_path (
620- validation_info_dir ,
621+ VALIDATION_INFO_DIR ,
621622 model_path ,
622623 batch_size ,
623624 seq_length ,
@@ -697,7 +698,6 @@ def _run_validation_level_0(
697698 token_iter = 0 ,
698699 device = "aiu" ,
699700 tokenizer = tokenizer ,
700- only_last_token = "paged" not in ATTN_NAME ,
701701 )
702702 dprint ("aiu validation info extracted for validation level 0" )
703703
@@ -727,6 +727,7 @@ def _run_validation_level_1(
727727 model ,
728728 micro_model_path ,
729729 validation_zero_info ,
730+ record_property ,
730731):
731732 iters = int (CUMULATIVE_TEST_TOKENS_PER_SEQUENCE ) // max_new_tokens
732733 ce_fail_responses_list = []
@@ -775,7 +776,6 @@ def _run_validation_level_1(
775776 token_iter = i ,
776777 device = "aiu" ,
777778 tokenizer = tokenizer ,
778- only_last_token = ATTN_TYPE != "paged" ,
779779 )
780780 dprint (f"aiu validation info extracted for validation level 1 - iter={ i } " )
781781
@@ -804,7 +804,10 @@ def _run_validation_level_1(
804804 total_tokens += len (level_1_metrics )
805805
806806 _check_failure_thresholds (
807- diff_fail_responses_list , ce_fail_responses_list , total_tokens
807+ diff_fail_responses_list ,
808+ ce_fail_responses_list ,
809+ total_tokens ,
810+ record_property ,
808811 )
809812
810813
@@ -817,12 +820,15 @@ def _run_cpu_aiu_validation_test(
817820 cpu_model ,
818821 aiu_model ,
819822 micro_model_path ,
823+ record_property ,
820824):
821825 # Get the tokenizer and AIU / CPU models to compare
822826 tokenizer = AutoTokenizer .from_pretrained (model_path )
823827
824828 # prepare input_ids
825- input_ids , extra_kwargs = __prepare_inputs (batch_size , seq_length , tokenizer )
829+ input_ids , extra_kwargs = __prepare_inputs (
830+ batch_size , seq_length , tokenizer , model_path
831+ )
826832
827833 extra_kwargs ["attn_name" ] = ATTN_NAME
828834 if (
@@ -869,11 +875,12 @@ def _run_cpu_aiu_validation_test(
869875 aiu_model ,
870876 micro_model_path ,
871877 validation_zero_info ,
878+ record_property ,
872879 )
873880
874881
875882@pytest .mark .parametrize (
876- "model_path,batch_size,seq_length,max_new_tokens" , common_shapes
883+ "model_path,batch_size,seq_length,max_new_tokens" , COMMON_SHAPES
877884)
878885def test_common_shapes (
879886 model_path ,
@@ -894,7 +901,6 @@ def test_common_shapes(
894901
895902 # we don't currently support inferring gptq from get_model, so we must use an adapter with hf_configured
896903 gptq_kwargs_aiu , gptq_kwargs_cpu = __maybe_get_gptq_kwargs (model_path )
897-
898904 is_gptq = len (gptq_kwargs_aiu ) != 0
899905 is_fp8 = "fp8" in ATTN_NAME
900906 model_kwargs = _get_common_model_kwargs (is_gptq , model_path )
@@ -920,4 +926,5 @@ def test_common_shapes(
920926 validation_model ,
921927 model ,
922928 micro_model_path ,
929+ record_property ,
923930 )
0 commit comments