Skip to content

Commit 76ede50

Browse files
dliu-ibmkcirred
authored andcommitted
[dpp] add option to only save cpu results without running aiu
Signed-off-by: kcirred <16872435+kcirred@users.noreply.github.com>
1 parent 281ff22 commit 76ede50

File tree

1 file changed

+26
-13
lines changed

1 file changed

+26
-13
lines changed

scripts/drive_paged_programs.py

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,11 @@
162162
action="store_true",
163163
help="set to true to save cpu validation outputs for later consumption",
164164
)
165+
parser.add_argument(
166+
"--only_save_validation_output",
167+
action="store_true",
168+
help="set to true to ONLY save cpu validation outputs for later consumption",
169+
)
165170
parser.add_argument(
166171
"--prioritize_large_batch_sizes",
167172
action="store_true",
@@ -391,14 +396,16 @@ def __load_validation_info(
391396
and dist.get_world_size() == 4
392397
):
393398
extra_kwargs["_kvcache_num_blocks_hint"] = KVCACHE_NUM_BLOCKS_HINT
394-
warmup_model(
395-
model,
396-
input_ids,
397-
max_new_tokens=max_new_tokens,
398-
compile_dynamic_sendnn=True,
399-
stagger_update_lazyhandle=args.stagger_update_lazyhandle,
400-
**extra_kwargs,
401-
)
399+
400+
if not args.only_save_validation_output:
401+
warmup_model(
402+
model,
403+
input_ids,
404+
max_new_tokens=max_new_tokens,
405+
compile_dynamic_sendnn=True,
406+
stagger_update_lazyhandle=args.stagger_update_lazyhandle,
407+
**extra_kwargs,
408+
)
402409

403410
if USE_DISTRIBUTED:
404411
# wait for rank0 to be finished as it is the only one generating the criteria json
@@ -659,7 +666,7 @@ def __metric_calculator(r: torch.Tensor, t: torch.Tensor):
659666
**extra_kwargs,
660667
)
661668
# save the cpu validation info for later consumption
662-
if save_validation_info_outputs:
669+
if save_validation_info_outputs or args.only_save_validation_output:
663670
cpu_validation_info.save(
664671
get_validation_info_path(
665672
args.validation_info_outputs_dir,
@@ -674,7 +681,7 @@ def __metric_calculator(r: torch.Tensor, t: torch.Tensor):
674681
)
675682
)
676683

677-
if args.test_type == "metrics":
684+
if args.test_type == "metrics" and not args.only_save_validation_output:
678685
aiu_validation_info = extract_validation_information(
679686
model,
680687
input_ids,
@@ -718,7 +725,7 @@ def __metric_calculator(r: torch.Tensor, t: torch.Tensor):
718725
if failure_rate >= args.failure_rate_threshold:
719726
failed_cases.append((program_id, valid_prompt, failure_rate))
720727

721-
elif args.test_type == "tokens":
728+
elif args.test_type == "tokens" and not args.only_save_validation_output:
722729
aiu_validation_info = extract_validation_information(
723730
model,
724731
input_ids,
@@ -758,9 +765,11 @@ def __metric_calculator(r: torch.Tensor, t: torch.Tensor):
758765
dprint(f"AIU tokens:\n{aiu_tokens_generated}")
759766
dprint(f"CPU output:\n{tokenizer.decode(cpu_tokens_generated)}")
760767
dprint(f"AIU output:\n{tokenizer.decode(aiu_tokens_generated)}")
768+
elif args.only_save_validation_output:
769+
pass
761770
else:
762771
raise ValueError("test type must be one of metrics or tokens")
763-
else:
772+
elif not args.only_save_validation_output:
764773
aiu_validation_info = extract_validation_information(
765774
model,
766775
input_ids,
@@ -784,7 +793,11 @@ def __metric_calculator(r: torch.Tensor, t: torch.Tensor):
784793
dprint(f"AIU tokens:\n{aiu_tokens_generated}")
785794
dprint(f"AIU output:\n{tokenizer.decode(aiu_tokens_generated)}")
786795

787-
if not args.skip_validation and local_rank == 0:
796+
if (
797+
not args.skip_validation
798+
and local_rank == 0
799+
and not args.only_save_validation_output
800+
):
788801
if len(failed_cases) != 0:
789802
dprint("the test failed with the following cases:")
790803
for failed_case in failed_cases:

0 commit comments

Comments
 (0)