1919from vllm .compilation .counter import compilation_counter
2020from vllm .compilation .wrapper import TorchCompileGuardsStripWrapper
2121from vllm .config import CompilationMode , VllmConfig , set_current_vllm_config
22+ from vllm .config .compilation import DynamicShapesType
2223from vllm .logger import init_logger
2324from vllm .sequence import IntermediateTensors
2425from vllm .utils .import_utils import resolve_obj_by_qualname
@@ -83,6 +84,7 @@ def support_torch_compile(
8384 * ,
8485 dynamic_arg_dims : dict [str , int | list [int ]] | None = None ,
8586 enable_if : Callable [[VllmConfig ], bool ] | None = None ,
87+ shape_invariants : Callable [..., None ] = lambda * args , ** kwargs : None ,
8688) -> Callable [[_T ], _T ] | _T :
8789 """
8890 A decorator to add support for compiling the forward method of a class.
@@ -172,7 +174,9 @@ def cls_decorator_helper(cls: _T) -> _T:
172174 raise ValueError (
173175 f"Argument { k } not found in the forward method of { cls } "
174176 )
175- return _support_torch_compile (cls , inferred_dynamic_arg_dims , enable_if )
177+ return _support_torch_compile (
178+ cls , inferred_dynamic_arg_dims , enable_if , shape_invariants
179+ )
176180
177181 if cls is not None :
178182 # use `support_torch_compile` as a decorator without arguments
@@ -213,6 +217,7 @@ def _support_torch_compile(
213217 cls : _T ,
214218 dynamic_arg_dims : dict [str , int | list [int ]],
215219 enable_if : Callable [[VllmConfig ], bool ] | None = None ,
220+ shape_invariants : Callable [..., None ] = lambda * args , ** kwargs : None ,
216221) -> _T :
217222 """
218223 A decorator to add support for compiling the forward method of a class.
@@ -233,11 +238,12 @@ def _support_torch_compile(
233238 def __init__ (self , * , vllm_config : VllmConfig , prefix : str = "" , ** kwargs ):
234239 old_init (self , vllm_config = vllm_config , prefix = prefix , ** kwargs )
235240 self .vllm_config = vllm_config
241+ self .compilation_config = self .vllm_config .compilation_config
236242 enable_compile = enable_if is None or enable_if (vllm_config )
237243 # for CompilationMode.STOCK_TORCH_COMPILE , the upper level model runner
238244 # will handle the compilation, so we don't need to do anything here.
239245 self .do_not_compile = (
240- vllm_config .compilation_config .mode
246+ self .compilation_config .mode
241247 in [CompilationMode .NONE , CompilationMode .STOCK_TORCH_COMPILE ]
242248 or not supports_dynamo ()
243249 or _should_ignore_torch_compile (self .__class__ )
@@ -246,29 +252,38 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs):
246252 if self .do_not_compile :
247253 return
248254
255+ self ._check_shape_invariants = shape_invariants
256+
249257 compilation_counter .num_models_seen += 1
250258 self .compiled = False
251259 TorchCompileGuardsStripWrapper .__init__ (self )
252260
253261 cls .__init__ = __init__
254262
255- def _mark_dynamic_inputs (mod , * args , ** kwargs ):
263+ def _mark_dynamic_inputs (mod , dynamic_shapes_type , * args , ** kwargs ):
264+ def mark_dynamic (arg , dims ):
265+ if dynamic_shapes_type == DynamicShapesType .UNBACKED :
266+ torch ._dynamo .decorators .mark_unbacked (arg , dims )
267+ else :
268+ torch ._dynamo .mark_dynamic (arg , dims )
269+
256270 sig = inspect .signature (mod .__class__ .forward )
257271 bound_args = sig .bind (mod , * args , ** kwargs )
258272 bound_args .apply_defaults ()
259273 for k , dims in dynamic_arg_dims .items ():
260274 arg = bound_args .arguments .get (k )
275+
261276 if arg is not None :
262277 dims = [dims ] if isinstance (dims , int ) else dims
263278 if isinstance (arg , torch .Tensor ):
264279 # In case dims is specified with negative indexing
265280 dims = [arg .ndim + dim if dim < 0 else dim for dim in dims ]
266- torch . _dynamo . mark_dynamic (arg , dims )
281+ mark_dynamic (arg , dims )
267282 elif isinstance (arg , IntermediateTensors ):
268283 for tensor in arg .tensors .values ():
269284 # In case dims is specified with negative indexing
270285 dims = [tensor .ndim + dim if dim < 0 else dim for dim in dims ]
271- torch . _dynamo . mark_dynamic (tensor , dims )
286+ mark_dynamic (tensor , dims )
272287 else :
273288 raise ValueError (
274289 "Unsupported dynamic dimensions"
@@ -286,6 +301,7 @@ def __call__(self, *args, **kwargs):
286301 if getattr (self , "aot_compiled_fn" , None ) is not None :
287302 return self .aot_compiled_fn (self , * args , ** kwargs )
288303
304+ ds_type = self .compilation_config .dynamic_shapes_config .dynamic_shapes_type
289305 cache_dir = None
290306 aot_compilation_path = None
291307 if envs .VLLM_USE_AOT_COMPILE :
@@ -300,6 +316,14 @@ def __call__(self, *args, **kwargs):
300316 serialized backend artifacts), then we need to generate a new AOT
301317 compile artifact from scratch.
302318 """
319+ # Validate that AOT compile is not used with unbacked dynamic
320+ # shapes. aot_compile re-allocates backed symbols post dynamo!
321+ if ds_type == DynamicShapesType .UNBACKED :
322+ raise ValueError (
323+ "AOT compilation is not compatible with UNBACKED dynamic shapes. "
324+ "Please use BACKED or BACKED_SIZE_OBLIVIOUS dynamic shapes type "
325+ "when VLLM_USE_AOT_COMPILE is enabled."
326+ )
303327 from .caching import compilation_config_hash_factors
304328
305329 factors : list [str ] = compilation_config_hash_factors (self .vllm_config )
@@ -348,7 +372,12 @@ def __call__(self, *args, **kwargs):
348372 # This is the path for the first compilation.
349373
350374 # the first compilation needs to have dynamic shapes marked
351- _mark_dynamic_inputs (self , * args , ** kwargs )
375+ _mark_dynamic_inputs (
376+ self ,
377+ ds_type ,
378+ * args ,
379+ ** kwargs ,
380+ )
352381
353382 # here, it is the starting point of the `torch.compile` process
354383 start_monitoring_torch_compile (self .vllm_config )
@@ -365,9 +394,7 @@ def __call__(self, *args, **kwargs):
365394 # properly when any of these files change.
366395
367396 # 1. the file containing the top-level forward function
368- self .vllm_config .compilation_config .traced_files .add (
369- original_code_object .co_filename
370- )
397+ self .compilation_config .traced_files .add (original_code_object .co_filename )
371398
372399 # 2. every time Dynamo sees a function call, it will inline
373400 # the function by calling InliningInstructionTranslator.inline_call_
@@ -377,7 +404,7 @@ def __call__(self, *args, **kwargs):
377404
378405 def patched_inline_call (self_ ):
379406 code = self_ .f_code
380- self .vllm_config . compilation_config .traced_files .add (code .co_filename )
407+ self .compilation_config .traced_files .add (code .co_filename )
381408 return inline_call (self_ )
382409
383410 # Disable the C++ compilation of symbolic shape guards. C++-fication
@@ -393,12 +420,18 @@ def patched_inline_call(self_):
393420 # if the config doesn't exist
394421 logger .debug ("enable_cpp_symbolic_shape_guards config not available" )
395422
423+ # Prepare backed_size_oblivious config patch if needed
424+ fx_config_patches = {}
425+ if ds_type == DynamicShapesType .BACKED_SIZE_OBLIVIOUS :
426+ fx_config_patches ["backed_size_oblivious" ] = True
427+
396428 with (
397429 patch .object (
398430 InliningInstructionTranslator , "inline_call_" , patched_inline_call
399431 ),
400432 torch ._dynamo .config .patch (** dynamo_config_patches ),
401433 maybe_use_cudagraph_partition_wrapper (self .vllm_config ),
434+ torch .fx .experimental ._config .patch (** fx_config_patches ),
402435 _torch27_patch_tensor_subclasses (),
403436 ):
404437 if envs .VLLM_USE_AOT_COMPILE :
0 commit comments