@@ -441,6 +441,35 @@ def set_model_tag(tag: str):
441441 model_tag = old_tag
442442
443443
444+ try :
445+ from torch ._dynamo .aot_compile import SerializableCallable
446+ except ImportError :
447+ SerializableCallable = object
448+
449+ assert isinstance (SerializableCallable , type )
450+
451+
452+ class VllmCompiledFunction (SerializableCallable ):
453+
454+ def __init__ (self , graph_module , example_inputs , vllm_config ,
455+ optimized_call ):
456+ self .graph_module = graph_module
457+ self .example_inputs = example_inputs
458+ self .vllm_config = vllm_config
459+ self .optimized_call = optimized_call
460+
461+ def __call__ (self , * args , ** kwargs ):
462+ return self .optimized_call (* args , ** kwargs )
463+
464+ @classmethod
465+ def serialize_compile_artifacts (cls , compiled_fn ):
466+ raise NotImplementedError ("serialization not implemented" )
467+
468+ @classmethod
469+ def deserialize_compile_artifacts (cls , data ):
470+ raise NotImplementedError ("deserialization not implemented" )
471+
472+
444473class VllmBackend :
445474 """The compilation backend for `torch.compile` with vLLM.
446475 It is used for compilation level of `CompilationLevel.PIECEWISE`,
@@ -659,7 +688,8 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
659688 self .compilation_config .cudagraph_mode == CUDAGraphMode .NONE
660689 or not self .compilation_config .cudagraph_copy_inputs
661690 ):
662- return self .split_gm
691+ return VllmCompiledFunction (graph , example_inputs , vllm_config ,
692+ self .split_gm )
663693
664694 # if we need to copy input buffers for cudagraph
665695 from torch ._guards import detect_fake_mode
@@ -704,4 +734,5 @@ def copy_and_call(*args):
704734 list_args [index ] = static_tensor
705735 return self .split_gm (* list_args )
706736
707- return copy_and_call
737+ return VllmCompiledFunction (graph , example_inputs , vllm_config ,
738+ copy_and_call )
0 commit comments