1010from torch ._dynamo .symbolic_convert import InliningInstructionTranslator
1111
1212from vllm .compilation .counter import compilation_counter
13- from vllm .compilation .wrapper import TorchCompileWrapperWithCustomDispatcher
13+ from vllm .compilation .wrapper import TorchCompileGuardsStripWrapper
1414from vllm .config import CompilationLevel , VllmConfig
1515from vllm .logger import init_logger
1616from vllm .sequence import IntermediateTensors
@@ -32,11 +32,11 @@ def ignore_torch_compile(cls: _T) -> _T:
3232 a support_torch_compile decorator, but we don't want to
3333 compile the class `cls` that inherits the parent class.
3434 This only ignores compiling the forward of the class the
35- decorator is applied to.
35+ decorator is applied to.
3636
3737 If the parent has ignore_torch_compile but the child has
3838 support_torch_compile, the child will still be compiled.
39-
39+
4040 If the class has one or more submodules
4141 that have support_torch_compile decorator applied, compile will
4242 not be ignored for those submodules.
@@ -182,14 +182,14 @@ def _support_torch_compile(
182182 """
183183 A decorator to add support for compiling the forward method of a class.
184184 """
185- if TorchCompileWrapperWithCustomDispatcher in cls .__bases__ :
185+ if TorchCompileGuardsStripWrapper in cls .__bases__ :
186186 # support decorating multiple times
187187 return cls
188188
189189 # take care of method resolution order
190190 # make sure super().__init__ is called on the base class
191- # other than TorchCompileWrapperWithCustomDispatcher
192- cls .__bases__ = cls .__bases__ + (TorchCompileWrapperWithCustomDispatcher , )
191+ # other than TorchCompileGuardsStripWrapper
192+ cls .__bases__ = cls .__bases__ + (TorchCompileGuardsStripWrapper , )
193193
194194 old_init = cls .__init__
195195
@@ -210,107 +210,83 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = '', **kwargs):
210210 return
211211
212212 compilation_counter .num_models_seen += 1
213- TorchCompileWrapperWithCustomDispatcher .__init__ (
214- self , compilation_level = vllm_config .compilation_config .level )
213+ TorchCompileGuardsStripWrapper .__init__ (self )
215214
216215 cls .__init__ = __init__
217216
217+ def _mark_dynamic_inputs (mod , * args , ** kwargs ):
218+ sig = inspect .signature (mod .__class__ .forward )
219+ bound_args = sig .bind (mod , * args , ** kwargs )
220+ bound_args .apply_defaults ()
221+ for k , dims in dynamic_arg_dims .items ():
222+ arg = bound_args .arguments .get (k )
223+ if arg is not None :
224+ dims = [dims ] if isinstance (dims , int ) else dims
225+ if isinstance (arg , torch .Tensor ):
226+ # In case dims is specified with negative indexing
227+ dims = [arg .ndim + dim if dim < 0 else dim for dim in dims ]
228+ torch ._dynamo .mark_dynamic (arg , dims )
229+ elif isinstance (arg , IntermediateTensors ):
230+ for tensor in arg .tensors .values ():
231+ # In case dims is specified with negative indexing
232+ dims = [
233+ tensor .ndim + dim if dim < 0 else dim
234+ for dim in dims
235+ ]
236+ torch ._dynamo .mark_dynamic (tensor , dims )
237+ else :
238+ raise ValueError (
239+ "Unsupported dynamic dimensions"
240+ f" { dims } for argument { k } with type { type (arg )} ." )
241+
218242 def __call__ (self , * args , ** kwargs ):
219243 # torch.compiler.is_compiling() means we are inside the compilation
220244 # e.g. TPU has the compilation logic in model runner, so we don't
221245 # need to compile the model inside.
222246 if self .do_not_compile or torch .compiler .is_compiling ():
223247 return self .forward (* args , ** kwargs )
224248
249+ # This attributed is added by TorchCompileGuardsStripWrapper
250+ if self .compiled :
251+ return TorchCompileGuardsStripWrapper .__call__ (
252+ self , * args , ** kwargs )
253+
254+ # This is the path for the first compilation.
255+ _mark_dynamic_inputs (self , * args , ** kwargs )
256+
225257 # the first compilation needs to have dynamic shapes marked
226- if len (self .compiled_codes ) < 1 :
227- sig = inspect .signature (self .__class__ .forward )
228- bound_args = sig .bind (self , * args , ** kwargs )
229- bound_args .apply_defaults ()
230- for k , dims in dynamic_arg_dims .items ():
231- arg = bound_args .arguments .get (k )
232- if arg is not None :
233- dims = [dims ] if isinstance (dims , int ) else dims
234- if isinstance (arg , torch .Tensor ):
235- # In case dims is specified with negative indexing
236- dims = [
237- arg .ndim + dim if dim < 0 else dim for dim in dims
238- ]
239- torch ._dynamo .mark_dynamic (arg , dims )
240- elif isinstance (arg , IntermediateTensors ):
241- for tensor in arg .tensors .values ():
242- # In case dims is specified with negative indexing
243- dims = [
244- tensor .ndim + dim if dim < 0 else dim
245- for dim in dims
246- ]
247- torch ._dynamo .mark_dynamic (tensor , dims )
248- else :
249- raise ValueError (
250- "Unsupported dynamic dimensions"
251- f" { dims } for argument { k } with type { type (arg )} ." )
252- # here, it is the starting point of the `torch.compile` process
253- start_monitoring_torch_compile (self .vllm_config )
254- logger .debug ("Start compiling function %s" ,
255- self .original_code_object )
258+ start_monitoring_torch_compile (self .vllm_config )
259+ logger .debug ("Start compiling function %s" ,
260+ self .original_code_object ())
256261
257262 # if we don't use custom dispatcher, we can directly call the
258263 # compiled function and let torch.compile handle the dispatching,
259264 # with the overhead of guard evaluation and recompilation.
260- if len (self .compiled_codes ) < 1 or not self .use_custom_dispatcher :
261- # it seems Dynamo reuse the compilation across instances,
262- # while we need to make sure the compiled code is not reused.
263- # we need to control all the compilation of the model.
264- torch ._dynamo .eval_frame .remove_from_cache (
265- self .original_code_object )
266-
267- # collect all relevant files traced by Dynamo,
268- # so that the compilation cache can trigger re-compilation
269- # properly when any of these files change.
270-
271- # 1. the file containing the top-level forward function
265+
266+ # collect all relevant files traced by Dynamo,
267+ # so that the compilation cache can trigger re-compilation
268+ # properly when any of these files change.
269+
270+ # 1. the file containing the top-level forward function
271+ self .vllm_config .compilation_config .traced_files .add (
272+ self .original_code_object ().co_filename )
273+
274+ # 2. every time Dynamo sees a function call, it will inline
275+ # the function by calling InliningInstructionTranslator.inline_call
276+ # we hijack this function to know all the functions called
277+ # during Dynamo tracing, and their corresponding files
278+ inline_call = InliningInstructionTranslator .inline_call
279+
280+ def patched_inline_call (parent , func , args , kwargs ):
281+ code = func .get_code ()
272282 self .vllm_config .compilation_config .traced_files .add (
273- self .original_code_object .co_filename )
274-
275- # 2. every time Dynamo sees a function call, it will inline
276- # the function by calling InliningInstructionTranslator.inline_call
277- # we hijack this function to know all the functions called
278- # during Dynamo tracing, and their corresponding files
279- inline_call = InliningInstructionTranslator .inline_call
280-
281- def patched_inline_call (parent , func , args , kwargs ):
282- code = func .get_code ()
283- self .vllm_config .compilation_config .traced_files .add (
284- code .co_filename )
285- return inline_call (parent , func , args , kwargs )
286-
287- # Disable the C++ compilation of symbolic shape guards. C++-fication
288- # of symbolic shape guards can improve guard overhead. But, since
289- # vllm skip guards anyways, setting this flag to False can improve
290- # compile time.
291- dynamo_config_patches = {}
292- try :
293- _ = torch ._dynamo .config .enable_cpp_symbolic_shape_guards
294- dynamo_config_patches [
295- "enable_cpp_symbolic_shape_guards" ] = False
296- except AttributeError :
297- # Note: this config is not available in torch 2.6, we can skip
298- # if the config doesn't exist
299- logger .debug (
300- "enable_cpp_symbolic_shape_guards config not available" )
301-
302- with patch .object (InliningInstructionTranslator , 'inline_call' ,
303- patched_inline_call ), torch ._dynamo .config .patch (
304- ** dynamo_config_patches ):
305- output = self .compiled_callable (* args , ** kwargs )
306- return output
307-
308- # usually, capturing the model once is enough, and then we can
309- # dispatch to the compiled code directly, without going through
310- # the Dynamo guard mechanism.
311- with self .dispatch_to_code (0 ):
312- model_output = self .forward (* args , ** kwargs )
313- return model_output
283+ code .co_filename )
284+ return inline_call (parent , func , args , kwargs )
285+
286+ with patch .object (InliningInstructionTranslator , "inline_call" ,
287+ patched_inline_call ):
288+ return TorchCompileGuardsStripWrapper .__call__ (
289+ self , * args , ** kwargs )
314290
315291 cls .__call__ = __call__
316292 return cls
0 commit comments