22# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
44import gc
5+ from contextlib import contextmanager
56
67import pytest
78import torch
89
910from vllm import LLM , SamplingParams
10- from vllm .config .compilation import CompilationMode , DynamicShapesType
11+ from vllm .compilation .decorators import support_torch_compile
12+ from vllm .config import CompilationConfig , VllmConfig , set_current_vllm_config
13+ from vllm .config .compilation import (
14+ CompilationMode ,
15+ DynamicShapesConfig ,
16+ DynamicShapesType ,
17+ )
18+ from vllm .forward_context import set_forward_context
1119from vllm .transformers_utils .tokenizer import get_tokenizer
1220from vllm .utils .torch_utils import is_torch_equal_or_newer
1321
@@ -29,18 +37,19 @@ def get_test_models():
2937)
3038@pytest .mark .parametrize ("use_aot_compile" , ["0" ])
3139@pytest .mark .parametrize ("use_bytecode_hook" , [True , False ])
40+ @pytest .mark .parametrize ("evaluate_guards" , [False , True ])
3241@pytest .mark .skipif (
3342 not is_torch_equal_or_newer ("2.10.0.dev" ), reason = "requires torch 2.10"
3443)
3544def test_dynamic_shapes_compilation (
36- monkeypatch , model_name , shapes_type , use_aot_compile , use_bytecode_hook
45+ monkeypatch ,
46+ model_name ,
47+ shapes_type ,
48+ use_aot_compile ,
49+ use_bytecode_hook ,
50+ evaluate_guards ,
3751):
3852 """Test that all dynamic shapes types compile successfully"""
39- print (
40- f"\n Testing model: { model_name } with { shapes_type .name } , "
41- f"AOT compile: { use_aot_compile } , "
42- f"Bytecode hook: { use_bytecode_hook } "
43- )
4453 if use_bytecode_hook and shapes_type == DynamicShapesType .UNBACKED :
4554 pytest .skip ("UNBACKED dynamic shapes require VLLM_USE_BYTECODE_HOOK=0" )
4655
@@ -58,6 +67,7 @@ def test_dynamic_shapes_compilation(
5867 "mode" : CompilationMode .VLLM_COMPILE ,
5968 "dynamic_shapes_config" : {
6069 "type" : shapes_type .value ,
70+ "evaluate_guards" : evaluate_guards ,
6171 },
6272 },
6373 )
@@ -86,3 +96,100 @@ def test_dynamic_shapes_compilation(
8696 torch .cuda .empty_cache ()
8797 torch .cuda .synchronize ()
8898 print ("GPU memory cleared" )
99+
100+
101+ @pytest .mark .parametrize ("use_aot_compile" , ["0" , "1" ])
102+ @pytest .mark .parametrize (
103+ "dynamic_shapes_type" ,
104+ [
105+ DynamicShapesType .BACKED ,
106+ DynamicShapesType .BACKED_SIZE_OBLIVIOUS ,
107+ DynamicShapesType .UNBACKED ,
108+ ],
109+ )
110+ @pytest .mark .parametrize ("evaluate_guards" , [False , True ])
111+ def test_model_specialization_with_evaluate_guards (
112+ monkeypatch , use_aot_compile , dynamic_shapes_type , evaluate_guards
113+ ):
114+ """Test that evaluate_guards correctly detects shape specialization
115+ violations.
116+ """
117+ if use_aot_compile and dynamic_shapes_type == DynamicShapesType .UNBACKED :
118+ pytest .skip ("UNBACKED dynamic shapes require use_aot_compile=0" )
119+
120+ @support_torch_compile
121+ class ModelWithSizeCheck (torch .nn .Module ):
122+ def __init__ (self , ** kwargs ):
123+ super ().__init__ ()
124+ self .linear = torch .nn .Linear (10 , 10 )
125+
126+ def forward (self , x : torch .Tensor ):
127+ x = self .linear (x )
128+ # This will cause specialization - torch.compile will guard on x.shape[0]
129+ if x .shape [0 ] >= 10 :
130+ return x
131+ else :
132+ return x
133+
134+ @contextmanager
135+ def use_vllm_config (vllm_config : VllmConfig ):
136+ with set_forward_context ({}, vllm_config ), set_current_vllm_config (vllm_config ):
137+ yield
138+
139+ monkeypatch .setenv ("TOKENIZERS_PARALLELISM" , "true" )
140+ monkeypatch .setenv ("VLLM_USE_AOT_COMPILE" , use_aot_compile )
141+ monkeypatch .setenv ("VLLM_USE_BYTECODE_HOOK" , "0" )
142+
143+ # Create vllm config with the desired settings
144+ from vllm .config import CompilationMode
145+
146+ vllm_config = VllmConfig (
147+ compilation_config = CompilationConfig (
148+ mode = CompilationMode .VLLM_COMPILE ,
149+ dynamic_shapes_config = DynamicShapesConfig (
150+ type = dynamic_shapes_type ,
151+ evaluate_guards = evaluate_guards ,
152+ ),
153+ )
154+ )
155+
156+ def test (model_class , input1 , input2 , is_01_specialization = False ):
157+ with torch .no_grad (), use_vllm_config (vllm_config ):
158+ model = model_class (vllm_config = vllm_config ).cuda ()
159+
160+ model (input1 )
161+
162+ if evaluate_guards and not is_01_specialization :
163+ # This should fail because guards were added.
164+ try :
165+ model (input2 )
166+ raise RuntimeError ("expected guard violation to occur" )
167+ except RuntimeError as e :
168+ # Expected failure - guard was violated
169+ error_msg = str (e )
170+ if "guard" in error_msg .lower () or "recompile" in error_msg .lower ():
171+ pass
172+ else :
173+ raise e
174+
175+ else :
176+ model (input2 )
177+
178+ test (ModelWithSizeCheck , torch .randn (20 , 10 ).cuda (), torch .randn (5 , 10 ).cuda ())
179+ test (ModelWithSizeCheck , torch .randn (5 , 10 ).cuda (), torch .randn (20 , 10 ).cuda ())
180+
181+ @support_torch_compile
182+ class ModelWithOneSizeCheck (torch .nn .Module ):
183+ def __init__ (self , ** kwargs ):
184+ super ().__init__ ()
185+ self .linear = torch .nn .Linear (10 , 10 )
186+
187+ def forward (self , x : torch .Tensor ):
188+ x = self .linear (x )
189+ # This will cause 0/1 specializations.
190+ if x .shape [0 ] >= 2 :
191+ return x
192+ else :
193+ return x
194+
195+ test (ModelWithOneSizeCheck , torch .randn (20 , 10 ).cuda (), torch .randn (1 , 10 ).cuda ())
0 commit comments