Skip to content

Commit e925619

Browse files
zhxchen17dolpm
authored andcommitted
AOT compilation workflow [1/n]
Signed-off-by: zhxchen17 <zhxchen17@fb.com>
1 parent 6b0fcbb commit e925619

File tree

6 files changed

+132
-5
lines changed

6 files changed

+132
-5
lines changed

.buildkite/test-pipeline.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,7 @@ steps:
387387
- pytest -v -s compile/test_fusion_all_reduce.py
388388
- pytest -v -s compile/test_decorator.py
389389
- pytest -v -s compile/test_noop_elimination.py
390+
- pytest -v -s compile/test_aot_compile.py
390391

391392
- label: PyTorch Fullgraph Smoke Test # 15min
392393
timeout_in_minutes: 30

tests/compile/test_aot_compile.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
from contextlib import contextmanager
5+
6+
import pytest
7+
import torch
8+
9+
from vllm.compilation.decorators import support_torch_compile
10+
from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig,
11+
set_current_vllm_config)
12+
from vllm.forward_context import set_forward_context
13+
14+
15+
class MyMod(torch.nn.Module):
16+
17+
def __init__(self, **kwargs):
18+
super().__init__()
19+
20+
def forward(self, x: torch.Tensor):
21+
for _ in range(3000):
22+
x = x + x.shape[0]
23+
return x
24+
25+
26+
def make_vllm_config() -> VllmConfig:
27+
return VllmConfig(compilation_config=CompilationConfig(
28+
level=CompilationLevel.PIECEWISE, ))
29+
30+
31+
@contextmanager
32+
def use_vllm_config(vllm_config: VllmConfig):
33+
with set_forward_context(
34+
{}, vllm_config), set_current_vllm_config(vllm_config):
35+
yield
36+
37+
38+
def test_no_eval_frame(monkeypatch: pytest.MonkeyPatch):
39+
with monkeypatch.context() as m:
40+
mod = MyMod()
41+
args = (torch.randn(10, 10), )
42+
expected = mod(*args)
43+
CompiledMod = support_torch_compile(MyMod)
44+
45+
vllm_config = make_vllm_config()
46+
m.setenv("VLLM_USE_AOT_COMPILE", "0")
47+
try:
48+
with use_vllm_config(vllm_config), torch.compiler.set_stance(
49+
"fail_on_recompile"):
50+
CompiledMod(vllm_config=vllm_config)(*args)
51+
except RuntimeError as e:
52+
assert "Detected recompile" in str(e)
53+
else:
54+
raise AssertionError("Expected exception to be raised")
55+
56+
m.setenv("VLLM_USE_AOT_COMPILE", "1")
57+
torch._dynamo.reset()
58+
with use_vllm_config(vllm_config), torch.compiler.set_stance(
59+
"fail_on_recompile"):
60+
ret = CompiledMod(vllm_config=vllm_config)(*args)
61+
assert torch.allclose(ret, expected)

vllm/compilation/backends.py

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -398,6 +398,35 @@ def set_model_tag(tag: str):
398398
model_tag = old_tag
399399

400400

401+
try:
402+
from torch._dynamo.aot_compile import SerializableCallable
403+
except ImportError:
404+
SerializableCallable = object
405+
406+
assert isinstance(SerializableCallable, type)
407+
408+
409+
class VllmCompiledFunction(SerializableCallable):
410+
411+
def __init__(self, graph_module, example_inputs, vllm_config,
412+
optimized_call):
413+
self.graph_module = graph_module
414+
self.example_inputs = example_inputs
415+
self.vllm_config = vllm_config
416+
self.optimized_call = optimized_call
417+
418+
def __call__(self, *args, **kwargs):
419+
return self.optimized_call(*args, **kwargs)
420+
421+
@classmethod
422+
def serialize_compile_artifacts(cls, compiled_fn):
423+
raise NotImplementedError("serialization not implemented")
424+
425+
@classmethod
426+
def deserialize_compile_artifacts(cls, data):
427+
raise NotImplementedError("deserialization not implemented")
428+
429+
401430
class VllmBackend:
402431
"""The compilation backend for `torch.compile` with vLLM.
403432
It is used for compilation level of `CompilationLevel.PIECEWISE`,
@@ -605,7 +634,8 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
605634

606635
if self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE or \
607636
not self.compilation_config.cudagraph_copy_inputs:
608-
return self.split_gm
637+
return VllmCompiledFunction(graph, example_inputs, vllm_config,
638+
self.split_gm)
609639

610640
# if we need to copy input buffers for cudagraph
611641
from torch._guards import detect_fake_mode
@@ -647,4 +677,5 @@ def copy_and_call(*args):
647677
list_args[index] = static_tensor
648678
return self.split_gm(*list_args)
649679

650-
return copy_and_call
680+
return VllmCompiledFunction(graph, example_inputs, vllm_config,
681+
copy_and_call)

vllm/compilation/decorators.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from packaging import version
1212
from torch._dynamo.symbolic_convert import InliningInstructionTranslator
1313

14+
import vllm.envs as envs
1415
from vllm.compilation.counter import compilation_counter
1516
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
1617
from vllm.config import CompilationLevel, VllmConfig
@@ -34,11 +35,11 @@ def ignore_torch_compile(cls: _T) -> _T:
3435
a support_torch_compile decorator, but we don't want to
3536
compile the class `cls` that inherits the parent class.
3637
This only ignores compiling the forward of the class the
37-
decorator is applied to.
38+
decorator is applied to.
3839
3940
If the parent has ignore_torch_compile but the child has
4041
support_torch_compile, the child will still be compiled.
41-
42+
4243
If the class has one or more submodules
4344
that have support_torch_compile decorator applied, compile will
4445
not be ignored for those submodules.
@@ -224,6 +225,9 @@ def __call__(self, *args, **kwargs):
224225
if self.do_not_compile or torch.compiler.is_compiling():
225226
return self.forward(*args, **kwargs)
226227

228+
if getattr(self, "aot_compiled_fn", None) is not None:
229+
return self.aot_compiled_fn(self, *args, **kwargs)
230+
227231
# the first compilation needs to have dynamic shapes marked
228232
if len(self.compiled_codes) < 1:
229233
sig = inspect.signature(self.__class__.forward)
@@ -307,7 +311,11 @@ def patched_inline_call(parent, func, args, kwargs):
307311
**dynamo_config_patches
308312
), maybe_use_cudagraph_partition_wrapper(
309313
self.vllm_config), _torch27_patch_tensor_subclasses():
310-
output = self.compiled_callable(*args, **kwargs)
314+
if envs.VLLM_USE_AOT_COMPILE:
315+
self.aot_compiled_fn = self.aot_compile(*args, **kwargs)
316+
output = self.aot_compiled_fn(self, *args, **kwargs)
317+
else:
318+
output = self.compiled_callable(*args, **kwargs)
311319
return output
312320

313321
# usually, capturing the model once is enough, and then we can

vllm/compilation/wrapper.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,18 @@ def __init__(self,
4545
if isinstance(backend, str) and backend == "inductor":
4646
options = get_current_vllm_config(
4747
).compilation_config.inductor_compile_config
48+
if envs.VLLM_USE_AOT_COMPILE:
49+
options = options or {}
50+
options["guard_filter_fn"] = lambda guards: [
51+
False for _ in guards
52+
]
53+
if hasattr(torch._dynamo.config, "enable_aot_compile"):
54+
torch._dynamo.config.enable_aot_compile = True
55+
else:
56+
msg = "torch._dynamo.config.enable_aot_compile is not "
57+
msg += "available. AOT compile is disabled and please "
58+
msg += "upgrade PyTorch version to use AOT compile."
59+
logger.warning(msg)
4860

4961
compiled_callable = torch.compile(self.forward,
5062
fullgraph=True,
@@ -62,6 +74,14 @@ def __init__(self,
6274
self.use_custom_dispatcher: bool = \
6375
compilation_level >= CompilationLevel.DYNAMO_ONCE
6476

77+
def aot_compile(self, *args, **kwargs):
78+
if not hasattr(self.compiled_callable, "aot_compile"):
79+
raise RuntimeError(
80+
"aot_compile is not supported by the current configuration. " +
81+
"Please make sure torch.compile is enabled with the latest " +
82+
"version of PyTorch")
83+
return self.compiled_callable.aot_compile((args, kwargs))
84+
6585
def __call__(self, *args, **kwargs):
6686
"""Implement the dispatch logic here, beyond the torch.compile level.
6787
NOTE: this function can have additional arguments beyond the forward

vllm/envs.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -511,6 +511,12 @@ def get_vllm_port() -> Optional[int]:
511511
"VLLM_PATTERN_MATCH_DEBUG":
512512
lambda: os.environ.get("VLLM_PATTERN_MATCH_DEBUG", None),
513513

514+
# Feature flag to enable/disable AOT compilation. This will ensure
515+
# compilation is done in warmup phase and the compilation will be
516+
# reused in subsequent calls.
517+
"VLLM_USE_AOT_COMPILE":
518+
lambda: os.environ.get("VLLM_USE_AOT_COMPILE", "0") == "1",
519+
514520
# local rank of the process in the distributed setting, used to determine
515521
# the GPU device id
516522
"LOCAL_RANK":

0 commit comments

Comments
 (0)