33
44import ast
55import dataclasses
6+ import inspect
67import os
8+ import pickle
79import pprint
810import time
911from collections .abc import Sequence
1012from contextlib import contextmanager
1113from typing import Any , Callable , Optional
14+ from unittest .mock import patch
1215
1316import torch
1417import torch .fx as fx
1518from torch ._dispatch .python import enable_python_dispatcher
19+ from torch .utils import _pytree as pytree
1620
1721import vllm .envs as envs
1822from vllm .config import CompilationConfig , CUDAGraphMode , VllmConfig
@@ -408,23 +412,94 @@ def set_model_tag(tag: str):
408412
409413class VllmCompiledFunction (SerializableCallable ):
410414
411- def __init__ (self , graph_module , example_inputs , vllm_config ,
415+ def __init__ (self , graph_module , example_inputs , vllm_config , prefix ,
412416 optimized_call ):
417+ assert isinstance (graph_module , torch .fx .GraphModule )
413418 self .graph_module = graph_module
414419 self .example_inputs = example_inputs
415420 self .vllm_config = vllm_config
421+ self .prefix = prefix
416422 self .optimized_call = optimized_call
417423
418424 def __call__ (self , * args , ** kwargs ):
419425 return self .optimized_call (* args , ** kwargs )
420426
421427 @classmethod
422- def serialize_compile_artifacts (cls , compiled_fn ):
423- raise NotImplementedError ("serialization not implemented" )
428+ def serialize_compile_artifacts (
429+ cls , compiled_fn : "VllmCompiledFunction" ) -> bytes :
430+ import sympy
431+ from torch ._subclasses import FakeTensorMode
432+ from torch .fx ._graph_pickler import GraphPickler , Options
433+ state = compiled_fn .__dict__ .copy ()
434+ state .pop ("optimized_call" )
435+ for node in state ["graph_module" ].graph .nodes :
436+ node .meta .pop ("source_fn_stack" , None )
437+ node .meta .pop ("nn_module_stack" , None )
438+
439+ graph_reducer_override = GraphPickler .reducer_override
440+
441+ def _graph_reducer_override (self , obj ):
442+ if (inspect .isclass (obj ) and issubclass (obj , sympy .Function )
443+ and hasattr (obj , "_torch_unpickler" )):
444+ return obj ._torch_unpickler , (obj ._torch_handler_name , )
445+ if isinstance (obj , FakeTensorMode ):
446+ return type (None ), ()
447+ return graph_reducer_override (self , obj )
448+
449+ # Mask off tensor inputs since they are large and not needed.
450+ state ["example_inputs" ] = pytree .tree_map_only (torch .Tensor ,
451+ lambda _ : None ,
452+ state ["example_inputs" ])
453+ with patch .object (GraphPickler , 'reducer_override' ,
454+ _graph_reducer_override ):
455+ state ["graph_module" ] = GraphPickler .dumps (
456+ state ["graph_module" ], Options (ops_filter = None ))
457+ state ["example_inputs" ] = GraphPickler .dumps (
458+ state ["example_inputs" ])
459+ return pickle .dumps (state )
424460
425461 @classmethod
426- def deserialize_compile_artifacts (cls , data ):
427- raise NotImplementedError ("deserialization not implemented" )
462+ def deserialize_compile_artifacts (cls ,
463+ data : bytes ) -> "VllmCompiledFunction" :
464+ from torch ._guards import TracingContext , tracing
465+ from torch ._subclasses import FakeTensorMode
466+ from torch .fx ._graph_pickler import GraphPickler
467+ from torch .fx .experimental .symbolic_shapes import ShapeEnv
468+
469+ state = pickle .loads (data )
470+ fake_mode = FakeTensorMode (shape_env = ShapeEnv ())
471+ state ["graph_module" ] = GraphPickler .loads (state ["graph_module" ],
472+ fake_mode )
473+ state ["example_inputs" ] = GraphPickler .loads (state ["example_inputs" ],
474+ fake_mode )
475+ vllm_backend = VllmBackend (state ["vllm_config" ], state ["prefix" ])
476+
477+ def optimized_call (* example_inputs ):
478+ compile_inputs = [
479+ inp or example_inputs [i ]
480+ for i , inp in enumerate (fn .example_inputs )
481+ ]
482+ with tracing (TracingContext (fake_mode )):
483+ fn .optimized_call = vllm_backend (state ["graph_module" ],
484+ compile_inputs ).optimized_call
485+ return fn .optimized_call (* example_inputs )
486+
487+ fn = cls (** state , optimized_call = optimized_call )
488+ return fn
489+
490+
491+ def compilation_config_hash_factors (vllm_config : VllmConfig ) -> list [str ]:
492+ factors = []
493+ # 0. factors come from the env, for example, The values of
494+ # VLLM_PP_LAYER_PARTITION will affect the computation graph.
495+ env_hash = envs .compute_hash ()
496+ factors .append (env_hash )
497+
498+ # 1. factors come from the vllm_config (it mainly summarizes how the
499+ # model is created)
500+ config_hash = vllm_config .compute_hash ()
501+ factors .append (config_hash )
502+ return factors
428503
429504
430505class VllmBackend :
@@ -502,7 +577,8 @@ def configure_post_pass(self):
502577 self .post_grad_pass_manager .add (inductor_config [PASS_KEY ])
503578 inductor_config [PASS_KEY ] = self .post_grad_pass_manager
504579
505- def __call__ (self , graph : fx .GraphModule , example_inputs ) -> Callable :
580+ def __call__ (self , graph : fx .GraphModule ,
581+ example_inputs ) -> VllmCompiledFunction :
506582
507583 vllm_config = self .vllm_config
508584 if not self .compilation_config .cache_dir :
@@ -511,17 +587,7 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
511587 # the cache dir will be the same so that we can reuse the compiled
512588 # graph.
513589
514- factors = []
515- # 0. factors come from the env, for example, The values of
516- # VLLM_PP_LAYER_PARTITION will affect the computation graph.
517- env_hash = envs .compute_hash ()
518- factors .append (env_hash )
519-
520- # 1. factors come from the vllm_config (it mainly summarizes how the
521- # model is created)
522- config_hash = vllm_config .compute_hash ()
523- factors .append (config_hash )
524-
590+ factors = compilation_config_hash_factors (vllm_config )
525591 # 2. factors come from the code files that are traced by Dynamo (
526592 # it mainly summarizes how the model is used in forward pass)
527593 forward_code_files = list (
@@ -635,7 +701,7 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
635701 if self .compilation_config .cudagraph_mode == CUDAGraphMode .NONE or \
636702 not self .compilation_config .cudagraph_copy_inputs :
637703 return VllmCompiledFunction (graph , example_inputs , vllm_config ,
638- self .split_gm )
704+ self .prefix , self . split_gm )
639705
640706 # if we need to copy input buffers for cudagraph
641707 from torch ._guards import detect_fake_mode
@@ -678,4 +744,4 @@ def copy_and_call(*args):
678744 return self .split_gm (* list_args )
679745
680746 return VllmCompiledFunction (graph , example_inputs , vllm_config ,
681- copy_and_call )
747+ self . prefix , copy_and_call )
0 commit comments