diff --git a/aiu_fms_testing_utils/scripts/drive_paged_programs.py b/aiu_fms_testing_utils/scripts/drive_paged_programs.py index 033a8efe..b3aff457 100644 --- a/aiu_fms_testing_utils/scripts/drive_paged_programs.py +++ b/aiu_fms_testing_utils/scripts/drive_paged_programs.py @@ -15,13 +15,13 @@ from torch import distributed as dist from torch.fx.experimental import _config as fx_config from transformers import AutoTokenizer +import numpy as np from aiu_fms_testing_utils.testing.validation import ( GoldenTokenHook, LogitsExtractorHook, capture_level_1_metrics, extract_validation_information, - filter_failed_level_1_cases, find_validation_info_path, get_validation_info_path, load_validation_information, @@ -108,12 +108,26 @@ ) parser.add_argument( - "--cross_entropy_threshold", + "--default_cross_entropy_threshold", type=float, default=2.5, help="threshold to denote passing/failing a given iteration", ) +parser.add_argument( + "--cross_entropy_threshold_path", + type=str, + default=None, + help="path to a file with all expected cross-entropy loss thresholds per program, pre sequence", +) + +parser.add_argument( + "--per_sequence_failure_rate_threshold", + type=float, + default=0.1, + help="the threshold which denotes whether to pass or fail the test for a given sequence.", +) + parser.add_argument( "--failure_rate_threshold", type=float, @@ -172,6 +186,12 @@ action="store_true", help="set to true ensure that all prompts hit the same prompt program for a given test", ) +parser.add_argument( + "--generate_metrics_path", + type=str, + default=None, + help="if set, will bypass AIU model processing and generate cross-entropy loss thresholds used for testing, and save the metrics to the given path", +) args = parser.parse_args() @@ -180,9 +200,19 @@ model_variant = args.model_variant DATASET_PATH = args.dataset_path save_validation_info_outputs = args.save_validation_info_outputs +generate_metrics = args.generate_metrics_path is not None tokenizer = AutoTokenizer.from_pretrained(model_variant) custom_shape = None +default_cross_entropy_threshold = float(args.default_cross_entropy_threshold) +program_threshold_dict = {} +# if the path exists, load it as a json +if args.cross_entropy_threshold_path is not None and os.path.exists( + args.cross_entropy_threshold_path +): + with open(args.cross_entropy_threshold_path, "r") as f: + program_threshold_dict = json.load(f) + if args.dataset_type == "custom": if local_rank == 0: dprint( @@ -332,13 +362,17 @@ def __load_validation_info( distributed_kwargs = {} if USE_DISTRIBUTED: + if generate_metrics: + torch.cuda.set_device(local_rank) if args.dist_timeout > 0: # Default timeout: # https://docs.pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group - dist.init_process_group(timeout=datetime.timedelta(minutes=args.dist_timeout)) + dist.init_process_group( + timeout=datetime.timedelta(minutes=args.dist_timeout), backend="gloo" + ) dprint(f"NOTICE: init_process_group timeout set to {args.dist_timeout} minutes") else: - dist.init_process_group() + dist.init_process_group(backend="gloo") aiu_dist_setup(dist.get_rank(), dist.get_world_size()) distributed_kwargs["distributed_strategy"] = "tp" distributed_kwargs["group"] = dist.group.WORLD @@ -349,7 +383,7 @@ def __load_validation_info( with stagger_region(args.stagger_load): model = get_model( architecture="hf_pretrained", - device_type="cpu", + device_type="cuda" if generate_metrics else "cpu", data_type=None if is_fp8 else torch.float16, fused_weights=False, **model_path_kwargs, @@ -358,7 +392,8 @@ def __load_validation_info( model.eval() fx_config.backed_size_oblivious = True -model.compile(backend="sendnn", options={"sendnn.dynamic": True}) +if not generate_metrics: + model.compile(backend="sendnn", options={"sendnn.dynamic": True}) __maybe_prepare_fp8_weights(model, is_fp8) @@ -391,14 +426,16 @@ def __load_validation_info( and dist.get_world_size() == 4 ): extra_kwargs["_kvcache_num_blocks_hint"] = KVCACHE_NUM_BLOCKS_HINT -warmup_model( - model, - input_ids, - max_new_tokens=max_new_tokens, - compile_dynamic_sendnn=True, - stagger_update_lazyhandle=args.stagger_update_lazyhandle, - **extra_kwargs, -) + +if not generate_metrics: + warmup_model( + model, + input_ids, + max_new_tokens=max_new_tokens, + compile_dynamic_sendnn=True, + stagger_update_lazyhandle=args.stagger_update_lazyhandle, + **extra_kwargs, + ) if USE_DISTRIBUTED: # wait for rank0 to be finished as it is the only one generating the criteria json @@ -608,18 +645,19 @@ def parse_program_limit(limit_str: str) -> tuple[int, str]: # metric calculator based on the cross-entropy and mean diff for each decode step def __metric_calculator(r: torch.Tensor, t: torch.Tensor): cross_entropy = torch.nn.CrossEntropyLoss()( - r, t.softmax(dim=1).to(dtype=torch.float32) + r, t.softmax(dim=1).to(device="cpu", dtype=torch.float32) ) diff = torch.mean( torch.abs( r.softmax(dim=1).to(dtype=torch.float32) - - t.softmax(dim=1).to(dtype=torch.float32) + - t.softmax(dim=1).to(device="cpu", dtype=torch.float32) ) ) return (cross_entropy, diff) -failed_cases = [] +per_sentence_failed_cases = [] +aggregate_failed_cases = [] # for each program and valid prompt (batch size, sequence length) for program_id, valid_prompt, input_ids, extra_kwargs, sample_key in valid_prompts: extra_kwargs["attn_name"] = ATTN_NAME @@ -675,6 +713,14 @@ def __metric_calculator(r: torch.Tensor, t: torch.Tensor): ) if args.test_type == "metrics": + # if we are generating metrics, all inputs should be on cuda device + if generate_metrics: + input_ids = input_ids.to("cuda") + extra_kwargs = { + k: v.to("cuda") if isinstance(v, torch.Tensor) else v + for k, v in extra_kwargs.items() + } + aiu_validation_info = extract_validation_information( model, input_ids, @@ -711,12 +757,70 @@ def __metric_calculator(r: torch.Tensor, t: torch.Tensor): f'For Program {program_id} in sentence {sentence_idx + 1}: the metric for token {token_idx} is {metrics_value}, AIU ID="{aiu_token.item()}" | STR="{aiu_str}" -- CPU ID="{cpu_token.item()}" | CPU STR="{cpu_str}"' ) - ce_fail_responses = filter_failed_level_1_cases( - level_1_metrics, lambda m: m[0] >= args.cross_entropy_threshold - ) - failure_rate = len(ce_fail_responses) / len(level_1_metrics) - if failure_rate >= args.failure_rate_threshold: - failed_cases.append((program_id, valid_prompt, failure_rate)) + # if generating metrics, get the 99th percentile ce threshold per sentence + # otherwise test the thresholds + if generate_metrics: + sentence_ce_dict = {} + for sentence_idx, token_idx, metrics_value in level_1_metrics: + sentence_ce_dict.setdefault(sentence_idx, []) + sentence_ce_dict[sentence_idx].append(metrics_value[0]) + + sentence_ce_threshold = { + k: np.percentile(v, 99) for k, v in sentence_ce_dict.items() + } + if local_rank == 0: + dprint( + f"Program {str(program_id.program_id)} produced the following thresholds:\n{sentence_ce_threshold}" + ) + program_threshold_dict[(str(program_id.program_id), sample_key)] = ( + sentence_ce_threshold + ) + else: + sentence_failures_dict = {} + for sentence_idx, token_idx, metrics_value in level_1_metrics: + program_threshold_key = f"{str(program_id.program_id)},{sample_key}" + if ( + len(program_threshold_dict) != 0 + and program_threshold_key not in program_threshold_dict + and local_rank == 0 + ): + dprint( + f"could not find the following key {program_threshold_key}, defaulting to {default_cross_entropy_threshold}" + ) + ce_threshold = program_threshold_dict.get( + program_threshold_key, + {str(sentence_idx): default_cross_entropy_threshold}, + )[str(sentence_idx)] + sentence_failures_dict.setdefault(sentence_idx, 0) + if metrics_value[0].item() >= ce_threshold: + sentence_failures_dict[sentence_idx] += 1 + + for sentence_idx, failure_count in sentence_failures_dict.items(): + per_sentence_failure_rate = failure_count / max_new_tokens + if ( + per_sentence_failure_rate + >= args.per_sequence_failure_rate_threshold + ): + per_sentence_failed_cases.append( + ( + program_id, + valid_prompt, + sentence_idx, + per_sentence_failure_rate, + ) + ) + + aggregate_failure_rate = sum(sentence_failures_dict.values()) / ( + max_new_tokens * len(sentence_failures_dict) + ) + if aggregate_failure_rate >= args.failure_rate_threshold: + aggregate_failed_cases.append( + ( + program_id, + valid_prompt, + aggregate_failure_rate, + ) + ) elif args.test_type == "tokens": aiu_validation_info = extract_validation_information( @@ -784,12 +888,30 @@ def __metric_calculator(r: torch.Tensor, t: torch.Tensor): dprint(f"AIU tokens:\n{aiu_tokens_generated}") dprint(f"AIU output:\n{tokenizer.decode(aiu_tokens_generated)}") +if generate_metrics and local_rank == 0: + with open(args.generate_metrics_path, "w") as f: + json_dict = {} + for program_seq, sentence_ce_threshold_dict in program_threshold_dict.items(): + program_seq_key = ",".join(program_seq) + json_dict[program_seq_key] = {} + for sentence_i, ce_threshold in sentence_ce_threshold_dict.items(): + json_dict[program_seq_key][sentence_i] = float(ce_threshold) + + json.dump(json_dict, f, indent=4) + if not args.skip_validation and local_rank == 0: - if len(failed_cases) != 0: - dprint("the test failed with the following cases:") - for failed_case in failed_cases: + if len(aggregate_failed_cases) != 0: + dprint("the test failed with the following aggregate cases:") + for failed_case in aggregate_failed_cases: dprint( f"Program ID: {failed_case[0]}, Prompt Shape: {failed_case[1]}, Failure Rate: {failed_case[2]}" ) - else: + if len(per_sentence_failed_cases) != 0: + dprint("the test failed with the following per sentence cases:") + for failed_case in per_sentence_failed_cases: + dprint( + f"Program ID: {failed_case[0]}, Prompt Shape: {failed_case[1]}, Sentence Index: {failed_case[2]}, Failure Rate: {failed_case[3]}" + ) + + if len(aggregate_failed_cases) == 0 and len(per_sentence_failed_cases) == 0: dprint("all tests passed") diff --git a/aiu_fms_testing_utils/testing/validation.py b/aiu_fms_testing_utils/testing/validation.py index 5bf120a0..9cc9584c 100644 --- a/aiu_fms_testing_utils/testing/validation.py +++ b/aiu_fms_testing_utils/testing/validation.py @@ -218,6 +218,7 @@ def load_validation_information( f"Not enough validation files at {validation_files_path} for a batch size of {batch_size}" ) + validation_files_paths.sort(key=lambda p: int(p.name.split(".pt")[0])) validation_info = [] for i, validation_file_path in enumerate(validation_files_paths): if i == batch_size: diff --git a/aiu_fms_testing_utils/utils/paged.py b/aiu_fms_testing_utils/utils/paged.py index edf0c548..ea6453d5 100644 --- a/aiu_fms_testing_utils/utils/paged.py +++ b/aiu_fms_testing_utils/utils/paged.py @@ -189,10 +189,10 @@ def generate( ( torch.zeros( NUM_BLOCKS, BLOCK_SIZE, kvheads, head_size, dtype=model_dtype - ), + ).to(input_ids.device), torch.zeros( NUM_BLOCKS, BLOCK_SIZE, kvheads, head_size, dtype=model_dtype - ), + ).to(input_ids.device), ) for _ in range(model.config.nlayers) ] @@ -304,9 +304,9 @@ def generate( last_n_tokens = kwargs.get("last_n_tokens", 0) output, current_kv_cache = model( input_ids_i, - slot_mapping=slot_mapping_i, - position_ids=position_ids_i, - mask=mask_i, + slot_mapping=slot_mapping_i.to(input_ids.device), + position_ids=position_ids_i.to(input_ids.device), + mask=mask_i.to(input_ids.device), past_key_value_states=current_kv_cache, use_cache=kwargs["use_cache"], last_n_tokens=last_n_tokens, @@ -342,8 +342,10 @@ def generate( # mask is no longer used here kwargs["mask"] = None kwargs["position_ids"] = kwargs["position_ids"][:, -1:] + 1 - kwargs["position_ids"] = kwargs["position_ids"].clone( - memory_format=torch.contiguous_format + kwargs["position_ids"] = ( + kwargs["position_ids"] + .clone(memory_format=torch.contiguous_format) + .to(device=input_ids.device) ) kwargs["last_n_tokens"] = 1 @@ -371,14 +373,20 @@ def generate( for b_seq in block_table ], dtype=torch.int64, + ).to(device=input_ids.device) + kwargs["left_padded_prompt_mask"] = left_padded_prompt_mask.to( + device=input_ids.device ) - kwargs["left_padded_prompt_mask"] = left_padded_prompt_mask current_tkv_mask = current_tkv_mask + 1 - kwargs["current_tkv_mask"] = current_tkv_mask - kwargs["slot_mapping"] = torch.tensor(slot_mapping, dtype=torch.int64) + kwargs["current_tkv_mask"] = current_tkv_mask.to(device=input_ids.device) + kwargs["slot_mapping"] = torch.tensor(slot_mapping, dtype=torch.int64).to( + device=input_ids.device + ) # batch - input_ids = input_ids.clone(memory_format=torch.contiguous_format) + input_ids = input_ids.clone(memory_format=torch.contiguous_format).to( + device=input_ids.device + ) torch._dynamo.mark_dynamic(input_ids, 0) torch._dynamo.mark_dynamic(kwargs["block_table"], 0) torch._dynamo.mark_dynamic(kwargs["slot_mapping"], 0) diff --git a/tests/testing/test_validation.py b/tests/testing/test_validation.py index 95f2ff4e..dc832055 100644 --- a/tests/testing/test_validation.py +++ b/tests/testing/test_validation.py @@ -37,7 +37,7 @@ def test_validation_info_round_trip(validation_type, post_iteration_hook): model.reset_parameters() seq_length = 64 - batch_size = 8 + batch_size = 16 max_new_tokens = 128 # prepare input_ids