Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 26 additions & 13 deletions aiu_fms_testing_utils/scripts/drive_paged_programs.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,11 @@
action="store_true",
help="set to true to save cpu validation outputs for later consumption",
)
parser.add_argument(
"--only_save_validation_output",
action="store_true",
help="set to true to ONLY save cpu validation outputs for later consumption",
)
parser.add_argument(
"--prioritize_large_batch_sizes",
action="store_true",
Expand Down Expand Up @@ -391,14 +396,16 @@ def __load_validation_info(
and dist.get_world_size() == 4
):
extra_kwargs["_kvcache_num_blocks_hint"] = KVCACHE_NUM_BLOCKS_HINT
warmup_model(
model,
input_ids,
max_new_tokens=max_new_tokens,
compile_dynamic_sendnn=True,
stagger_update_lazyhandle=args.stagger_update_lazyhandle,
**extra_kwargs,
)

if not args.only_save_validation_output:
warmup_model(
model,
input_ids,
max_new_tokens=max_new_tokens,
compile_dynamic_sendnn=True,
stagger_update_lazyhandle=args.stagger_update_lazyhandle,
**extra_kwargs,
)

if USE_DISTRIBUTED:
# wait for rank0 to be finished as it is the only one generating the criteria json
Expand Down Expand Up @@ -659,7 +666,7 @@ def __metric_calculator(r: torch.Tensor, t: torch.Tensor):
**extra_kwargs,
)
# save the cpu validation info for later consumption
if save_validation_info_outputs:
if save_validation_info_outputs or args.only_save_validation_output:
cpu_validation_info.save(
get_validation_info_path(
args.validation_info_outputs_dir,
Expand All @@ -674,7 +681,7 @@ def __metric_calculator(r: torch.Tensor, t: torch.Tensor):
)
)

if args.test_type == "metrics":
if args.test_type == "metrics" and not args.only_save_validation_output:
aiu_validation_info = extract_validation_information(
model,
input_ids,
Expand Down Expand Up @@ -718,7 +725,7 @@ def __metric_calculator(r: torch.Tensor, t: torch.Tensor):
if failure_rate >= args.failure_rate_threshold:
failed_cases.append((program_id, valid_prompt, failure_rate))

elif args.test_type == "tokens":
elif args.test_type == "tokens" and not args.only_save_validation_output:
aiu_validation_info = extract_validation_information(
model,
input_ids,
Expand Down Expand Up @@ -758,9 +765,11 @@ def __metric_calculator(r: torch.Tensor, t: torch.Tensor):
dprint(f"AIU tokens:\n{aiu_tokens_generated}")
dprint(f"CPU output:\n{tokenizer.decode(cpu_tokens_generated)}")
dprint(f"AIU output:\n{tokenizer.decode(aiu_tokens_generated)}")
elif args.only_save_validation_output:
pass
else:
raise ValueError("test type must be one of metrics or tokens")
else:
elif not args.only_save_validation_output:
aiu_validation_info = extract_validation_information(
model,
input_ids,
Expand All @@ -784,7 +793,11 @@ def __metric_calculator(r: torch.Tensor, t: torch.Tensor):
dprint(f"AIU tokens:\n{aiu_tokens_generated}")
dprint(f"AIU output:\n{tokenizer.decode(aiu_tokens_generated)}")

if not args.skip_validation and local_rank == 0:
if (
not args.skip_validation
and local_rank == 0
and not args.only_save_validation_output
):
if len(failed_cases) != 0:
dprint("the test failed with the following cases:")
for failed_case in failed_cases:
Expand Down