diff --git a/aiu_fms_testing_utils/scripts/drive_paged_programs.py b/aiu_fms_testing_utils/scripts/drive_paged_programs.py index 6b8f3bef..2e75994b 100644 --- a/aiu_fms_testing_utils/scripts/drive_paged_programs.py +++ b/aiu_fms_testing_utils/scripts/drive_paged_programs.py @@ -523,129 +523,129 @@ def parse_program_limit(limit_str: str) -> tuple[int, str]: for v in program_map.values(): random.Random(42).shuffle(v) + # select prompts that fit the batch size criteria -valid_prompts = [] -if custom_shape: - for program_criteria_seq, valid_prompt_shapes in program_map.items(): - for valid_prompt_shape in valid_prompt_shapes: - if valid_prompt_shape == custom_shape: - enforce_sizes = [valid_prompt_shape[1]] - input_ids, extra_kwargs, sample_key = __prepare_inputs( - valid_prompt_shape[0], - valid_prompt_shape[1], - tokenizer, - enforce_sizes=enforce_sizes, - pad_multiple=pad_multiple, - ) - valid_prompts = [ - ( +def get_program_prompt_list(): + if custom_shape: + prompt_found = 0 + for program_criteria_seq, valid_prompt_shapes in program_map.items(): + for valid_prompt_shape in valid_prompt_shapes: + if valid_prompt_shape == custom_shape: + enforce_sizes = [valid_prompt_shape[1]] + input_ids, extra_kwargs, sample_key = __prepare_inputs( + valid_prompt_shape[0], + valid_prompt_shape[1], + tokenizer, + enforce_sizes=enforce_sizes, + pad_multiple=pad_multiple, + ) + prompt_found = 1 + yield ( program_criteria_seq[0].program_id, custom_shape, input_ids, extra_kwargs, sample_key, ) - ] + break + if prompt_found: break - if len(valid_prompts) > 0: - break -else: - for program_info in programs: - program_id = program_info.program_id - batch_size_limit = program_info.batch_size_limit - batch_size_limit_type = program_info.batch_size_limit_type - prompt_length_limit = program_info.prompt_length_limit - prompt_length_limit_type = program_info.prompt_length_limit_type - - filtered_program_map = program_map - if program_id.isnumeric(): - filtered_program_map = { - k: v - for k, v in program_map.items() - if k[0] == program_criteria_list[int(program_id)] - } - used_keys = set() - # for each program, we need to check if we have a shape that satisfies the --programs request - for program_seq_key, valid_prompt_shapes in filtered_program_map.items(): - # if ? or numeric => we need to check if we have found at least one valid key to stop - if (program_id == "?" or program_id.isnumeric()) and len(used_keys) > 0: - break - # if * => we need to see if we have found the first key to see if we should skip - elif program_id == "*" and program_seq_key[0] in used_keys: - continue - - for valid_prompt_shape in valid_prompt_shapes: - # make sure the criteria for batch limit and prompt limit is satisfied - # eval is safe here because we have limited what type and limit can be before + else: + for program_info in programs: + program_id = program_info.program_id + batch_size_limit = program_info.batch_size_limit + batch_size_limit_type = program_info.batch_size_limit_type + prompt_length_limit = program_info.prompt_length_limit + prompt_length_limit_type = program_info.prompt_length_limit_type + + filtered_program_map = program_map + if program_id.isnumeric(): + filtered_program_map = { + k: v + for k, v in program_map.items() + if k[0] == program_criteria_list[int(program_id)] + } + used_keys = set() + # for each program, we need to check if we have a shape that satisfies the --programs request + for program_seq_key, valid_prompt_shapes in filtered_program_map.items(): + # if ? or numeric => we need to check if we have found at least one valid key to stop + if (program_id == "?" or program_id.isnumeric()) and len(used_keys) > 0: + break + # if * => we need to see if we have found the first key to see if we should skip + elif program_id == "*" and program_seq_key[0] in used_keys: + continue + + for valid_prompt_shape in valid_prompt_shapes: + # make sure the criteria for batch limit and prompt limit is satisfied + # eval is safe here because we have limited what type and limit can be before + + batch_check = eval( + f"valid_prompt_shape[0] {batch_size_limit_type} {batch_size_limit}" + ) + prompt_check = eval( + f"valid_prompt_shape[1] {prompt_length_limit_type} {prompt_length_limit}" + ) + if batch_check and prompt_check: + # when we enforce homogeneous prompt programs, we will cycle through all sizes between the min of a program and the valid prompt sequence length + # if there does not exist enough sequence sizes between this range, we will cycle back to the beginning + # in the event we don't have enough sequences that satisfy the enforce_sizes, we will repeat sequences and warn the user + enforce_sizes = [valid_prompt_shape[1]] + if ( + args.enforce_homogeneous_prompt_programs + or args.prefill_chunk_size > 0 + ): + # if enforcing homogeneous prompt programs, this will get the number of bits for the sequence length and shift to get the power of 2 that is less than or equal to the sequence length + tkv_cutoff = ( + 1 << (valid_prompt_shape[1].bit_length() - 1) + if args.enforce_homogeneous_prompt_programs + else pad_multiple + ) - batch_check = eval( - f"valid_prompt_shape[0] {batch_size_limit_type} {batch_size_limit}" - ) - prompt_check = eval( - f"valid_prompt_shape[1] {prompt_length_limit_type} {prompt_length_limit}" - ) - if batch_check and prompt_check: - # when we enforce homogeneous prompt programs, we will cycle through all sizes between the min of a program and the valid prompt sequence length - # if there does not exist enough sequence sizes between this range, we will cycle back to the beginning - # in the event we don't have enough sequences that satisfy the enforce_sizes, we will repeat sequences and warn the user - enforce_sizes = [valid_prompt_shape[1]] - if ( - args.enforce_homogeneous_prompt_programs - or args.prefill_chunk_size > 0 - ): - # if enforcing homogeneous prompt programs, this will get the number of bits for the sequence length and shift to get the power of 2 that is less than or equal to the sequence length - tkv_cutoff = ( - 1 << (valid_prompt_shape[1].bit_length() - 1) - if args.enforce_homogeneous_prompt_programs - else pad_multiple - ) - - possible_seq_lengths = [ - _ - for _ in range( - tkv_cutoff, valid_prompt_shape[1], pad_multiple + possible_seq_lengths = [ + _ + for _ in range( + tkv_cutoff, valid_prompt_shape[1], pad_multiple + ) + ] + # favor sequences that are close to the valid prompt length + possible_seq_lengths.reverse() + # add the valid prompt size to the end since it will already exist in the above enforce_sizes + possible_seq_lengths = possible_seq_lengths + [ + valid_prompt_shape[1] + ] + enforce_sizes = enforce_sizes + list( + itertools.islice( + itertools.cycle(possible_seq_lengths), + valid_prompt_shape[0] - 1, + ) ) - ] - # favor sequences that are close to the valid prompt length - possible_seq_lengths.reverse() - # add the valid prompt size to the end since it will already exist in the above enforce_sizes - possible_seq_lengths = possible_seq_lengths + [ - valid_prompt_shape[1] - ] - enforce_sizes = enforce_sizes + list( - itertools.islice( - itertools.cycle(possible_seq_lengths), - valid_prompt_shape[0] - 1, + try: + input_ids, extra_kwargs, sample_key = __prepare_inputs( + valid_prompt_shape[0], + valid_prompt_shape[1], + tokenizer, + enforce_sizes=enforce_sizes, + pad_multiple=64, # this should be the smallest granularity to ensure we get the largest enforce_size (if we choose chunked prefill, we want to make sure we pad to the full enforced size) ) - ) - try: - input_ids, extra_kwargs, sample_key = __prepare_inputs( - valid_prompt_shape[0], - valid_prompt_shape[1], - tokenizer, - enforce_sizes=enforce_sizes, - pad_multiple=64, # this should be the smallest granularity to ensure we get the largest enforce_size (if we choose chunked prefill, we want to make sure we pad to the full enforced size) - ) - valid_prompts.append( - ( + used_keys.add(program_seq_key[0]) + yield ( program_seq_key[0], valid_prompt_shape, input_ids, extra_kwargs, sample_key, ) - ) - used_keys.add(program_seq_key[0]) - break - except ValueError: - dprint( - f"No valid sample exists in dataset for this input shape {valid_prompt_shape}" - ) - - if len(used_keys) == 0 and local_rank == 0: - dprint( - f"no valid prompt shape was found which would result in program {program_id} that satisfied batch{batch_size_limit_type}{batch_size_limit} and prompt_length{prompt_length_limit_type}{prompt_length_limit}" - ) + + break + except ValueError: + dprint( + f"No valid sample exists in dataset for this input shape {valid_prompt_shape}" + ) + + if len(used_keys) == 0 and local_rank == 0: + dprint( + f"no valid prompt shape was found which would result in program {program_id} that satisfied batch{batch_size_limit_type}{batch_size_limit} and prompt_length{prompt_length_limit_type}{prompt_length_limit}" + ) # metric calculator based on the cross-entropy and mean diff for each decode step @@ -664,7 +664,13 @@ def __metric_calculator(r: torch.Tensor, t: torch.Tensor): failed_cases = [] # for each program and valid prompt (batch size, sequence length) -for program_id, valid_prompt, input_ids, extra_kwargs, sample_key in valid_prompts: +for ( + program_id, + valid_prompt, + input_ids, + extra_kwargs, + sample_key, +) in get_program_prompt_list(): extra_kwargs["attn_name"] = ATTN_NAME if ( "granite-3.3-8b-instruct" in model_variant