3636
3737
3838def execute_script (execute_cmd ):
39- # using these options temporarily
40- current_env ["VLLM_DT_MAX_BATCH_TKV_LIMIT" ] = "16384"
41- current_env ["VLLM_DT_MAX_BATCH_SIZE" ] = "4"
42- current_env ["VLLM_DT_MAX_CONTEXT_LEN" ] = "4096"
43-
4439 with Popen (
4540 execute_cmd ,
4641 stdin = PIPE ,
@@ -56,11 +51,25 @@ def execute_script(execute_cmd):
5651 raise Exception (error )
5752
5853
59- def execute_inference (model_path , batch_size , seq_length , max_new_tokens , attn_type ):
54+ def execute_inference (
55+ model_path , batch_size , seq_length , max_new_tokens , attn_type , allow_symbolic_shapes
56+ ):
6057 extra_args = []
6158 if attn_type == "paged" :
62- extra_args . append ( "--compile_dynamic_sendnn" )
59+ # paged needs symbolic shapes
6360 extra_args .append ("--attention_type=paged" )
61+ # using these options temporarily
62+ current_env .setdefault ("VLLM_DT_MAX_BATCH_TKV_LIMIT" , "16384" )
63+ current_env .setdefault ("VLLM_DT_MAX_BATCH_SIZE" , "4" )
64+ current_env .setdefault ("VLLM_DT_MAX_CONTEXT_LEN" , "4096" )
65+ else :
66+ # added in case symbolic shapes used with sdpa
67+ current_env .setdefault ("_PROMPT_LEN" , "64" )
68+ current_env .setdefault ("_MAX_DECODE_TOKENS" , "8" )
69+ current_env .setdefault ("_MAX_CONTEXT_LEN" , "71" )
70+
71+ if allow_symbolic_shapes is not None and allow_symbolic_shapes :
72+ extra_args .append ("--compile_dynamic_sendnn" )
6473
6574 execute_cmd = [
6675 "python3" ,
@@ -97,20 +106,38 @@ def __repeat_batch_asserts(bs: int) -> list[str]:
97106# add the asserts based on batch size
98107# for batches greater than common_asserts, repeat common_asserts since this follows inference behavior
99108common_inference_params = [
100- common_param + (__repeat_batch_asserts (common_param [1 ]),)
109+ common_param + (__repeat_batch_asserts (common_param [1 ]), None )
101110 for common_param in common_params
102111]
112+ # adding special case where we allow symbolic shapes for batch size 1 using sdpa
113+ common_inference_params .append (
114+ (common_model_paths [0 ], 1 , 64 , 8 , "sdpa" , [common_asserts [0 ]], True )
115+ )
103116
104117
105118@pytest .mark .parametrize (
106- "model_path,batch_size,seq_length,max_new_tokens,attn_type,asserts" ,
119+ "model_path,batch_size,seq_length,max_new_tokens,attn_type,asserts,allow_symbolic_shapes " ,
107120 common_inference_params ,
108121)
109122def test_inference_script (
110- model_path , batch_size , seq_length , max_new_tokens , attn_type , asserts
123+ model_path ,
124+ batch_size ,
125+ seq_length ,
126+ max_new_tokens ,
127+ attn_type ,
128+ asserts ,
129+ allow_symbolic_shapes ,
111130):
131+ # force symbolic shapes if paged
132+ if "paged" in attn_type :
133+ allow_symbolic_shapes = True
112134 result_text = execute_inference (
113- model_path , batch_size , seq_length , max_new_tokens , attn_type
135+ model_path ,
136+ batch_size ,
137+ seq_length ,
138+ max_new_tokens ,
139+ attn_type ,
140+ allow_symbolic_shapes ,
114141 )
115142
116143 for common_assert in asserts :
0 commit comments