From 208af7249da0b9cbf23ca9c051a821e836db26e8 Mon Sep 17 00:00:00 2001 From: zhxchen17 Date: Wed, 10 Sep 2025 08:44:16 -0700 Subject: [PATCH 1/2] AOT compilation workflow [1/n] Signed-off-by: zhxchen17 --- .buildkite/test-pipeline.yaml | 1 + tests/compile/test_aot_compile.py | 61 +++++++++++++++++++++++++++++++ vllm/compilation/backends.py | 35 +++++++++++++++++- vllm/compilation/decorators.py | 10 ++++- vllm/compilation/wrapper.py | 26 +++++++++++++ vllm/envs.py | 7 ++++ 6 files changed, 137 insertions(+), 3 deletions(-) create mode 100644 tests/compile/test_aot_compile.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 681503e98372..ebe0602a1b5d 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -403,6 +403,7 @@ steps: - pytest -v -s compile/test_fusion_all_reduce.py - pytest -v -s compile/test_decorator.py - pytest -v -s compile/test_noop_elimination.py + - pytest -v -s compile/test_aot_compile.py - label: PyTorch Fullgraph Smoke Test # 15min timeout_in_minutes: 30 diff --git a/tests/compile/test_aot_compile.py b/tests/compile/test_aot_compile.py new file mode 100644 index 000000000000..df3b1ffacb8a --- /dev/null +++ b/tests/compile/test_aot_compile.py @@ -0,0 +1,61 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from contextlib import contextmanager + +import pytest +import torch + +from vllm.compilation.decorators import support_torch_compile +from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig, + set_current_vllm_config) +from vllm.forward_context import set_forward_context + + +class MyMod(torch.nn.Module): + + def __init__(self, **kwargs): + super().__init__() + + def forward(self, x: torch.Tensor): + for _ in range(3000): + x = x + x.shape[0] + return x + + +def make_vllm_config() -> VllmConfig: + return VllmConfig(compilation_config=CompilationConfig( + level=CompilationLevel.PIECEWISE, )) + + +@contextmanager +def use_vllm_config(vllm_config: VllmConfig): + with set_forward_context( + {}, vllm_config), set_current_vllm_config(vllm_config): + yield + + +def test_no_eval_frame(monkeypatch: pytest.MonkeyPatch): + with monkeypatch.context() as m: + mod = MyMod() + args = (torch.randn(10, 10), ) + expected = mod(*args) + CompiledMod = support_torch_compile(MyMod) + + vllm_config = make_vllm_config() + m.setenv("VLLM_USE_AOT_COMPILE", "0") + try: + with use_vllm_config(vllm_config), torch.compiler.set_stance( + "fail_on_recompile"): + CompiledMod(vllm_config=vllm_config)(*args) + except RuntimeError as e: + assert "Detected recompile" in str(e) + else: + raise AssertionError("Expected exception to be raised") + + m.setenv("VLLM_USE_AOT_COMPILE", "1") + torch._dynamo.reset() + with use_vllm_config(vllm_config), torch.compiler.set_stance( + "fail_on_recompile"): + ret = CompiledMod(vllm_config=vllm_config)(*args) + assert torch.allclose(ret, expected) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index c35d77d4668c..831755b834ed 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -472,6 +472,35 @@ def set_model_tag(tag: str): model_tag = old_tag +try: + from torch._dynamo.aot_compile import SerializableCallable +except ImportError: + SerializableCallable = object + +assert isinstance(SerializableCallable, type) + + +class VllmCompiledFunction(SerializableCallable): + + def __init__(self, graph_module, example_inputs, vllm_config, + optimized_call): + self.graph_module = graph_module + self.example_inputs = example_inputs + self.vllm_config = vllm_config + self.optimized_call = optimized_call + + def __call__(self, *args, **kwargs): + return self.optimized_call(*args, **kwargs) + + @classmethod + def serialize_compile_artifacts(cls, compiled_fn): + raise NotImplementedError("serialization not implemented") + + @classmethod + def deserialize_compile_artifacts(cls, data): + raise NotImplementedError("deserialization not implemented") + + class VllmBackend: """The compilation backend for `torch.compile` with vLLM. It is used for compilation level of `CompilationLevel.PIECEWISE`, @@ -695,7 +724,8 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE or not self.compilation_config.cudagraph_copy_inputs ): - return self.split_gm + return VllmCompiledFunction(graph, example_inputs, vllm_config, + self.split_gm) # if we need to copy input buffers for cudagraph from torch._guards import detect_fake_mode @@ -740,4 +770,5 @@ def copy_and_call(*args): list_args[index] = static_tensor return self.split_gm(*list_args) - return copy_and_call + return VllmCompiledFunction(graph, example_inputs, vllm_config, + copy_and_call) diff --git a/vllm/compilation/decorators.py b/vllm/compilation/decorators.py index 4f5648d3000a..74090a00c7e2 100644 --- a/vllm/compilation/decorators.py +++ b/vllm/compilation/decorators.py @@ -11,6 +11,7 @@ from packaging import version from torch._dynamo.symbolic_convert import InliningInstructionTranslator +import vllm.envs as envs from vllm.compilation.counter import compilation_counter from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher from vllm.config import CompilationLevel, VllmConfig @@ -227,6 +228,9 @@ def __call__(self, *args, **kwargs): if self.do_not_compile or torch.compiler.is_compiling(): return self.forward(*args, **kwargs) + if getattr(self, "aot_compiled_fn", None) is not None: + return self.aot_compiled_fn(self, *args, **kwargs) + # the first compilation needs to have dynamic shapes marked if len(self.compiled_codes) < 1: sig = inspect.signature(self.__class__.forward) @@ -306,7 +310,11 @@ def patched_inline_call(parent, func, args, kwargs): maybe_use_cudagraph_partition_wrapper(self.vllm_config), _torch27_patch_tensor_subclasses(), ): - output = self.compiled_callable(*args, **kwargs) + if envs.VLLM_USE_AOT_COMPILE: + self.aot_compiled_fn = self.aot_compile(*args, **kwargs) + output = self.aot_compiled_fn(self, *args, **kwargs) + else: + output = self.compiled_callable(*args, **kwargs) return output # usually, capturing the model once is enough, and then we can diff --git a/vllm/compilation/wrapper.py b/vllm/compilation/wrapper.py index 71a4e1745d4e..4202a78b3105 100644 --- a/vllm/compilation/wrapper.py +++ b/vllm/compilation/wrapper.py @@ -10,6 +10,7 @@ import torch +import vllm.envs as envs from vllm.config import CompilationLevel, CUDAGraphMode, get_current_vllm_config from vllm.logger import init_logger @@ -41,9 +42,26 @@ def __init__( backend = vllm_config.compilation_config.init_backend(vllm_config) options = None if isinstance(backend, str) and backend == "inductor": +<<<<<<< HEAD options = ( get_current_vllm_config().compilation_config.inductor_compile_config ) +======= + options = get_current_vllm_config( + ).compilation_config.inductor_compile_config + if envs.VLLM_USE_AOT_COMPILE: + options = options or {} + options["guard_filter_fn"] = lambda guards: [ + False for _ in guards + ] + if hasattr(torch._dynamo.config, "enable_aot_compile"): + torch._dynamo.config.enable_aot_compile = True + else: + msg = "torch._dynamo.config.enable_aot_compile is not " + msg += "available. AOT compile is disabled and please " + msg += "upgrade PyTorch version to use AOT compile." + logger.warning(msg) +>>>>>>> 6fc29676a (AOT compilation workflow [1/n]) compiled_callable = torch.compile( self.forward, fullgraph=True, backend=backend, options=options @@ -61,6 +79,14 @@ def __init__( compilation_level >= CompilationLevel.DYNAMO_ONCE ) + def aot_compile(self, *args, **kwargs): + if not hasattr(self.compiled_callable, "aot_compile"): + raise RuntimeError( + "aot_compile is not supported by the current configuration. " + + "Please make sure torch.compile is enabled with the latest " + + "version of PyTorch") + return self.compiled_callable.aot_compile((args, kwargs)) + def __call__(self, *args, **kwargs): """Implement the dispatch logic here, beyond the torch.compile level. NOTE: this function can have additional arguments beyond the forward diff --git a/vllm/envs.py b/vllm/envs.py index 9485aeeb8a82..afcebef8b12a 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -494,6 +494,13 @@ def get_vllm_port() -> Optional[int]: # Dump fx graphs to the given directory. # It will override CompilationConfig.debug_dump_path if set. "VLLM_DEBUG_DUMP_PATH": lambda: os.environ.get("VLLM_DEBUG_DUMP_PATH", None), + + # Feature flag to enable/disable AOT compilation. This will ensure + # compilation is done in warmup phase and the compilation will be + # reused in subsequent calls. + "VLLM_USE_AOT_COMPILE": + lambda: os.environ.get("VLLM_USE_AOT_COMPILE", "0") == "1", + # local rank of the process in the distributed setting, used to determine # the GPU device id "LOCAL_RANK": lambda: int(os.environ.get("LOCAL_RANK", "0")), From aba7a855d81102b88c8178eb2f456acf56abd876 Mon Sep 17 00:00:00 2001 From: zhxchen17 Date: Wed, 10 Sep 2025 08:44:16 -0700 Subject: [PATCH 2/2] AOT compilation workflow [2/n] Signed-off-by: zhxchen17 --- tests/compile/test_aot_compile.py | 134 +++++++++++++---- tools/pre_commit/check_pickle_imports.py | 1 + vllm/compilation/backends.py | 80 +++-------- vllm/compilation/caching.py | 176 +++++++++++++++++++++++ vllm/compilation/compiler_interface.py | 6 + vllm/compilation/decorators.py | 103 ++++++++++++- vllm/compilation/wrapper.py | 19 ++- vllm/envs.py | 18 ++- 8 files changed, 423 insertions(+), 114 deletions(-) create mode 100644 vllm/compilation/caching.py diff --git a/tests/compile/test_aot_compile.py b/tests/compile/test_aot_compile.py index df3b1ffacb8a..08f79d90cd36 100644 --- a/tests/compile/test_aot_compile.py +++ b/tests/compile/test_aot_compile.py @@ -1,61 +1,139 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import tempfile from contextlib import contextmanager import pytest import torch from vllm.compilation.decorators import support_torch_compile -from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig, - set_current_vllm_config) +from vllm.config import ( + CompilationConfig, + CompilationLevel, + VllmConfig, + set_current_vllm_config, +) from vllm.forward_context import set_forward_context +from vllm.utils import is_torch_equal_or_newer -class MyMod(torch.nn.Module): +def reference_fn(x: torch.Tensor): + assert x.shape[0] <= 42 + assert x.shape[0] % 2 == 0 + for _ in range(3000): + x = x + x.shape[0] + return x + +@support_torch_compile +class CompiledMod(torch.nn.Module): def __init__(self, **kwargs): super().__init__() def forward(self, x: torch.Tensor): - for _ in range(3000): - x = x + x.shape[0] - return x + return reference_fn(x) def make_vllm_config() -> VllmConfig: - return VllmConfig(compilation_config=CompilationConfig( - level=CompilationLevel.PIECEWISE, )) + return VllmConfig( + compilation_config=CompilationConfig( + level=CompilationLevel.PIECEWISE, + ) + ) @contextmanager def use_vllm_config(vllm_config: VllmConfig): - with set_forward_context( - {}, vllm_config), set_current_vllm_config(vllm_config): + with set_forward_context({}, vllm_config), set_current_vllm_config(vllm_config): yield -def test_no_eval_frame(monkeypatch: pytest.MonkeyPatch): +@pytest.mark.skipif( + not is_torch_equal_or_newer("2.10.0.dev"), reason="requires torch 2.10" +) +def test_no_dynamo_cache_entry(monkeypatch: pytest.MonkeyPatch): with monkeypatch.context() as m: - mod = MyMod() - args = (torch.randn(10, 10), ) - expected = mod(*args) - CompiledMod = support_torch_compile(MyMod) - vllm_config = make_vllm_config() - m.setenv("VLLM_USE_AOT_COMPILE", "0") - try: - with use_vllm_config(vllm_config), torch.compiler.set_stance( - "fail_on_recompile"): + args = (torch.randn(10, 10),) + expected = reference_fn(*args) + with use_vllm_config(vllm_config): + m.setenv("VLLM_USE_AOT_COMPILE", "0") + with ( + pytest.raises(RuntimeError, match="Detected recompile"), + torch.compiler.set_stance("fail_on_recompile"), + ): CompiledMod(vllm_config=vllm_config)(*args) - except RuntimeError as e: - assert "Detected recompile" in str(e) - else: - raise AssertionError("Expected exception to be raised") + m.setenv("VLLM_USE_AOT_COMPILE", "1") + torch._dynamo.reset() + with torch.compiler.set_stance("fail_on_recompile"): + actual = CompiledMod(vllm_config=vllm_config)(*args) + assert torch.allclose(actual, expected) + + +@pytest.mark.skipif( + not is_torch_equal_or_newer("2.10.0.dev"), reason="requires torch 2.10" +) +def test_force_aot_load(monkeypatch: pytest.MonkeyPatch): + with tempfile.TemporaryDirectory() as tmpdirname, monkeypatch.context() as m: + args = (torch.randn(10, 10),) m.setenv("VLLM_USE_AOT_COMPILE", "1") - torch._dynamo.reset() - with use_vllm_config(vllm_config), torch.compiler.set_stance( - "fail_on_recompile"): - ret = CompiledMod(vllm_config=vllm_config)(*args) + m.setenv("VLLM_FORCE_AOT_LOAD", "1") + m.setenv("VLLM_CACHE_ROOT", tmpdirname) + vllm_config = make_vllm_config() + with use_vllm_config(vllm_config), pytest.raises(FileNotFoundError): + CompiledMod(vllm_config=vllm_config)(*args) + + +@pytest.mark.skipif( + not is_torch_equal_or_newer("2.10.0.dev"), reason="requires torch 2.10" +) +def test_save_and_load(monkeypatch: pytest.MonkeyPatch): + with monkeypatch.context() as m: + args = (torch.randn(10, 10),) + + with tempfile.TemporaryDirectory() as tmpdirname: + m.setenv("VLLM_CACHE_ROOT", tmpdirname) + m.setenv("VLLM_USE_AOT_COMPILE", "1") + vllm_config = make_vllm_config() + with use_vllm_config(vllm_config): + expected = CompiledMod(vllm_config=vllm_config)(*args) + + m.setenv("VLLM_FORCE_AOT_LOAD", "1") + vllm_config = make_vllm_config() + with use_vllm_config(vllm_config): + ret = CompiledMod(vllm_config=vllm_config)(*args) assert torch.allclose(ret, expected) + + +@pytest.mark.skipif( + not is_torch_equal_or_newer("2.10.0.dev"), reason="requires torch 2.10" +) +def test_shape_env(monkeypatch: pytest.MonkeyPatch): + """ + Test that the shape environment is correctly serialized and preserved + when loading from cache. + """ + with monkeypatch.context() as m: + args = (torch.randn(10, 10),) + + with tempfile.TemporaryDirectory() as tmpdirname: + m.setenv("VLLM_CACHE_ROOT", tmpdirname) + m.setenv("VLLM_USE_AOT_COMPILE", "1") + vllm_config = make_vllm_config() + with use_vllm_config(vllm_config): + compiled_mod = CompiledMod(vllm_config=vllm_config) + compiled_mod(*args) + artifacts = compiled_mod.aot_compiled_fn._artifacts + guards_string = artifacts.compiled_fn.shape_env.format_guards() + assert guards_string == " - s77 <= 42\n - Eq(Mod(s77, 2), 0)" + + m.setenv("VLLM_FORCE_AOT_LOAD", "1") + vllm_config = make_vllm_config() + with use_vllm_config(vllm_config): + compiled_mod = CompiledMod(vllm_config=vllm_config) + compiled_mod(*args) + artifacts = compiled_mod.aot_compiled_fn._artifacts + guards_string = artifacts.compiled_fn.shape_env.format_guards() + assert guards_string == " - s77 <= 42\n - Eq(Mod(s77, 2), 0)" diff --git a/tools/pre_commit/check_pickle_imports.py b/tools/pre_commit/check_pickle_imports.py index bceb894a7a5f..7944b7c9b275 100644 --- a/tools/pre_commit/check_pickle_imports.py +++ b/tools/pre_commit/check_pickle_imports.py @@ -22,6 +22,7 @@ "vllm/multimodal/hasher.py", "vllm/transformers_utils/config.py", "vllm/model_executor/models/registry.py", + "vllm/compilation/caching.py", "tests/utils_/test_utils.py", "tests/tokenization/test_cached_tokenizer.py", "vllm/distributed/utils.py", diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 831755b834ed..826ab42462c3 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -3,6 +3,7 @@ import ast import dataclasses +import hashlib import os import pprint import time @@ -25,6 +26,7 @@ from vllm.platforms import current_platform from vllm.utils import is_torch_equal_or_newer, resolve_obj_by_qualname +from .caching import VllmSerializableFunction from .compiler_interface import ( CompilerInterface, EagerAdaptor, @@ -195,6 +197,7 @@ def compile( # there can be multiple graphs due to piecewise compilation. now = time.time() elapsed = now - compilation_start_time + compilation_config.compilation_time += elapsed if runtime_shape is None: logger.info( "Directly load the compiled graph(s) for dynamic shape " @@ -472,35 +475,6 @@ def set_model_tag(tag: str): model_tag = old_tag -try: - from torch._dynamo.aot_compile import SerializableCallable -except ImportError: - SerializableCallable = object - -assert isinstance(SerializableCallable, type) - - -class VllmCompiledFunction(SerializableCallable): - - def __init__(self, graph_module, example_inputs, vllm_config, - optimized_call): - self.graph_module = graph_module - self.example_inputs = example_inputs - self.vllm_config = vllm_config - self.optimized_call = optimized_call - - def __call__(self, *args, **kwargs): - return self.optimized_call(*args, **kwargs) - - @classmethod - def serialize_compile_artifacts(cls, compiled_fn): - raise NotImplementedError("serialization not implemented") - - @classmethod - def deserialize_compile_artifacts(cls, data): - raise NotImplementedError("deserialization not implemented") - - class VllmBackend: """The compilation backend for `torch.compile` with vLLM. It is used for compilation level of `CompilationLevel.PIECEWISE`, @@ -578,7 +552,11 @@ def configure_post_pass(self): self.post_grad_pass_manager.add(inductor_config[PASS_KEY]) inductor_config[PASS_KEY] = self.post_grad_pass_manager - def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: + def __call__( + self, graph: fx.GraphModule, example_inputs + ) -> VllmSerializableFunction: + from .caching import _compute_code_hash, compilation_config_hash_factors + vllm_config = self.vllm_config if not self.compilation_config.cache_dir: # no provided cache dir, generate one based on the known factors @@ -586,39 +564,11 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: # the cache dir will be the same so that we can reuse the compiled # graph. - factors = [] - # 0. factors come from the env, for example, The values of - # VLLM_PP_LAYER_PARTITION will affect the computation graph. - env_hash = envs.compute_hash() - factors.append(env_hash) - - # 1. factors come from the vllm_config (it mainly summarizes how the - # model is created) - config_hash = vllm_config.compute_hash() - factors.append(config_hash) - + factors = compilation_config_hash_factors(vllm_config) # 2. factors come from the code files that are traced by Dynamo ( # it mainly summarizes how the model is used in forward pass) - forward_code_files = list(sorted(self.compilation_config.traced_files)) + code_hash = _compute_code_hash(self.compilation_config.traced_files) self.compilation_config.traced_files.clear() - logger.debug( - "Traced files (to be considered for compilation cache):\n%s", - "\n".join(forward_code_files), - ) - hash_content = [] - for filepath in forward_code_files: - hash_content.append(filepath) - if filepath == "": - # This means the function was dynamically generated, with - # e.g. exec(). We can't actually check these. - continue - with open(filepath) as f: - hash_content.append(f.read()) - import hashlib - - code_hash = hashlib.md5( - "\n".join(hash_content).encode(), usedforsecurity=False - ).hexdigest() factors.append(code_hash) # 3. compiler hash @@ -724,8 +674,9 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE or not self.compilation_config.cudagraph_copy_inputs ): - return VllmCompiledFunction(graph, example_inputs, vllm_config, - self.split_gm) + return VllmSerializableFunction( + graph, example_inputs, self.prefix, self.split_gm + ) # if we need to copy input buffers for cudagraph from torch._guards import detect_fake_mode @@ -770,5 +721,6 @@ def copy_and_call(*args): list_args[index] = static_tensor return self.split_gm(*list_args) - return VllmCompiledFunction(graph, example_inputs, vllm_config, - copy_and_call) + return VllmSerializableFunction( + graph, example_inputs, self.prefix, copy_and_call + ) diff --git a/vllm/compilation/caching.py b/vllm/compilation/caching.py new file mode 100644 index 000000000000..fc930e9b4f14 --- /dev/null +++ b/vllm/compilation/caching.py @@ -0,0 +1,176 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import hashlib +import inspect +import pickle +from unittest.mock import patch + +import torch +from torch.utils import _pytree as pytree + +import vllm.envs as envs +from vllm.config import VllmConfig, get_current_vllm_config +from vllm.logger import init_logger + +try: + from torch._dynamo.aot_compile import SerializableCallable +except ImportError: + SerializableCallable = object + +assert isinstance(SerializableCallable, type) + +logger = init_logger(__name__) + + +class VllmSerializableFunction(SerializableCallable): + """ + A wrapper around a compiled function by vllm. It will forward the tensor + inputs to the compiled function and return the result. + It also implements a serialization interface to support PyTorch's precompile + with custom backend, so that we can save and load the compiled function on + disk. There's no need to wrap around the compiled function if we don't want + to serialize them in particular cases. + Right now serialization for the custom backend is done via + serializing the Dynamo fx graph plus example inputs. + """ + + def __init__(self, graph_module, example_inputs, prefix, optimized_call): + assert isinstance(graph_module, torch.fx.GraphModule) + self.graph_module = graph_module + self.example_inputs = example_inputs + self.prefix = prefix + self.optimized_call = optimized_call + self.shape_env = None + sym_input = next( + (i for i in self.example_inputs if isinstance(i, torch.SymInt)), None + ) + if sym_input is not None: + self.shape_env = sym_input.node.shape_env + + def __call__(self, *args, **kwargs): + return self.optimized_call(*args, **kwargs) + + @classmethod + def serialize_compile_artifacts( + cls, compiled_fn: "VllmSerializableFunction" + ) -> bytes: + import sympy + from torch._subclasses import FakeTensorMode + from torch.fx._graph_pickler import GraphPickler, Options + + state = compiled_fn.__dict__.copy() + state.pop("optimized_call") + state.pop("shape_env") + for node in state["graph_module"].graph.nodes: + node.meta.pop("source_fn_stack", None) + node.meta.pop("nn_module_stack", None) + + graph_reducer_override = GraphPickler.reducer_override + + def _graph_reducer_override(self, obj): + if ( + inspect.isclass(obj) + and issubclass(obj, sympy.Function) + and hasattr(obj, "_torch_unpickler") + ): + return obj._torch_unpickler, (obj._torch_handler_name,) + if isinstance(obj, FakeTensorMode): + return type(None), () + return graph_reducer_override(self, obj) + + # Mask off tensor inputs since they are large and not needed. + state["example_inputs"] = pytree.tree_map_only( + torch.Tensor, lambda _: None, state["example_inputs"] + ) + with patch.object(GraphPickler, "reducer_override", _graph_reducer_override): + state["graph_module"] = GraphPickler.dumps( + state["graph_module"], Options(ops_filter=None) + ) + state["example_inputs"] = GraphPickler.dumps(state["example_inputs"]) + return pickle.dumps(state) + + @classmethod + def deserialize_compile_artifacts(cls, data: bytes) -> "VllmSerializableFunction": + from torch._guards import TracingContext, tracing + from torch._subclasses import FakeTensorMode + from torch.fx._graph_pickler import GraphPickler + from torch.fx.experimental.symbolic_shapes import ShapeEnv + + from vllm.compilation.backends import VllmBackend + + state = pickle.loads(data) + fake_mode = FakeTensorMode(shape_env=ShapeEnv()) + state["graph_module"] = GraphPickler.loads(state["graph_module"], fake_mode) + state["example_inputs"] = GraphPickler.loads(state["example_inputs"], fake_mode) + vllm_backend = VllmBackend(get_current_vllm_config(), state["prefix"]) + + def optimized_call(*example_inputs): + """ + On the first run of the optimized call, we rerun the compiler + backend which should result in a cache hit. After the backend + call returns, we just do a one-time replacement of the optimized + call with the compiled function, so that subsequent calls are on + the AOT compiled path. + """ + compile_inputs = [ + inp or example_inputs[i] for i, inp in enumerate(fn.example_inputs) + ] + with tracing(TracingContext(fake_mode)): + fn.optimized_call = vllm_backend( + state["graph_module"], compile_inputs + ).optimized_call + return fn.optimized_call(*example_inputs) + + fn = cls(**state, optimized_call=optimized_call) + return fn + + @property + def co_name(self): + """ + Used for depyf debugging. + """ + return "VllmSerializableFunction" + + +def compilation_config_hash_factors(vllm_config: VllmConfig) -> list[str]: + factors = [] + # 0. factors come from the env, for example, The values of + # VLLM_PP_LAYER_PARTITION will affect the computation graph. + env_hash = envs.compute_hash() + factors.append(env_hash) + + # 1. factors come from the vllm_config (it mainly summarizes how the + # model is created) + config_hash = vllm_config.compute_hash() + factors.append(config_hash) + return factors + + +def _compute_code_hash_with_content(file_contents: dict[str, str]) -> str: + items = list(sorted(file_contents.items(), key=lambda x: x[0])) + hash_content = [] + for filepath, content in items: + hash_content.append(filepath) + if filepath == "": + # This means the function was dynamically generated, with + # e.g. exec(). We can't actually check these. + continue + hash_content.append(content) + return hashlib.md5( + "\n".join(hash_content).encode(), usedforsecurity=False + ).hexdigest() + + +def _compute_code_hash(files: set[str]) -> str: + logger.debug( + "Traced files (to be considered for compilation cache):\n%s", "\n".join(files) + ) + file_contents = {} + for filepath in files: + if filepath == "": + file_contents[filepath] = "" + else: + with open(filepath) as f: + file_contents[filepath] = f.read() + return _compute_code_hash_with_content(file_contents) diff --git a/vllm/compilation/compiler_interface.py b/vllm/compilation/compiler_interface.py index 4b1893887ac8..e5fa2518b87b 100644 --- a/vllm/compilation/compiler_interface.py +++ b/vllm/compilation/compiler_interface.py @@ -199,6 +199,7 @@ def compile( if compiler_config is not None: current_config.update(compiler_config) set_inductor_config(current_config, runtime_shape) + set_functorch_config() if isinstance(runtime_shape, int): dynamic_shapes = "from_example_inputs" @@ -307,6 +308,7 @@ def compile( current_config["fx_graph_remote_cache"] = False set_inductor_config(current_config, runtime_shape) + set_functorch_config() # inductor can inplace modify the graph, so we need to copy it # see https://github.com/pytorch/pytorch/issues/138980 @@ -596,6 +598,10 @@ def set_inductor_config(config, runtime_shape): ) +def set_functorch_config(): + torch._functorch.config.bundled_autograd_cache = False + + class EagerAdaptor(CompilerInterface): name = "eager" diff --git a/vllm/compilation/decorators.py b/vllm/compilation/decorators.py index 74090a00c7e2..20bf63c80401 100644 --- a/vllm/compilation/decorators.py +++ b/vllm/compilation/decorators.py @@ -2,7 +2,10 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import contextlib +import hashlib import inspect +import os +import sys from typing import Callable, Optional, TypeVar, Union, overload from unittest.mock import patch @@ -14,7 +17,7 @@ import vllm.envs as envs from vllm.compilation.counter import compilation_counter from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher -from vllm.config import CompilationLevel, VllmConfig +from vllm.config import CompilationLevel, VllmConfig, set_current_vllm_config from vllm.logger import init_logger from vllm.sequence import IntermediateTensors from vllm.utils import resolve_obj_by_qualname, supports_dynamo @@ -177,6 +180,33 @@ def cls_decorator_helper(cls: _T) -> _T: return cls_decorator_helper +def _model_hash_key(fn) -> str: + import vllm + + sha256_hash = hashlib.sha256() + sha256_hash.update(vllm.__version__.encode()) + sha256_hash.update(fn.__qualname__.encode()) + sha256_hash.update(str(fn.__code__.co_firstlineno).encode()) + return sha256_hash.hexdigest() + + +def _verify_source_unchanged(source_info, vllm_config) -> None: + from .caching import _compute_code_hash, _compute_code_hash_with_content + + file_contents = {} + for source in source_info.inlined_sources: + module = sys.modules[source.module] + file = inspect.getfile(module) + vllm_config.compilation_config.traced_files.add(file) + file_contents[file] = source.content + expected_checksum = _compute_code_hash_with_content(file_contents) + actual_checksum = _compute_code_hash(set(file_contents.keys())) + if expected_checksum != actual_checksum: + raise RuntimeError( + "Source code has changed since the last compilation. Recompiling the model." + ) + + def _support_torch_compile( cls: _T, dynamic_arg_dims: dict[str, Union[int, list[int]]], @@ -231,6 +261,61 @@ def __call__(self, *args, **kwargs): if getattr(self, "aot_compiled_fn", None) is not None: return self.aot_compiled_fn(self, *args, **kwargs) + cache_dir = None + aot_compilation_path = None + if envs.VLLM_USE_AOT_COMPILE: + """ + When using torch.compile in AOT mode, we store the cache artifacts + under VLLM_CACHE_ROOT/torch_aot_compile/{hash}/rank_i_j. The {hash} + contains all of the factors except for the source files being + traced through, because we don't actually know which source files + to check at this point (before dynamo runs). + On loading we will actually look at the source files being traced + through. If any source file have changed (compared with the + serialized backend artifacts), then we need to generate a new AOT + compile artifact from scratch. + """ + from .caching import compilation_config_hash_factors + + factors: list[str] = compilation_config_hash_factors(self.vllm_config) + + factors.append(_model_hash_key(self.forward)) + hash_key = hashlib.sha256(str(factors).encode()).hexdigest() + + cache_dir = os.path.join( + envs.VLLM_CACHE_ROOT, + "torch_aot_compile", + hash_key, + ) + + rank = self.vllm_config.parallel_config.rank + dp_rank = self.vllm_config.parallel_config.data_parallel_rank + cache_dir = os.path.join(cache_dir, f"rank_{rank}_{dp_rank}") + aot_compilation_path = os.path.join(cache_dir, "model") + try: + with ( + set_current_vllm_config(self.vllm_config), + open(aot_compilation_path, "rb") as f, + ): + start_monitoring_torch_compile(self.vllm_config) + loaded_fn = torch.compiler.load_compiled_function(f) + _verify_source_unchanged(loaded_fn.source_info(), self.vllm_config) + self.aot_compiled_fn = loaded_fn + except Exception as e: + if os.path.exists(aot_compilation_path): + logger.warning( + "Cannot load aot compilation from path %s, error: %s", + aot_compilation_path, + str(e), + ) + if envs.VLLM_FORCE_AOT_LOAD: + raise e + if getattr(self, "aot_compiled_fn", None) is not None: + logger.info( + "Directly load AOT compilation from path %s", aot_compilation_path + ) + return self.aot_compiled_fn(self, *args, **kwargs) + # the first compilation needs to have dynamic shapes marked if len(self.compiled_codes) < 1: sig = inspect.signature(self.__class__.forward) @@ -279,15 +364,15 @@ def __call__(self, *args, **kwargs): ) # 2. every time Dynamo sees a function call, it will inline - # the function by calling InliningInstructionTranslator.inline_call + # the function by calling InliningInstructionTranslator.inline_call_ # we hijack this function to know all the functions called # during Dynamo tracing, and their corresponding files - inline_call = InliningInstructionTranslator.inline_call + inline_call = InliningInstructionTranslator.inline_call_ - def patched_inline_call(parent, func, args, kwargs): - code = func.get_code() + def patched_inline_call(self_): + code = self_.f_code self.vllm_config.compilation_config.traced_files.add(code.co_filename) - return inline_call(parent, func, args, kwargs) + return inline_call(self_) # Disable the C++ compilation of symbolic shape guards. C++-fication # of symbolic shape guards can improve guard overhead. But, since @@ -304,7 +389,7 @@ def patched_inline_call(parent, func, args, kwargs): with ( patch.object( - InliningInstructionTranslator, "inline_call", patched_inline_call + InliningInstructionTranslator, "inline_call_", patched_inline_call ), torch._dynamo.config.patch(**dynamo_config_patches), maybe_use_cudagraph_partition_wrapper(self.vllm_config), @@ -313,6 +398,10 @@ def patched_inline_call(parent, func, args, kwargs): if envs.VLLM_USE_AOT_COMPILE: self.aot_compiled_fn = self.aot_compile(*args, **kwargs) output = self.aot_compiled_fn(self, *args, **kwargs) + assert aot_compilation_path is not None + assert cache_dir is not None + os.makedirs(cache_dir, exist_ok=True) + self.aot_compiled_fn.save_compiled_function(aot_compilation_path) else: output = self.compiled_callable(*args, **kwargs) return output diff --git a/vllm/compilation/wrapper.py b/vllm/compilation/wrapper.py index 4202a78b3105..2007b655e264 100644 --- a/vllm/compilation/wrapper.py +++ b/vllm/compilation/wrapper.py @@ -42,18 +42,15 @@ def __init__( backend = vllm_config.compilation_config.init_backend(vllm_config) options = None if isinstance(backend, str) and backend == "inductor": -<<<<<<< HEAD options = ( get_current_vllm_config().compilation_config.inductor_compile_config ) -======= - options = get_current_vllm_config( - ).compilation_config.inductor_compile_config if envs.VLLM_USE_AOT_COMPILE: options = options or {} - options["guard_filter_fn"] = lambda guards: [ - False for _ in guards - ] + # This effectively drop all the guards. + # We need this because bytecode hook is not used any more to + # drop guards in the AOT compile mode. + options["guard_filter_fn"] = lambda guards: [False for _ in guards] if hasattr(torch._dynamo.config, "enable_aot_compile"): torch._dynamo.config.enable_aot_compile = True else: @@ -61,7 +58,6 @@ def __init__( msg += "available. AOT compile is disabled and please " msg += "upgrade PyTorch version to use AOT compile." logger.warning(msg) ->>>>>>> 6fc29676a (AOT compilation workflow [1/n]) compiled_callable = torch.compile( self.forward, fullgraph=True, backend=backend, options=options @@ -82,9 +78,10 @@ def __init__( def aot_compile(self, *args, **kwargs): if not hasattr(self.compiled_callable, "aot_compile"): raise RuntimeError( - "aot_compile is not supported by the current configuration. " + - "Please make sure torch.compile is enabled with the latest " + - "version of PyTorch") + "aot_compile is not supported by the current configuration. " + + "Please make sure torch.compile is enabled with the latest " + + f"version of PyTorch (current using torch: {torch.__version__})" + ) return self.compiled_callable.aot_compile((args, kwargs)) def __call__(self, *args, **kwargs): diff --git a/vllm/envs.py b/vllm/envs.py index afcebef8b12a..ab8548cf5066 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -89,6 +89,8 @@ VLLM_TORCH_PROFILER_DIR: Optional[str] = None VLLM_TORCH_PROFILER_RECORD_SHAPES: bool = False VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY: bool = False + VLLM_USE_AOT_COMPILE: bool = False + VLLM_FORCE_AOT_LOAD: bool = False VLLM_TORCH_PROFILER_WITH_STACK: bool = True VLLM_TORCH_PROFILER_WITH_FLOPS: bool = False VLLM_USE_TRITON_AWQ: bool = False @@ -235,6 +237,13 @@ def maybe_convert_bool(value: Optional[str]) -> Optional[bool]: return bool(int(value)) +def use_aot_compile() -> bool: + from vllm.utils import is_torch_equal_or_newer + + default_value = "1" if is_torch_equal_or_newer("2.10.0.dev") else "0" + return os.environ.get("VLLM_USE_AOT_COMPILE", default_value) == "1" + + def env_with_choices( env_name: str, default: Optional[str], @@ -494,13 +503,14 @@ def get_vllm_port() -> Optional[int]: # Dump fx graphs to the given directory. # It will override CompilationConfig.debug_dump_path if set. "VLLM_DEBUG_DUMP_PATH": lambda: os.environ.get("VLLM_DEBUG_DUMP_PATH", None), - # Feature flag to enable/disable AOT compilation. This will ensure # compilation is done in warmup phase and the compilation will be # reused in subsequent calls. - "VLLM_USE_AOT_COMPILE": - lambda: os.environ.get("VLLM_USE_AOT_COMPILE", "0") == "1", - + "VLLM_USE_AOT_COMPILE": use_aot_compile, + # Force vllm to always load AOT compiled models from disk. Failure + # to load will result in a hard error when this is enabled. + # Will be ignored when VLLM_USE_AOT_COMPILE is disabled. + "VLLM_FORCE_AOT_LOAD": lambda: os.environ.get("VLLM_FORCE_AOT_LOAD", "0") == "1", # local rank of the process in the distributed setting, used to determine # the GPU device id "LOCAL_RANK": lambda: int(os.environ.get("LOCAL_RANK", "0")),