1919from vllm .compilation .counter import compilation_counter
2020from vllm .compilation .wrapper import TorchCompileGuardsStripWrapper
2121from vllm .config import CompilationMode , VllmConfig , set_current_vllm_config
22+ from vllm .config .compilation import DynamicShapesType
2223from vllm .logger import init_logger
2324from vllm .sequence import IntermediateTensors
2425from vllm .utils import resolve_obj_by_qualname , supports_dynamo
@@ -82,6 +83,7 @@ def support_torch_compile(
8283 * ,
8384 dynamic_arg_dims : dict [str , int | list [int ]] | None = None ,
8485 enable_if : Callable [[VllmConfig ], bool ] | None = None ,
86+ shape_invariants : Callable [..., None ] = lambda * args , ** kwargs : None ,
8587) -> Callable [[_T ], _T ] | _T :
8688 """
8789 A decorator to add support for compiling the forward method of a class.
@@ -171,7 +173,9 @@ def cls_decorator_helper(cls: _T) -> _T:
171173 raise ValueError (
172174 f"Argument { k } not found in the forward method of { cls } "
173175 )
174- return _support_torch_compile (cls , inferred_dynamic_arg_dims , enable_if )
176+ return _support_torch_compile (
177+ cls , inferred_dynamic_arg_dims , enable_if , shape_invariants
178+ )
175179
176180 if cls is not None :
177181 # use `support_torch_compile` as a decorator without arguments
@@ -212,6 +216,7 @@ def _support_torch_compile(
212216 cls : _T ,
213217 dynamic_arg_dims : dict [str , int | list [int ]],
214218 enable_if : Callable [[VllmConfig ], bool ] | None = None ,
219+ shape_invariants : Callable [..., None ] = lambda * args , ** kwargs : None ,
215220) -> _T :
216221 """
217222 A decorator to add support for compiling the forward method of a class.
@@ -232,11 +237,12 @@ def _support_torch_compile(
232237 def __init__ (self , * , vllm_config : VllmConfig , prefix : str = "" , ** kwargs ):
233238 old_init (self , vllm_config = vllm_config , prefix = prefix , ** kwargs )
234239 self .vllm_config = vllm_config
240+ self .compilation_config = self .vllm_config .compilation_config
235241 enable_compile = enable_if is None or enable_if (vllm_config )
236242 # for CompilationMode.STOCK_TORCH_COMPILE , the upper level model runner
237243 # will handle the compilation, so we don't need to do anything here.
238244 self .do_not_compile = (
239- vllm_config .compilation_config .mode
245+ self .compilation_config .mode
240246 in [CompilationMode .NONE , CompilationMode .STOCK_TORCH_COMPILE ]
241247 or not supports_dynamo ()
242248 or _should_ignore_torch_compile (self .__class__ )
@@ -245,29 +251,38 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs):
245251 if self .do_not_compile :
246252 return
247253
254+ self ._check_shape_invariants = shape_invariants
255+
248256 compilation_counter .num_models_seen += 1
249257 self .compiled = False
250258 TorchCompileGuardsStripWrapper .__init__ (self )
251259
252260 cls .__init__ = __init__
253261
254- def _mark_dynamic_inputs (mod , * args , ** kwargs ):
262+ def _mark_dynamic_inputs (mod , dynamic_shapes_type , * args , ** kwargs ):
263+ def mark_dynamic (arg , dims ):
264+ if dynamic_shapes_type == DynamicShapesType .UNBACKED :
265+ torch ._dynamo .decorators .mark_unbacked (arg , dims )
266+ else :
267+ torch ._dynamo .mark_dynamic (arg , dims )
268+
255269 sig = inspect .signature (mod .__class__ .forward )
256270 bound_args = sig .bind (mod , * args , ** kwargs )
257271 bound_args .apply_defaults ()
258272 for k , dims in dynamic_arg_dims .items ():
259273 arg = bound_args .arguments .get (k )
274+
260275 if arg is not None :
261276 dims = [dims ] if isinstance (dims , int ) else dims
262277 if isinstance (arg , torch .Tensor ):
263278 # In case dims is specified with negative indexing
264279 dims = [arg .ndim + dim if dim < 0 else dim for dim in dims ]
265- torch . _dynamo . mark_dynamic (arg , dims )
280+ mark_dynamic (arg , dims )
266281 elif isinstance (arg , IntermediateTensors ):
267282 for tensor in arg .tensors .values ():
268283 # In case dims is specified with negative indexing
269284 dims = [tensor .ndim + dim if dim < 0 else dim for dim in dims ]
270- torch . _dynamo . mark_dynamic (tensor , dims )
285+ mark_dynamic (tensor , dims )
271286 else :
272287 raise ValueError (
273288 "Unsupported dynamic dimensions"
@@ -285,6 +300,7 @@ def __call__(self, *args, **kwargs):
285300 if getattr (self , "aot_compiled_fn" , None ) is not None :
286301 return self .aot_compiled_fn (self , * args , ** kwargs )
287302
303+ ds_type = self .compilation_config .dynamic_shapes_config .dynamic_shapes_type
288304 cache_dir = None
289305 aot_compilation_path = None
290306 if envs .VLLM_USE_AOT_COMPILE :
@@ -299,6 +315,14 @@ def __call__(self, *args, **kwargs):
299315 serialized backend artifacts), then we need to generate a new AOT
300316 compile artifact from scratch.
301317 """
318+ # Validate that AOT compile is not used with unbacked dynamic
319+ # shapes. aot_compile re-allocates backed symbols post dynamo!
320+ if ds_type == DynamicShapesType .UNBACKED :
321+ raise ValueError (
322+ "AOT compilation is not compatible with UNBACKED dynamic shapes. "
323+ "Please use BACKED or BACKED_SIZE_OBLIVIOUS dynamic shapes type "
324+ "when VLLM_USE_AOT_COMPILE is enabled."
325+ )
302326 from .caching import compilation_config_hash_factors
303327
304328 factors : list [str ] = compilation_config_hash_factors (self .vllm_config )
@@ -347,7 +371,12 @@ def __call__(self, *args, **kwargs):
347371 # This is the path for the first compilation.
348372
349373 # the first compilation needs to have dynamic shapes marked
350- _mark_dynamic_inputs (self , * args , ** kwargs )
374+ _mark_dynamic_inputs (
375+ self ,
376+ ds_type ,
377+ * args ,
378+ ** kwargs ,
379+ )
351380
352381 # here, it is the starting point of the `torch.compile` process
353382 start_monitoring_torch_compile (self .vllm_config )
@@ -364,9 +393,7 @@ def __call__(self, *args, **kwargs):
364393 # properly when any of these files change.
365394
366395 # 1. the file containing the top-level forward function
367- self .vllm_config .compilation_config .traced_files .add (
368- original_code_object .co_filename
369- )
396+ self .compilation_config .traced_files .add (original_code_object .co_filename )
370397
371398 # 2. every time Dynamo sees a function call, it will inline
372399 # the function by calling InliningInstructionTranslator.inline_call_
@@ -376,7 +403,7 @@ def __call__(self, *args, **kwargs):
376403
377404 def patched_inline_call (self_ ):
378405 code = self_ .f_code
379- self .vllm_config . compilation_config .traced_files .add (code .co_filename )
406+ self .compilation_config .traced_files .add (code .co_filename )
380407 return inline_call (self_ )
381408
382409 # Disable the C++ compilation of symbolic shape guards. C++-fication
@@ -392,12 +419,18 @@ def patched_inline_call(self_):
392419 # if the config doesn't exist
393420 logger .debug ("enable_cpp_symbolic_shape_guards config not available" )
394421
422+ # Prepare backed_size_oblivious config patch if needed
423+ fx_config_patches = {}
424+ if ds_type == DynamicShapesType .BACKED_SIZE_OBLIVIOUS :
425+ fx_config_patches ["backed_size_oblivious" ] = True
426+
395427 with (
396428 patch .object (
397429 InliningInstructionTranslator , "inline_call_" , patched_inline_call
398430 ),
399431 torch ._dynamo .config .patch (** dynamo_config_patches ),
400432 maybe_use_cudagraph_partition_wrapper (self .vllm_config ),
433+ torch .fx .experimental ._config .patch (** fx_config_patches ),
401434 _torch27_patch_tensor_subclasses (),
402435 ):
403436 if envs .VLLM_USE_AOT_COMPILE :
0 commit comments