Skip to content

Commit ab4f7a2

Browse files
authored
Merge pull request #135 from foundation-model-stack/fix_inference_script_static_graphs_show_more_symbols_than_expected
fixed inference.py for batch size 1 symbolic sdpa
2 parents ea529c5 + 9f6bd71 commit ab4f7a2

File tree

2 files changed

+39
-12
lines changed

2 files changed

+39
-12
lines changed

scripts/inference.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -588,7 +588,7 @@ def select_int8_module(
588588

589589
if args.compile:
590590
dprint("compiling model")
591-
fx_config.backed_size_oblivious = True
591+
fx_config.backed_size_oblivious = "paged" in attn_name
592592
if is_aiu_backend:
593593
model.compile(
594594
backend="sendnn", options={"sendnn.dynamic": args.compile_dynamic_sendnn}

tests/models/test_scripts.py

Lines changed: 38 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,6 @@
3636

3737

3838
def execute_script(execute_cmd):
39-
# using these options temporarily
40-
current_env["VLLM_DT_MAX_BATCH_TKV_LIMIT"] = "16384"
41-
current_env["VLLM_DT_MAX_BATCH_SIZE"] = "4"
42-
current_env["VLLM_DT_MAX_CONTEXT_LEN"] = "4096"
43-
4439
with Popen(
4540
execute_cmd,
4641
stdin=PIPE,
@@ -56,11 +51,25 @@ def execute_script(execute_cmd):
5651
raise Exception(error)
5752

5853

59-
def execute_inference(model_path, batch_size, seq_length, max_new_tokens, attn_type):
54+
def execute_inference(
55+
model_path, batch_size, seq_length, max_new_tokens, attn_type, allow_symbolic_shapes
56+
):
6057
extra_args = []
6158
if attn_type == "paged":
62-
extra_args.append("--compile_dynamic_sendnn")
59+
# paged needs symbolic shapes
6360
extra_args.append("--attention_type=paged")
61+
# using these options temporarily
62+
current_env.setdefault("VLLM_DT_MAX_BATCH_TKV_LIMIT", "16384")
63+
current_env.setdefault("VLLM_DT_MAX_BATCH_SIZE", "4")
64+
current_env.setdefault("VLLM_DT_MAX_CONTEXT_LEN", "4096")
65+
else:
66+
# added in case symbolic shapes used with sdpa
67+
current_env.setdefault("_PROMPT_LEN", "64")
68+
current_env.setdefault("_MAX_DECODE_TOKENS", "8")
69+
current_env.setdefault("_MAX_CONTEXT_LEN", "71")
70+
71+
if allow_symbolic_shapes is not None and allow_symbolic_shapes:
72+
extra_args.append("--compile_dynamic_sendnn")
6473

6574
execute_cmd = [
6675
"python3",
@@ -97,20 +106,38 @@ def __repeat_batch_asserts(bs: int) -> list[str]:
97106
# add the asserts based on batch size
98107
# for batches greater than common_asserts, repeat common_asserts since this follows inference behavior
99108
common_inference_params = [
100-
common_param + (__repeat_batch_asserts(common_param[1]),)
109+
common_param + (__repeat_batch_asserts(common_param[1]), None)
101110
for common_param in common_params
102111
]
112+
# adding special case where we allow symbolic shapes for batch size 1 using sdpa
113+
common_inference_params.append(
114+
(common_model_paths[0], 1, 64, 8, "sdpa", [common_asserts[0]], True)
115+
)
103116

104117

105118
@pytest.mark.parametrize(
106-
"model_path,batch_size,seq_length,max_new_tokens,attn_type,asserts",
119+
"model_path,batch_size,seq_length,max_new_tokens,attn_type,asserts,allow_symbolic_shapes",
107120
common_inference_params,
108121
)
109122
def test_inference_script(
110-
model_path, batch_size, seq_length, max_new_tokens, attn_type, asserts
123+
model_path,
124+
batch_size,
125+
seq_length,
126+
max_new_tokens,
127+
attn_type,
128+
asserts,
129+
allow_symbolic_shapes,
111130
):
131+
# force symbolic shapes if paged
132+
if "paged" in attn_type:
133+
allow_symbolic_shapes = True
112134
result_text = execute_inference(
113-
model_path, batch_size, seq_length, max_new_tokens, attn_type
135+
model_path,
136+
batch_size,
137+
seq_length,
138+
max_new_tokens,
139+
attn_type,
140+
allow_symbolic_shapes,
114141
)
115142

116143
for common_assert in asserts:

0 commit comments

Comments
 (0)