88from torch .torch_version import TorchVersion
99
1010from vllm import LLM , SamplingParams
11+ from vllm .config import set_current_vllm_config
1112from vllm .config .compilation import DynamicShapesType
1213
1314
@@ -35,9 +36,10 @@ def get_test_models():
3536
3637
3738@pytest .mark .parametrize ("model_name" , get_test_models ())
38- def test_dynamic_shapes_compilation (monkeypatch , model_name ):
39+ @pytest .mark .parametrize ("evaluate_guards" , [False , True ])
40+ def test_dynamic_shapes_compilation (monkeypatch , model_name , evaluate_guards ):
3941 """Test that all dynamic shapes types produce compiles"""
40- print (f"\n Testing model: { model_name } " )
42+ print (f"\n Testing model: { model_name } with evaluate_guards= { evaluate_guards } " )
4143
4244 monkeypatch .setenv ("TOKENIZERS_PARALLELISM" , "true" )
4345 # Note USE_AOT_COMPILE fails https://github.com/vllm-project/vllm/issues/27040.
@@ -76,7 +78,10 @@ def test_dynamic_shapes_compilation(monkeypatch, model_name):
7678 DynamicShapesType .UNBACKED ,
7779 DynamicShapesType .BACKED_SIZE_OBLIVIOUS ,
7880 ]:
79- print (f"Testing { shapes_type .name } dynamic shapes..." )
81+ print (
82+ f"Testing { shapes_type .name } dynamic shapes with "
83+ f"evaluate_guards={ evaluate_guards } ..."
84+ )
8085
8186 # Initialize the model with specific dynamic shapes configuration
8287 model = LLM (
@@ -85,7 +90,7 @@ def test_dynamic_shapes_compilation(monkeypatch, model_name):
8590 "level" : 3 , # PIECEWISE compilation
8691 "dynamic_shapes_config" : {
8792 "dynamic_shapes_type" : shapes_type .value ,
88- "eval_dynamo_ds_guards" : False ,
93+ "eval_dynamo_ds_guards" : evaluate_guards ,
8994 },
9095 },
9196 # gpu_memory_utilization=0.2,
@@ -110,36 +115,136 @@ def test_dynamic_shapes_compilation(monkeypatch, model_name):
110115 print (f"{ shape_type } : '{ result } '" )
111116
112117
113- if __name__ == "__main__" :
114- """Run the test directly as a Python script"""
115- import os
116-
117- print ("Running dynamic shapes compilation test..." )
118-
119- # Get test models based on PyTorch version
120- test_models = get_test_models ()
121- print (f"Testing { len (test_models )} models: { test_models } " )
122-
123- # Create a mock monkeypatch object for environment variables
124- class MockMonkeypatch :
125- def setenv (self , key , value ):
126- os .environ [key ] = value
127-
128- monkeypatch = MockMonkeypatch ()
118+ @pytest .mark .parametrize ("use_aot_compile" , ["0" , "1" ])
119+ @pytest .mark .parametrize (
120+ "dynamic_shapes_type" ,
121+ [
122+ DynamicShapesType .BACKED ,
123+ DynamicShapesType .BACKED_SIZE_OBLIVIOUS ,
124+ ],
125+ )
126+ @pytest .mark .parametrize ("evaluate_guards" , [False , True ])
127+ def test_model_specialization_with_evaluate_guards (
128+ monkeypatch , use_aot_compile , dynamic_shapes_type , evaluate_guards
129+ ):
130+ """Test that evaluate_guards correctly detects shape specialization violations."""
131+ from contextlib import contextmanager
132+
133+ from vllm .compilation .decorators import support_torch_compile
134+ from vllm .config import CompilationConfig , VllmConfig
135+ from vllm .config .compilation import DynamicShapesConfig
136+ from vllm .forward_context import set_forward_context
137+
138+ @support_torch_compile
139+ class ModelWithSizeCheck (torch .nn .Module ):
140+ def __init__ (self , ** kwargs ):
141+ super ().__init__ ()
142+ self .linear = torch .nn .Linear (10 , 10 )
143+
144+ def forward (self , x : torch .Tensor ):
145+ x = self .linear (x )
146+ # This will cause specialization - torch.compile will guard on x.shape[0]
147+ if x .shape [0 ] >= 10 :
148+ return x
149+ else :
150+ return x
151+
152+ @contextmanager
153+ def use_vllm_config (vllm_config : VllmConfig ):
154+ with set_forward_context ({}, vllm_config ), set_current_vllm_config (vllm_config ):
155+ yield
129156
130- # Run test for each model
131- for model_name in test_models :
132- try :
133- print (f"\n { '=' * 60 } " )
134- print (f"Testing model: { model_name } " )
135- print (f"{ '=' * 60 } " )
157+ monkeypatch .setenv ("TOKENIZERS_PARALLELISM" , "true" )
158+ monkeypatch .setenv ("VLLM_USE_AOT_COMPILE" , use_aot_compile )
136159
137- test_dynamic_shapes_compilation (monkeypatch , model_name )
160+ # Reset torch dynamo to clear any cached compilation state
161+ torch ._dynamo .reset ()
138162
139- print (f"✅ Test passed for { model_name } " )
163+ config_desc = (
164+ f"AOT={ use_aot_compile } , shapes={ dynamic_shapes_type .name } , "
165+ f"eval_guards={ evaluate_guards } "
166+ )
167+ print (f"\n { '=' * 60 } " )
168+ print (f"Testing: { config_desc } " )
169+ print (f"{ '=' * 60 } " )
170+
171+ # Create vllm config with the desired settings
172+ from vllm .config import CompilationMode
173+
174+ vllm_config = VllmConfig (
175+ compilation_config = CompilationConfig (
176+ mode = CompilationMode .VLLM_COMPILE ,
177+ dynamic_shapes_config = DynamicShapesConfig (
178+ dynamic_shapes_type = dynamic_shapes_type ,
179+ evaluate_guards = evaluate_guards ,
180+ ),
181+ )
182+ )
140183
141- except Exception as e :
142- print (f"❌ Test failed for { model_name } : { e } " )
143- raise
184+ assert (
185+ vllm_config .compilation_config .dynamic_shapes_config .evaluate_guards
186+ == evaluate_guards
187+ )
188+ with torch .no_grad (), use_vllm_config (vllm_config ):
189+ model = ModelWithSizeCheck (vllm_config = vllm_config ).cuda ()
190+
191+ # First call with size 20 - should always work
192+ input_10 = torch .randn (20 , 10 ).cuda ()
193+ model (input_10 )
194+
195+ # Second call with different size (5) - behavior depends on evaluate_guards
196+ input_5 = torch .randn (5 , 10 ).cuda ()
197+
198+ # Allow recompiles for evaluate_guards=False case
199+ # Only when evaluate_guards=True do we want to detect guard violations
200+ if evaluate_guards :
201+ # With evaluate_guards=True, this should fail because
202+ # guards were added. The model specialized on size 10,
203+ # so size 5 violates the guard
204+ try :
205+ model (input_5 )
206+ # If we get here, no guard violation occurred
207+ # This is a TEST FAILURE - evaluate_guards should have caused a failure
208+ pytest .fail (
209+ f"{ config_desc } : Expected guard violation did "
210+ f"not occur! evaluate_guards=True should fail "
211+ f"when shape changes from 10 to 5, but the "
212+ f"model ran successfully without error."
213+ )
214+ except Exception as e :
215+ # Expected failure - guard was violated
216+ error_msg = str (e )
217+ if "guard" in error_msg .lower () or "recompile" in error_msg .lower ():
218+ print (f"✅ { config_desc } : Expected failure due to guard violation" )
219+ print (f" Error (truncated): { error_msg [:150 ]} " )
220+ else :
221+ # Unexpected error type
222+ print (f"❌ { config_desc } : Unexpected error type" )
223+ print (f" Error: { e } " )
224+ raise
225+ else :
226+ # With evaluate_guards=False, guards are dropped, so this should work
227+ # However, recompilation may still occur, which is expected
228+ try :
229+ output_5 = model (input_5 )
230+ assert output_5 .shape == (
231+ 5 ,
232+ 10 ,
233+ ), "Output shape should match input"
234+ print (f"✅ { config_desc } : Passed without guard violations" )
235+ print (" Second call (size 5): Success" )
236+ except RuntimeError as e :
237+ # If it's a recompile error, that's expected when evaluate_guards=False
238+ # The model is allowed to recompile with different shapes
239+ if (
240+ "recompile" in str (e ).lower ()
241+ and "fail_on_recompile" in str (e ).lower ()
242+ ):
243+ print (f"✅ { config_desc } : Recompile occurred (expected behavior)" )
244+ print (" Recompiles are allowed when evaluate_guards=False" )
245+ else :
246+ print (f"❌ { config_desc } : Unexpected failure" )
247+ print (f" Error: { e } " )
248+ raise
144249
145- print ( " \n 🎉 All tests completed successfully!" )
250+ cleanup_gpu_memory ( )
0 commit comments