Skip to content

Commit df40fe6

Browse files
committed
AOT compilation workflow [2/n]
Signed-off-by: zhxchen17 <zhxchen17@fb.com>
1 parent ab7acd7 commit df40fe6

File tree

8 files changed

+410
-114
lines changed

8 files changed

+410
-114
lines changed

tests/compile/test_aot_compile.py

Lines changed: 93 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,61 +1,126 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

4+
import tempfile
45
from contextlib import contextmanager
56

67
import pytest
78
import torch
89

910
from vllm.compilation.decorators import support_torch_compile
10-
from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig,
11-
set_current_vllm_config)
11+
from vllm.config import (
12+
CompilationConfig,
13+
CompilationLevel,
14+
VllmConfig,
15+
set_current_vllm_config,
16+
)
1217
from vllm.forward_context import set_forward_context
1318

1419

15-
class MyMod(torch.nn.Module):
20+
def reference_fn(x: torch.Tensor):
21+
assert x.shape[0] <= 42
22+
assert x.shape[0] % 2 == 0
23+
for _ in range(3000):
24+
x = x + x.shape[0]
25+
return x
1626

27+
28+
@support_torch_compile
29+
class CompiledMod(torch.nn.Module):
1730
def __init__(self, **kwargs):
1831
super().__init__()
1932

2033
def forward(self, x: torch.Tensor):
21-
for _ in range(3000):
22-
x = x + x.shape[0]
23-
return x
34+
return reference_fn(x)
2435

2536

2637
def make_vllm_config() -> VllmConfig:
27-
return VllmConfig(compilation_config=CompilationConfig(
28-
level=CompilationLevel.PIECEWISE, ))
38+
return VllmConfig(
39+
compilation_config=CompilationConfig(
40+
level=CompilationLevel.PIECEWISE,
41+
)
42+
)
2943

3044

3145
@contextmanager
3246
def use_vllm_config(vllm_config: VllmConfig):
33-
with set_forward_context(
34-
{}, vllm_config), set_current_vllm_config(vllm_config):
47+
with set_forward_context({}, vllm_config), set_current_vllm_config(vllm_config):
3548
yield
3649

3750

38-
def test_no_eval_frame(monkeypatch: pytest.MonkeyPatch):
51+
def test_no_dynamo_cache_entry(monkeypatch: pytest.MonkeyPatch):
3952
with monkeypatch.context() as m:
40-
mod = MyMod()
41-
args = (torch.randn(10, 10), )
42-
expected = mod(*args)
43-
CompiledMod = support_torch_compile(MyMod)
44-
4553
vllm_config = make_vllm_config()
46-
m.setenv("VLLM_USE_AOT_COMPILE", "0")
47-
try:
48-
with use_vllm_config(vllm_config), torch.compiler.set_stance(
49-
"fail_on_recompile"):
54+
args = (torch.randn(10, 10),)
55+
expected = reference_fn(*args)
56+
with use_vllm_config(vllm_config):
57+
m.setenv("VLLM_USE_AOT_COMPILE", "0")
58+
with (
59+
pytest.raises(RuntimeError, match="Detected recompile"),
60+
torch.compiler.set_stance("fail_on_recompile"),
61+
):
5062
CompiledMod(vllm_config=vllm_config)(*args)
51-
except RuntimeError as e:
52-
assert "Detected recompile" in str(e)
53-
else:
54-
raise AssertionError("Expected exception to be raised")
5563

64+
m.setenv("VLLM_USE_AOT_COMPILE", "1")
65+
torch._dynamo.reset()
66+
with torch.compiler.set_stance("fail_on_recompile"):
67+
actual = CompiledMod(vllm_config=vllm_config)(*args)
68+
assert torch.allclose(actual, expected)
69+
70+
71+
def test_force_aot_load(monkeypatch: pytest.MonkeyPatch):
72+
with tempfile.TemporaryDirectory() as tmpdirname, monkeypatch.context() as m:
73+
args = (torch.randn(10, 10),)
5674
m.setenv("VLLM_USE_AOT_COMPILE", "1")
57-
torch._dynamo.reset()
58-
with use_vllm_config(vllm_config), torch.compiler.set_stance(
59-
"fail_on_recompile"):
60-
ret = CompiledMod(vllm_config=vllm_config)(*args)
75+
m.setenv("VLLM_FORCE_AOT_LOAD", "1")
76+
m.setenv("VLLM_CACHE_ROOT", tmpdirname)
77+
vllm_config = make_vllm_config()
78+
with use_vllm_config(vllm_config), pytest.raises(FileNotFoundError):
79+
CompiledMod(vllm_config=vllm_config)(*args)
80+
81+
82+
def test_save_and_load(monkeypatch: pytest.MonkeyPatch):
83+
with monkeypatch.context() as m:
84+
args = (torch.randn(10, 10),)
85+
86+
with tempfile.TemporaryDirectory() as tmpdirname:
87+
m.setenv("VLLM_CACHE_ROOT", tmpdirname)
88+
m.setenv("VLLM_USE_AOT_COMPILE", "1")
89+
vllm_config = make_vllm_config()
90+
with use_vllm_config(vllm_config):
91+
expected = CompiledMod(vllm_config=vllm_config)(*args)
92+
93+
m.setenv("VLLM_FORCE_AOT_LOAD", "1")
94+
vllm_config = make_vllm_config()
95+
with use_vllm_config(vllm_config):
96+
ret = CompiledMod(vllm_config=vllm_config)(*args)
6197
assert torch.allclose(ret, expected)
98+
99+
100+
def test_shape_env(monkeypatch: pytest.MonkeyPatch):
101+
"""
102+
Test that the shape environment is correctly serialized and preserved
103+
when loading from cache.
104+
"""
105+
with monkeypatch.context() as m:
106+
args = (torch.randn(10, 10),)
107+
108+
with tempfile.TemporaryDirectory() as tmpdirname:
109+
m.setenv("VLLM_CACHE_ROOT", tmpdirname)
110+
m.setenv("VLLM_USE_AOT_COMPILE", "1")
111+
vllm_config = make_vllm_config()
112+
with use_vllm_config(vllm_config):
113+
compiled_mod = CompiledMod(vllm_config=vllm_config)
114+
compiled_mod(*args)
115+
artifacts = compiled_mod.aot_compiled_fn._artifacts
116+
guards_string = artifacts.compiled_fn.shape_env.format_guards()
117+
assert guards_string == " - s77 <= 42\n - Eq(Mod(s77, 2), 0)"
118+
119+
m.setenv("VLLM_FORCE_AOT_LOAD", "1")
120+
vllm_config = make_vllm_config()
121+
with use_vllm_config(vllm_config):
122+
compiled_mod = CompiledMod(vllm_config=vllm_config)
123+
compiled_mod(*args)
124+
artifacts = compiled_mod.aot_compiled_fn._artifacts
125+
guards_string = artifacts.compiled_fn.shape_env.format_guards()
126+
assert guards_string == " - s77 <= 42\n - Eq(Mod(s77, 2), 0)"

tools/pre_commit/check_pickle_imports.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
"vllm/multimodal/hasher.py",
2323
"vllm/transformers_utils/config.py",
2424
"vllm/model_executor/models/registry.py",
25+
"vllm/compilation/caching.py",
2526
"tests/utils_/test_utils.py",
2627
"tests/tokenization/test_cached_tokenizer.py",
2728
"vllm/distributed/utils.py",

vllm/compilation/backends.py

Lines changed: 16 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import ast
55
import dataclasses
6+
import hashlib
67
import os
78
import pprint
89
import time
@@ -25,6 +26,7 @@
2526
from vllm.platforms import current_platform
2627
from vllm.utils import is_torch_equal_or_newer, resolve_obj_by_qualname
2728

29+
from .caching import VllmSerializableFunction
2830
from .compiler_interface import (
2931
CompilerInterface,
3032
EagerAdaptor,
@@ -195,6 +197,7 @@ def compile(
195197
# there can be multiple graphs due to piecewise compilation.
196198
now = time.time()
197199
elapsed = now - compilation_start_time
200+
compilation_config.compilation_time += elapsed
198201
if runtime_shape is None:
199202
logger.info(
200203
"Directly load the compiled graph(s) for dynamic shape "
@@ -472,35 +475,6 @@ def set_model_tag(tag: str):
472475
model_tag = old_tag
473476

474477

475-
try:
476-
from torch._dynamo.aot_compile import SerializableCallable
477-
except ImportError:
478-
SerializableCallable = object
479-
480-
assert isinstance(SerializableCallable, type)
481-
482-
483-
class VllmCompiledFunction(SerializableCallable):
484-
485-
def __init__(self, graph_module, example_inputs, vllm_config,
486-
optimized_call):
487-
self.graph_module = graph_module
488-
self.example_inputs = example_inputs
489-
self.vllm_config = vllm_config
490-
self.optimized_call = optimized_call
491-
492-
def __call__(self, *args, **kwargs):
493-
return self.optimized_call(*args, **kwargs)
494-
495-
@classmethod
496-
def serialize_compile_artifacts(cls, compiled_fn):
497-
raise NotImplementedError("serialization not implemented")
498-
499-
@classmethod
500-
def deserialize_compile_artifacts(cls, data):
501-
raise NotImplementedError("deserialization not implemented")
502-
503-
504478
class VllmBackend:
505479
"""The compilation backend for `torch.compile` with vLLM.
506480
It is used for compilation level of `CompilationLevel.PIECEWISE`,
@@ -578,47 +552,23 @@ def configure_post_pass(self):
578552
self.post_grad_pass_manager.add(inductor_config[PASS_KEY])
579553
inductor_config[PASS_KEY] = self.post_grad_pass_manager
580554

581-
def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
555+
def __call__(
556+
self, graph: fx.GraphModule, example_inputs
557+
) -> VllmSerializableFunction:
558+
from .caching import _compute_code_hash, compilation_config_hash_factors
559+
582560
vllm_config = self.vllm_config
583561
if not self.compilation_config.cache_dir:
584562
# no provided cache dir, generate one based on the known factors
585563
# that affects the compilation. if none of the factors change,
586564
# the cache dir will be the same so that we can reuse the compiled
587565
# graph.
588566

589-
factors = []
590-
# 0. factors come from the env, for example, The values of
591-
# VLLM_PP_LAYER_PARTITION will affect the computation graph.
592-
env_hash = envs.compute_hash()
593-
factors.append(env_hash)
594-
595-
# 1. factors come from the vllm_config (it mainly summarizes how the
596-
# model is created)
597-
config_hash = vllm_config.compute_hash()
598-
factors.append(config_hash)
599-
567+
factors = compilation_config_hash_factors(vllm_config)
600568
# 2. factors come from the code files that are traced by Dynamo (
601569
# it mainly summarizes how the model is used in forward pass)
602-
forward_code_files = list(sorted(self.compilation_config.traced_files))
570+
code_hash = _compute_code_hash(self.compilation_config.traced_files)
603571
self.compilation_config.traced_files.clear()
604-
logger.debug(
605-
"Traced files (to be considered for compilation cache):\n%s",
606-
"\n".join(forward_code_files),
607-
)
608-
hash_content = []
609-
for filepath in forward_code_files:
610-
hash_content.append(filepath)
611-
if filepath == "<string>":
612-
# This means the function was dynamically generated, with
613-
# e.g. exec(). We can't actually check these.
614-
continue
615-
with open(filepath) as f:
616-
hash_content.append(f.read())
617-
import hashlib
618-
619-
code_hash = hashlib.md5(
620-
"\n".join(hash_content).encode(), usedforsecurity=False
621-
).hexdigest()
622572
factors.append(code_hash)
623573

624574
# 3. compiler hash
@@ -724,8 +674,9 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
724674
self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE
725675
or not self.compilation_config.cudagraph_copy_inputs
726676
):
727-
return VllmCompiledFunction(graph, example_inputs, vllm_config,
728-
self.split_gm)
677+
return VllmSerializableFunction(
678+
graph, example_inputs, self.prefix, self.split_gm
679+
)
729680

730681
# if we need to copy input buffers for cudagraph
731682
from torch._guards import detect_fake_mode
@@ -770,5 +721,6 @@ def copy_and_call(*args):
770721
list_args[index] = static_tensor
771722
return self.split_gm(*list_args)
772723

773-
return VllmCompiledFunction(graph, example_inputs, vllm_config,
774-
copy_and_call)
724+
return VllmSerializableFunction(
725+
graph, example_inputs, self.prefix, copy_and_call
726+
)

0 commit comments

Comments
 (0)