1717
1818import vllm .envs as envs
1919from vllm .compilation .counter import compilation_counter
20- from vllm .compilation .wrapper import TorchCompileWrapperWithCustomDispatcher
20+ from vllm .compilation .wrapper import TorchCompileGuardsStripWrapper
2121from vllm .config import CompilationMode , VllmConfig , set_current_vllm_config
2222from vllm .logger import init_logger
2323from vllm .sequence import IntermediateTensors
@@ -217,14 +217,14 @@ def _support_torch_compile(
217217 """
218218 A decorator to add support for compiling the forward method of a class.
219219 """
220- if TorchCompileWrapperWithCustomDispatcher in cls .__bases__ :
220+ if TorchCompileGuardsStripWrapper in cls .__bases__ :
221221 # support decorating multiple times
222222 return cls
223223
224224 # take care of method resolution order
225225 # make sure super().__init__ is called on the base class
226- # other than TorchCompileWrapperWithCustomDispatcher
227- cls .__bases__ = cls .__bases__ + (TorchCompileWrapperWithCustomDispatcher ,)
226+ # other than TorchCompileGuardsStripWrapper
227+ cls .__bases__ = cls .__bases__ + (TorchCompileGuardsStripWrapper ,)
228228
229229 old_init = cls .__init__
230230
@@ -247,19 +247,42 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs):
247247 return
248248
249249 compilation_counter .num_models_seen += 1
250- TorchCompileWrapperWithCustomDispatcher .__init__ (
251- self , compilation_mode = vllm_config .compilation_config .mode
252- )
250+ self .compiled = False
251+ TorchCompileGuardsStripWrapper .__init__ (self )
253252
254253 cls .__init__ = __init__
255254
255+ def _mark_dynamic_inputs (mod , * args , ** kwargs ):
256+ sig = inspect .signature (mod .__class__ .forward )
257+ bound_args = sig .bind (mod , * args , ** kwargs )
258+ bound_args .apply_defaults ()
259+ for k , dims in dynamic_arg_dims .items ():
260+ arg = bound_args .arguments .get (k )
261+ if arg is not None :
262+ dims = [dims ] if isinstance (dims , int ) else dims
263+ if isinstance (arg , torch .Tensor ):
264+ # In case dims is specified with negative indexing
265+ dims = [arg .ndim + dim if dim < 0 else dim for dim in dims ]
266+ torch ._dynamo .mark_dynamic (arg , dims )
267+ elif isinstance (arg , IntermediateTensors ):
268+ for tensor in arg .tensors .values ():
269+ # In case dims is specified with negative indexing
270+ dims = [tensor .ndim + dim if dim < 0 else dim for dim in dims ]
271+ torch ._dynamo .mark_dynamic (tensor , dims )
272+ else :
273+ raise ValueError (
274+ "Unsupported dynamic dimensions"
275+ f" { dims } for argument { k } with type { type (arg )} ."
276+ )
277+
256278 def __call__ (self , * args , ** kwargs ):
257279 # torch.compiler.is_compiling() means we are inside the compilation
258280 # e.g. TPU has the compilation logic in model runner, so we don't
259281 # need to compile the model inside.
260282 if self .do_not_compile or torch .compiler .is_compiling ():
261283 return self .forward (* args , ** kwargs )
262284
285+ # if aot_compiled_fn is set, just call it.
263286 if getattr (self , "aot_compiled_fn" , None ) is not None :
264287 return self .aot_compiled_fn (self , * args , ** kwargs )
265288
@@ -318,102 +341,78 @@ def __call__(self, *args, **kwargs):
318341 )
319342 return self .aot_compiled_fn (self , * args , ** kwargs )
320343
344+ if self .compiled :
345+ assert not envs .VLLM_USE_AOT_COMPILE
346+ return TorchCompileGuardsStripWrapper .__call__ (self , * args , ** kwargs )
347+
348+ # This is the path for the first compilation.
349+
321350 # the first compilation needs to have dynamic shapes marked
322- if len (self .compiled_codes ) < 1 :
323- sig = inspect .signature (self .__class__ .forward )
324- bound_args = sig .bind (self , * args , ** kwargs )
325- bound_args .apply_defaults ()
326- for k , dims in dynamic_arg_dims .items ():
327- arg = bound_args .arguments .get (k )
328- if arg is not None :
329- dims = [dims ] if isinstance (dims , int ) else dims
330- if isinstance (arg , torch .Tensor ):
331- # In case dims is specified with negative indexing
332- dims = [arg .ndim + dim if dim < 0 else dim for dim in dims ]
333- torch ._dynamo .mark_dynamic (arg , dims )
334- elif isinstance (arg , IntermediateTensors ):
335- for tensor in arg .tensors .values ():
336- # In case dims is specified with negative indexing
337- dims = [
338- tensor .ndim + dim if dim < 0 else dim for dim in dims
339- ]
340- torch ._dynamo .mark_dynamic (tensor , dims )
341- else :
342- raise ValueError (
343- "Unsupported dynamic dimensions"
344- f" { dims } for argument { k } with type { type (arg )} ."
345- )
346- # here, it is the starting point of the `torch.compile` process
347- start_monitoring_torch_compile (self .vllm_config )
348- logger .debug ("Start compiling function %s" , self .original_code_object )
349-
350- # if we don't use custom dispatcher, we can directly call the
351- # compiled function and let torch.compile handle the dispatching,
352- # with the overhead of guard evaluation and recompilation.
353- if len (self .compiled_codes ) < 1 or not self .use_custom_dispatcher :
354- # it seems Dynamo reuse the compilation across instances,
355- # while we need to make sure the compiled code is not reused.
356- # we need to control all the compilation of the model.
357- torch ._dynamo .eval_frame .remove_from_cache (self .original_code_object )
358-
359- # collect all relevant files traced by Dynamo,
360- # so that the compilation cache can trigger re-compilation
361- # properly when any of these files change.
362-
363- # 1. the file containing the top-level forward function
364- self .vllm_config .compilation_config .traced_files .add (
365- self .original_code_object .co_filename
366- )
351+ _mark_dynamic_inputs (self , * args , ** kwargs )
367352
368- # 2. every time Dynamo sees a function call, it will inline
369- # the function by calling InliningInstructionTranslator.inline_call_
370- # we hijack this function to know all the functions called
371- # during Dynamo tracing, and their corresponding files
372- inline_call = InliningInstructionTranslator .inline_call_
373-
374- def patched_inline_call (self_ ):
375- code = self_ .f_code
376- self .vllm_config .compilation_config .traced_files .add (code .co_filename )
377- return inline_call (self_ )
378-
379- # Disable the C++ compilation of symbolic shape guards. C++-fication
380- # of symbolic shape guards can improve guard overhead. But, since
381- # vllm skip guards anyways, setting this flag to False can improve
382- # compile time.
383- dynamo_config_patches = {}
384- try :
385- _ = torch ._dynamo .config .enable_cpp_symbolic_shape_guards
386- dynamo_config_patches ["enable_cpp_symbolic_shape_guards" ] = False
387- except AttributeError :
388- # Note: this config is not available in torch 2.6, we can skip
389- # if the config doesn't exist
390- logger .debug ("enable_cpp_symbolic_shape_guards config not available" )
391-
392- with (
393- patch .object (
394- InliningInstructionTranslator , "inline_call_" , patched_inline_call
395- ),
396- torch ._dynamo .config .patch (** dynamo_config_patches ),
397- maybe_use_cudagraph_partition_wrapper (self .vllm_config ),
398- _torch27_patch_tensor_subclasses (),
399- ):
400- if envs .VLLM_USE_AOT_COMPILE :
401- self .aot_compiled_fn = self .aot_compile (* args , ** kwargs )
402- output = self .aot_compiled_fn (self , * args , ** kwargs )
403- assert aot_compilation_path is not None
404- assert cache_dir is not None
405- os .makedirs (cache_dir , exist_ok = True )
406- self .aot_compiled_fn .save_compiled_function (aot_compilation_path )
407- else :
408- output = self .compiled_callable (* args , ** kwargs )
409- return output
410-
411- # usually, capturing the model once is enough, and then we can
412- # dispatch to the compiled code directly, without going through
413- # the Dynamo guard mechanism.
414- with self .dispatch_to_code (0 ):
415- model_output = self .forward (* args , ** kwargs )
416- return model_output
353+ # here, it is the starting point of the `torch.compile` process
354+ start_monitoring_torch_compile (self .vllm_config )
355+ original_code_object = self .original_code_object ()
356+ logger .debug ("Start compiling function %s" , original_code_object )
357+
358+ # it seems Dynamo reuse the compilation across instances,
359+ # while we need to make sure the compiled code is not reused.
360+ # we need to control all the compilation of the model.
361+ torch ._dynamo .eval_frame .remove_from_cache (original_code_object )
362+
363+ # collect all relevant files traced by Dynamo,
364+ # so that the compilation cache can trigger re-compilation
365+ # properly when any of these files change.
366+
367+ # 1. the file containing the top-level forward function
368+ self .vllm_config .compilation_config .traced_files .add (
369+ original_code_object .co_filename
370+ )
371+
372+ # 2. every time Dynamo sees a function call, it will inline
373+ # the function by calling InliningInstructionTranslator.inline_call_
374+ # we hijack this function to know all the functions called
375+ # during Dynamo tracing, and their corresponding files
376+ inline_call = InliningInstructionTranslator .inline_call_
377+
378+ def patched_inline_call (self_ ):
379+ code = self_ .f_code
380+ self .vllm_config .compilation_config .traced_files .add (code .co_filename )
381+ return inline_call (self_ )
382+
383+ # Disable the C++ compilation of symbolic shape guards. C++-fication
384+ # of symbolic shape guards can improve guard overhead. But, since
385+ # vllm skip guards anyways, setting this flag to False can improve
386+ # compile time.
387+ dynamo_config_patches = {}
388+ try :
389+ _ = torch ._dynamo .config .enable_cpp_symbolic_shape_guards
390+ dynamo_config_patches ["enable_cpp_symbolic_shape_guards" ] = False
391+ except AttributeError :
392+ # Note: this config is not available in torch 2.6, we can skip
393+ # if the config doesn't exist
394+ logger .debug ("enable_cpp_symbolic_shape_guards config not available" )
395+
396+ with (
397+ patch .object (
398+ InliningInstructionTranslator , "inline_call_" , patched_inline_call
399+ ),
400+ torch ._dynamo .config .patch (** dynamo_config_patches ),
401+ maybe_use_cudagraph_partition_wrapper (self .vllm_config ),
402+ _torch27_patch_tensor_subclasses (),
403+ ):
404+ if envs .VLLM_USE_AOT_COMPILE :
405+ self .aot_compiled_fn = self .aot_compile (* args , ** kwargs )
406+ output = self .aot_compiled_fn (self , * args , ** kwargs )
407+ assert aot_compilation_path is not None
408+ assert cache_dir is not None
409+ os .makedirs (cache_dir , exist_ok = True )
410+ self .aot_compiled_fn .save_compiled_function (aot_compilation_path )
411+ else :
412+ output = TorchCompileGuardsStripWrapper .__call__ (self , * args , ** kwargs )
413+
414+ self .compiled = True
415+ return output
417416
418417 cls .__call__ = __call__
419418 return cls
0 commit comments