Skip to content

Commit 55ca17e

Browse files
committed
fixed bug in test_decoders passing an extra kwarg for sdpa
Signed-off-by: Joshua Rosenkranz <jmrosenk@us.ibm.com>
1 parent 1d533ac commit 55ca17e

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

tests/models/test_decoders.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,7 @@
185185
]
186186
)
187187
os.environ["VLLM_DT_MAX_BATCH_SIZE"] = str(max(max(common_batch_sizes), 2))
188+
fx_config.backed_size_oblivious = True
188189

189190
# thresholds are chosen based on 1024 tokens per sequence
190191
# 1% error threshold rate between cpu fp32 and cuda fp16
@@ -402,7 +403,6 @@ def get_or_create(self, is_gptq, is_fp8, **kwargs):
402403
self.__maybe_prepare_fp8_weights(model, is_fp8)
403404

404405
model.eval()
405-
fx_config.backed_size_oblivious = compile_dynamic_sendnn
406406
model.compile(
407407
backend="sendnn", options={"sendnn.dynamic": compile_dynamic_sendnn}
408408
)
@@ -632,7 +632,8 @@ def _metric_calculator(r: torch.Tensor, t: torch.Tensor):
632632
)
633633
extra_kwargs["attn_name"] = ATTN_NAME
634634
if (
635-
"ibm-granite/granite-3.3-8b-instruct" in model_path
635+
"paged" in ATTN_NAME
636+
and "ibm-granite/granite-3.3-8b-instruct" in model_path
636637
and USE_DISTRIBUTED
637638
and dist.get_world_size() == 4
638639
):

0 commit comments

Comments
 (0)