Skip to content

Commit d297def

Browse files
committed
AOT compilation workflow [1/n]
Signed-off-by: zhxchen17 <zhxchen17@fb.com>
1 parent fc67969 commit d297def

File tree

6 files changed

+137
-3
lines changed

6 files changed

+137
-3
lines changed

.buildkite/test-pipeline.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -404,6 +404,7 @@ steps:
404404
- pytest -v -s compile/test_fusion_all_reduce.py
405405
- pytest -v -s compile/test_decorator.py
406406
- pytest -v -s compile/test_noop_elimination.py
407+
- pytest -v -s compile/test_aot_compile.py
407408

408409
- label: PyTorch Fullgraph Smoke Test # 15min
409410
timeout_in_minutes: 30

tests/compile/test_aot_compile.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
from contextlib import contextmanager
5+
6+
import pytest
7+
import torch
8+
9+
from vllm.compilation.decorators import support_torch_compile
10+
from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig,
11+
set_current_vllm_config)
12+
from vllm.forward_context import set_forward_context
13+
14+
15+
class MyMod(torch.nn.Module):
16+
17+
def __init__(self, **kwargs):
18+
super().__init__()
19+
20+
def forward(self, x: torch.Tensor):
21+
for _ in range(3000):
22+
x = x + x.shape[0]
23+
return x
24+
25+
26+
def make_vllm_config() -> VllmConfig:
27+
return VllmConfig(compilation_config=CompilationConfig(
28+
level=CompilationLevel.PIECEWISE, ))
29+
30+
31+
@contextmanager
32+
def use_vllm_config(vllm_config: VllmConfig):
33+
with set_forward_context(
34+
{}, vllm_config), set_current_vllm_config(vllm_config):
35+
yield
36+
37+
38+
def test_no_eval_frame(monkeypatch: pytest.MonkeyPatch):
39+
with monkeypatch.context() as m:
40+
mod = MyMod()
41+
args = (torch.randn(10, 10), )
42+
expected = mod(*args)
43+
CompiledMod = support_torch_compile(MyMod)
44+
45+
vllm_config = make_vllm_config()
46+
m.setenv("VLLM_USE_AOT_COMPILE", "0")
47+
try:
48+
with use_vllm_config(vllm_config), torch.compiler.set_stance(
49+
"fail_on_recompile"):
50+
CompiledMod(vllm_config=vllm_config)(*args)
51+
except RuntimeError as e:
52+
assert "Detected recompile" in str(e)
53+
else:
54+
raise AssertionError("Expected exception to be raised")
55+
56+
m.setenv("VLLM_USE_AOT_COMPILE", "1")
57+
torch._dynamo.reset()
58+
with use_vllm_config(vllm_config), torch.compiler.set_stance(
59+
"fail_on_recompile"):
60+
ret = CompiledMod(vllm_config=vllm_config)(*args)
61+
assert torch.allclose(ret, expected)

vllm/compilation/backends.py

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -441,6 +441,35 @@ def set_model_tag(tag: str):
441441
model_tag = old_tag
442442

443443

444+
try:
445+
from torch._dynamo.aot_compile import SerializableCallable
446+
except ImportError:
447+
SerializableCallable = object
448+
449+
assert isinstance(SerializableCallable, type)
450+
451+
452+
class VllmCompiledFunction(SerializableCallable):
453+
454+
def __init__(self, graph_module, example_inputs, vllm_config,
455+
optimized_call):
456+
self.graph_module = graph_module
457+
self.example_inputs = example_inputs
458+
self.vllm_config = vllm_config
459+
self.optimized_call = optimized_call
460+
461+
def __call__(self, *args, **kwargs):
462+
return self.optimized_call(*args, **kwargs)
463+
464+
@classmethod
465+
def serialize_compile_artifacts(cls, compiled_fn):
466+
raise NotImplementedError("serialization not implemented")
467+
468+
@classmethod
469+
def deserialize_compile_artifacts(cls, data):
470+
raise NotImplementedError("deserialization not implemented")
471+
472+
444473
class VllmBackend:
445474
"""The compilation backend for `torch.compile` with vLLM.
446475
It is used for compilation level of `CompilationLevel.PIECEWISE`,
@@ -659,7 +688,8 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
659688
self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE
660689
or not self.compilation_config.cudagraph_copy_inputs
661690
):
662-
return self.split_gm
691+
return VllmCompiledFunction(graph, example_inputs, vllm_config,
692+
self.split_gm)
663693

664694
# if we need to copy input buffers for cudagraph
665695
from torch._guards import detect_fake_mode
@@ -704,4 +734,5 @@ def copy_and_call(*args):
704734
list_args[index] = static_tensor
705735
return self.split_gm(*list_args)
706736

707-
return copy_and_call
737+
return VllmCompiledFunction(graph, example_inputs, vllm_config,
738+
copy_and_call)

vllm/compilation/decorators.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from packaging import version
1212
from torch._dynamo.symbolic_convert import InliningInstructionTranslator
1313

14+
import vllm.envs as envs
1415
from vllm.compilation.counter import compilation_counter
1516
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
1617
from vllm.config import CompilationLevel, VllmConfig
@@ -227,6 +228,9 @@ def __call__(self, *args, **kwargs):
227228
if self.do_not_compile or torch.compiler.is_compiling():
228229
return self.forward(*args, **kwargs)
229230

231+
if getattr(self, "aot_compiled_fn", None) is not None:
232+
return self.aot_compiled_fn(self, *args, **kwargs)
233+
230234
# the first compilation needs to have dynamic shapes marked
231235
if len(self.compiled_codes) < 1:
232236
sig = inspect.signature(self.__class__.forward)
@@ -306,7 +310,11 @@ def patched_inline_call(parent, func, args, kwargs):
306310
maybe_use_cudagraph_partition_wrapper(self.vllm_config),
307311
_torch27_patch_tensor_subclasses(),
308312
):
309-
output = self.compiled_callable(*args, **kwargs)
313+
if envs.VLLM_USE_AOT_COMPILE:
314+
self.aot_compiled_fn = self.aot_compile(*args, **kwargs)
315+
output = self.aot_compiled_fn(self, *args, **kwargs)
316+
else:
317+
output = self.compiled_callable(*args, **kwargs)
310318
return output
311319

312320
# usually, capturing the model once is enough, and then we can

vllm/compilation/wrapper.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import torch
1212

13+
import vllm.envs as envs
1314
from vllm.config import CompilationLevel, CUDAGraphMode, get_current_vllm_config
1415
from vllm.logger import init_logger
1516

@@ -41,9 +42,26 @@ def __init__(
4142
backend = vllm_config.compilation_config.init_backend(vllm_config)
4243
options = None
4344
if isinstance(backend, str) and backend == "inductor":
45+
<<<<<<< HEAD
4446
options = (
4547
get_current_vllm_config().compilation_config.inductor_compile_config
4648
)
49+
=======
50+
options = get_current_vllm_config(
51+
).compilation_config.inductor_compile_config
52+
if envs.VLLM_USE_AOT_COMPILE:
53+
options = options or {}
54+
options["guard_filter_fn"] = lambda guards: [
55+
False for _ in guards
56+
]
57+
if hasattr(torch._dynamo.config, "enable_aot_compile"):
58+
torch._dynamo.config.enable_aot_compile = True
59+
else:
60+
msg = "torch._dynamo.config.enable_aot_compile is not "
61+
msg += "available. AOT compile is disabled and please "
62+
msg += "upgrade PyTorch version to use AOT compile."
63+
logger.warning(msg)
64+
>>>>>>> 6fc29676a (AOT compilation workflow [1/n])
4765

4866
compiled_callable = torch.compile(
4967
self.forward, fullgraph=True, backend=backend, options=options
@@ -61,6 +79,14 @@ def __init__(
6179
compilation_level >= CompilationLevel.DYNAMO_ONCE
6280
)
6381

82+
def aot_compile(self, *args, **kwargs):
83+
if not hasattr(self.compiled_callable, "aot_compile"):
84+
raise RuntimeError(
85+
"aot_compile is not supported by the current configuration. " +
86+
"Please make sure torch.compile is enabled with the latest " +
87+
"version of PyTorch")
88+
return self.compiled_callable.aot_compile((args, kwargs))
89+
6490
def __call__(self, *args, **kwargs):
6591
"""Implement the dispatch logic here, beyond the torch.compile level.
6692
NOTE: this function can have additional arguments beyond the forward

vllm/envs.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -499,6 +499,13 @@ def get_vllm_port() -> Optional[int]:
499499
# Dump fx graphs to the given directory.
500500
# It will override CompilationConfig.debug_dump_path if set.
501501
"VLLM_DEBUG_DUMP_PATH": lambda: os.environ.get("VLLM_DEBUG_DUMP_PATH", None),
502+
503+
# Feature flag to enable/disable AOT compilation. This will ensure
504+
# compilation is done in warmup phase and the compilation will be
505+
# reused in subsequent calls.
506+
"VLLM_USE_AOT_COMPILE":
507+
lambda: os.environ.get("VLLM_USE_AOT_COMPILE", "0") == "1",
508+
502509
# local rank of the process in the distributed setting, used to determine
503510
# the GPU device id
504511
"LOCAL_RANK": lambda: int(os.environ.get("LOCAL_RANK", "0")),

0 commit comments

Comments
 (0)