162162 action = "store_true" ,
163163 help = "set to true to save cpu validation outputs for later consumption" ,
164164)
165+ parser .add_argument (
166+ "--only_save_validation_output" ,
167+ action = "store_true" ,
168+ help = "set to true to ONLY save cpu validation outputs for later consumption" ,
169+ )
165170parser .add_argument (
166171 "--prioritize_large_batch_sizes" ,
167172 action = "store_true" ,
@@ -391,14 +396,16 @@ def __load_validation_info(
391396 and dist .get_world_size () == 4
392397):
393398 extra_kwargs ["_kvcache_num_blocks_hint" ] = KVCACHE_NUM_BLOCKS_HINT
394- warmup_model (
395- model ,
396- input_ids ,
397- max_new_tokens = max_new_tokens ,
398- compile_dynamic_sendnn = True ,
399- stagger_update_lazyhandle = args .stagger_update_lazyhandle ,
400- ** extra_kwargs ,
401- )
399+
400+ if not args .only_save_validation_output :
401+ warmup_model (
402+ model ,
403+ input_ids ,
404+ max_new_tokens = max_new_tokens ,
405+ compile_dynamic_sendnn = True ,
406+ stagger_update_lazyhandle = args .stagger_update_lazyhandle ,
407+ ** extra_kwargs ,
408+ )
402409
403410if USE_DISTRIBUTED :
404411 # wait for rank0 to be finished as it is the only one generating the criteria json
@@ -659,7 +666,7 @@ def __metric_calculator(r: torch.Tensor, t: torch.Tensor):
659666 ** extra_kwargs ,
660667 )
661668 # save the cpu validation info for later consumption
662- if save_validation_info_outputs :
669+ if save_validation_info_outputs or args . only_save_validation_output :
663670 cpu_validation_info .save (
664671 get_validation_info_path (
665672 args .validation_info_outputs_dir ,
@@ -674,7 +681,7 @@ def __metric_calculator(r: torch.Tensor, t: torch.Tensor):
674681 )
675682 )
676683
677- if args .test_type == "metrics" :
684+ if args .test_type == "metrics" and not args . only_save_validation_output :
678685 aiu_validation_info = extract_validation_information (
679686 model ,
680687 input_ids ,
@@ -718,7 +725,7 @@ def __metric_calculator(r: torch.Tensor, t: torch.Tensor):
718725 if failure_rate >= args .failure_rate_threshold :
719726 failed_cases .append ((program_id , valid_prompt , failure_rate ))
720727
721- elif args .test_type == "tokens" :
728+ elif args .test_type == "tokens" and not args . only_save_validation_output :
722729 aiu_validation_info = extract_validation_information (
723730 model ,
724731 input_ids ,
@@ -758,9 +765,11 @@ def __metric_calculator(r: torch.Tensor, t: torch.Tensor):
758765 dprint (f"AIU tokens:\n { aiu_tokens_generated } " )
759766 dprint (f"CPU output:\n { tokenizer .decode (cpu_tokens_generated )} " )
760767 dprint (f"AIU output:\n { tokenizer .decode (aiu_tokens_generated )} " )
768+ elif args .only_save_validation_output :
769+ pass
761770 else :
762771 raise ValueError ("test type must be one of metrics or tokens" )
763- else :
772+ elif not args . only_save_validation_output :
764773 aiu_validation_info = extract_validation_information (
765774 model ,
766775 input_ids ,
@@ -784,7 +793,11 @@ def __metric_calculator(r: torch.Tensor, t: torch.Tensor):
784793 dprint (f"AIU tokens:\n { aiu_tokens_generated } " )
785794 dprint (f"AIU output:\n { tokenizer .decode (aiu_tokens_generated )} " )
786795
787- if not args .skip_validation and local_rank == 0 :
796+ if (
797+ not args .skip_validation
798+ and local_rank == 0
799+ and not args .only_save_validation_output
800+ ):
788801 if len (failed_cases ) != 0 :
789802 dprint ("the test failed with the following cases:" )
790803 for failed_case in failed_cases :
0 commit comments