Skip to content

Commit 2e9d6d8

Browse files
committed
addressed pr comments; added constant for kv-cache hint
Signed-off-by: Joshua Rosenkranz <jmrosenk@us.ibm.com>
1 parent dc95321 commit 2e9d6d8

File tree

4 files changed

+22
-10
lines changed

4 files changed

+22
-10
lines changed

aiu_fms_testing_utils/utils/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -468,7 +468,7 @@ def __sample_requests(
468468
return filtered_dataset
469469

470470

471-
def sample_granite_3_3_long_answerable_requests(
471+
def sample_rag_factoid_requests(
472472
dataset_path: str,
473473
num_requests: int,
474474
tokenizer: PreTrainedTokenizerBase,

aiu_fms_testing_utils/utils/paged.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -416,7 +416,7 @@ def generate(
416416
if post_iteration_hook is not None:
417417
_logits = logits
418418
_next_val = next_val
419-
# since we cannot handle batch size 1 and mimic with batch size 2, we need to only pass in the first logits/next_val
419+
# since we cannot handle batch size 1 for fp8 and mimic with batch size 2, we need to only pass in the first logits/next_val
420420
if is_fp8 and not is_batch:
421421
_logits = logits[0].unsqueeze(0)
422422
_next_val = _next_val[0].unsqueeze(0)
@@ -464,6 +464,11 @@ def generate(
464464
return result
465465

466466

467+
# this value is default to 2080 to be consistent with vllm for granite 3.3 8b instruct
468+
KVCACHE_NUM_BLOCKS_HINT = int(
469+
os.environ.get("AFTU_PAGED_KVCACHE_NUM_BLOCKS_HINT", 2080)
470+
)
471+
467472
VLLM_DT_MAX_BATCH_TKV_LIMIT = int(os.environ.get("VLLM_DT_MAX_BATCH_TKV_LIMIT", 131072))
468473

469474

scripts/drive_paged_programs.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,17 @@
2424
top_k_loss_calculator,
2525
)
2626
from aiu_fms_testing_utils.utils import (
27-
sample_granite_3_3_long_answerable_requests,
27+
sample_rag_factoid_requests,
2828
sample_sharegpt_requests,
2929
stagger_region,
3030
warmup_model,
3131
)
3232
from aiu_fms_testing_utils.utils.aiu_setup import aiu_dist_setup, dprint, local_rank
33-
from aiu_fms_testing_utils.utils.paged import ProgramCriteria, get_programs_prompts
33+
from aiu_fms_testing_utils.utils.paged import (
34+
ProgramCriteria,
35+
get_programs_prompts,
36+
KVCACHE_NUM_BLOCKS_HINT,
37+
)
3438

3539
parser = argparse.ArgumentParser(
3640
description="Script which will drive paged programs for debugging"
@@ -167,7 +171,7 @@
167171
save_validation_info_outputs = args.save_validation_info_outputs
168172

169173
if args.dataset_type == "rag_factoid":
170-
sampler = sample_granite_3_3_long_answerable_requests
174+
sampler = sample_rag_factoid_requests
171175
allow_truncation = False
172176
elif args.dataset_type == "sharegpt":
173177
sampler = sample_sharegpt_requests
@@ -335,7 +339,7 @@ def __load_validation_info(
335339
and USE_DISTRIBUTED
336340
and dist.get_world_size() == 4
337341
):
338-
extra_kwargs["_kvcache_num_blocks_hint"] = 2080
342+
extra_kwargs["_kvcache_num_blocks_hint"] = KVCACHE_NUM_BLOCKS_HINT
339343
warmup_model(
340344
model,
341345
input_ids,
@@ -347,6 +351,7 @@ def __load_validation_info(
347351

348352
if USE_DISTRIBUTED:
349353
# wait for rank0 to be finished as it is the only one generating the criteria json
354+
# this is needed since otherwise we may run into a race condition
350355
torch.distributed.barrier()
351356

352357
with open(args.program_criteria_json_path, "r") as f:
@@ -434,7 +439,8 @@ def __metric_calculator(r: torch.Tensor, t: torch.Tensor):
434439

435440

436441
failed_cases = []
437-
for program_id, valid_prompt in valid_prompts: # for each program
442+
# for each program and valid prompt (batch size, sequence length)
443+
for program_id, valid_prompt in valid_prompts:
438444
input_ids, extra_kwargs = __prepare_inputs(
439445
valid_prompt[0], valid_prompt[1], tokenizer, enforce_sizes=[valid_prompt[1]]
440446
)
@@ -444,7 +450,7 @@ def __metric_calculator(r: torch.Tensor, t: torch.Tensor):
444450
and USE_DISTRIBUTED
445451
and dist.get_world_size() == 4
446452
):
447-
extra_kwargs["_kvcache_num_blocks_hint"] = 2080
453+
extra_kwargs["_kvcache_num_blocks_hint"] = KVCACHE_NUM_BLOCKS_HINT
448454

449455
if local_rank == 0:
450456
dprint(f"*** testing program {program_id} ***")

tests/models/test_decoders.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
warmup_model,
2525
sample_sharegpt_requests,
2626
)
27+
from aiu_fms_testing_utils.utils.paged import KVCACHE_NUM_BLOCKS_HINT
2728
import json
2829
from transformers import AutoTokenizer
2930

@@ -538,7 +539,7 @@ def test_common_shapes(
538539
and USE_DISTRIBUTED
539540
and dist.get_world_size() == 4
540541
):
541-
extra_kwargs["_kvcache_num_blocks_hint"] = 2080
542+
extra_kwargs["_kvcache_num_blocks_hint"] = KVCACHE_NUM_BLOCKS_HINT
542543

543544
# warmup aiu model
544545
warmup_model(
@@ -637,7 +638,7 @@ def _metric_calculator(r: torch.Tensor, t: torch.Tensor):
637638
and USE_DISTRIBUTED
638639
and dist.get_world_size() == 4
639640
):
640-
extra_kwargs["_kvcache_num_blocks_hint"] = 2080
641+
extra_kwargs["_kvcache_num_blocks_hint"] = KVCACHE_NUM_BLOCKS_HINT
641642

642643
cpu_validation_info = __load_validation_info(
643644
model_path,

0 commit comments

Comments
 (0)