Skip to content

Commit 955e518

Browse files
committed
AOT compilation workflow [2/n]
Signed-off-by: zhxchen17 <zhxchen17@fb.com>
1 parent d297def commit 955e518

File tree

8 files changed

+394
-114
lines changed

8 files changed

+394
-114
lines changed

tests/compile/test_aot_compile.py

Lines changed: 96 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,61 +1,129 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

4+
import tempfile
45
from contextlib import contextmanager
56

67
import pytest
78
import torch
89

910
from vllm.compilation.decorators import support_torch_compile
10-
from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig,
11-
set_current_vllm_config)
11+
from vllm.config import (
12+
CompilationConfig,
13+
CompilationLevel,
14+
VllmConfig,
15+
set_current_vllm_config,
16+
)
1217
from vllm.forward_context import set_forward_context
1318

1419

15-
class MyMod(torch.nn.Module):
20+
def reference_fn(x: torch.Tensor):
21+
assert x.shape[0] <= 42
22+
assert x.shape[0] % 2 == 0
23+
for _ in range(3000):
24+
x = x + x.shape[0]
25+
return x
1626

27+
28+
@support_torch_compile
29+
class CompiledMod(torch.nn.Module):
1730
def __init__(self, **kwargs):
1831
super().__init__()
1932

2033
def forward(self, x: torch.Tensor):
21-
for _ in range(3000):
22-
x = x + x.shape[0]
23-
return x
34+
return reference_fn(x)
2435

2536

2637
def make_vllm_config() -> VllmConfig:
27-
return VllmConfig(compilation_config=CompilationConfig(
28-
level=CompilationLevel.PIECEWISE, ))
38+
return VllmConfig(
39+
compilation_config=CompilationConfig(
40+
level=CompilationLevel.PIECEWISE,
41+
)
42+
)
2943

3044

3145
@contextmanager
3246
def use_vllm_config(vllm_config: VllmConfig):
33-
with set_forward_context(
34-
{}, vllm_config), set_current_vllm_config(vllm_config):
47+
with set_forward_context({}, vllm_config), set_current_vllm_config(vllm_config):
3548
yield
3649

3750

38-
def test_no_eval_frame(monkeypatch: pytest.MonkeyPatch):
51+
def test_no_dynamo_cache_entry(monkeypatch: pytest.MonkeyPatch):
3952
with monkeypatch.context() as m:
40-
mod = MyMod()
41-
args = (torch.randn(10, 10), )
42-
expected = mod(*args)
43-
CompiledMod = support_torch_compile(MyMod)
44-
4553
vllm_config = make_vllm_config()
46-
m.setenv("VLLM_USE_AOT_COMPILE", "0")
47-
try:
48-
with use_vllm_config(vllm_config), torch.compiler.set_stance(
49-
"fail_on_recompile"):
54+
args = (torch.randn(10, 10),)
55+
expected = reference_fn(*args)
56+
with use_vllm_config(vllm_config):
57+
m.setenv("VLLM_USE_AOT_COMPILE", "0")
58+
with (
59+
pytest.raises(RuntimeError, match="Detected recompile"),
60+
torch.compiler.set_stance("fail_on_recompile"),
61+
):
5062
CompiledMod(vllm_config=vllm_config)(*args)
51-
except RuntimeError as e:
52-
assert "Detected recompile" in str(e)
53-
else:
54-
raise AssertionError("Expected exception to be raised")
5563

64+
m.setenv("VLLM_USE_AOT_COMPILE", "1")
65+
torch._dynamo.reset()
66+
with (
67+
use_vllm_config(vllm_config),
68+
torch.compiler.set_stance("fail_on_recompile"),
69+
):
70+
actual = CompiledMod(vllm_config=vllm_config)(*args)
71+
assert torch.allclose(actual, expected)
72+
73+
74+
def test_force_aot_load(monkeypatch: pytest.MonkeyPatch):
75+
with tempfile.TemporaryDirectory() as tmpdirname, monkeypatch.context() as m:
76+
args = (torch.randn(10, 10),)
5677
m.setenv("VLLM_USE_AOT_COMPILE", "1")
57-
torch._dynamo.reset()
58-
with use_vllm_config(vllm_config), torch.compiler.set_stance(
59-
"fail_on_recompile"):
60-
ret = CompiledMod(vllm_config=vllm_config)(*args)
78+
m.setenv("VLLM_FORCE_AOT_LOAD", "1")
79+
m.setenv("VLLM_CACHE_ROOT", tmpdirname)
80+
vllm_config = make_vllm_config()
81+
with use_vllm_config(vllm_config), pytest.raises(FileNotFoundError):
82+
CompiledMod(vllm_config=vllm_config)(*args)
83+
84+
85+
def test_save_and_load(monkeypatch: pytest.MonkeyPatch):
86+
with monkeypatch.context() as m:
87+
args = (torch.randn(10, 10),)
88+
89+
with tempfile.TemporaryDirectory() as tmpdirname:
90+
m.setenv("VLLM_CACHE_ROOT", tmpdirname)
91+
m.setenv("VLLM_USE_AOT_COMPILE", "1")
92+
vllm_config = make_vllm_config()
93+
with use_vllm_config(vllm_config):
94+
expected = CompiledMod(vllm_config=vllm_config)(*args)
95+
96+
m.setenv("VLLM_FORCE_AOT_LOAD", "1")
97+
vllm_config = make_vllm_config()
98+
with use_vllm_config(vllm_config):
99+
ret = CompiledMod(vllm_config=vllm_config)(*args)
61100
assert torch.allclose(ret, expected)
101+
102+
103+
def test_shape_env(monkeypatch: pytest.MonkeyPatch):
104+
"""
105+
Test that the shape environment is correctly serialized and preserved
106+
when loading from cache.
107+
"""
108+
with monkeypatch.context() as m:
109+
args = (torch.randn(10, 10),)
110+
111+
with tempfile.TemporaryDirectory() as tmpdirname:
112+
m.setenv("VLLM_CACHE_ROOT", tmpdirname)
113+
m.setenv("VLLM_USE_AOT_COMPILE", "1")
114+
vllm_config = make_vllm_config()
115+
with use_vllm_config(vllm_config):
116+
compiled_mod = CompiledMod(vllm_config=vllm_config)
117+
compiled_mod(*args)
118+
artifacts = compiled_mod.aot_compiled_fn._artifacts
119+
guards_string = artifacts.compiled_fn.shape_env.format_guards()
120+
assert guards_string == " - s77 <= 42\n - Eq(Mod(s77, 2), 0)"
121+
122+
m.setenv("VLLM_FORCE_AOT_LOAD", "1")
123+
vllm_config = make_vllm_config()
124+
with use_vllm_config(vllm_config):
125+
compiled_mod = CompiledMod(vllm_config=vllm_config)
126+
compiled_mod(*args)
127+
artifacts = compiled_mod.aot_compiled_fn._artifacts
128+
guards_string = artifacts.compiled_fn.shape_env.format_guards()
129+
assert guards_string == " - s77 <= 42\n - Eq(Mod(s77, 2), 0)"

tools/pre_commit/check_pickle_imports.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
"vllm/multimodal/hasher.py",
2323
"vllm/transformers_utils/config.py",
2424
"vllm/model_executor/models/registry.py",
25+
"vllm/compilation/caching.py",
2526
"tests/utils_/test_utils.py",
2627
"tests/tokenization/test_cached_tokenizer.py",
2728
"vllm/distributed/utils.py",

vllm/compilation/backends.py

Lines changed: 16 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import ast
55
import dataclasses
6+
import hashlib
67
import os
78
import pprint
89
import time
@@ -20,6 +21,7 @@
2021
from vllm.platforms import current_platform
2122
from vllm.utils import is_torch_equal_or_newer, resolve_obj_by_qualname
2223

24+
from .caching import VllmSerializableFunction
2325
from .compiler_interface import (
2426
CompilerInterface,
2527
EagerAdaptor,
@@ -175,6 +177,7 @@ def compile(
175177
# there can be multiple graphs due to piecewise compilation.
176178
now = time.time()
177179
elapsed = now - compilation_start_time
180+
compilation_config.compilation_time += elapsed
178181
if runtime_shape is None:
179182
logger.info(
180183
"Directly load the compiled graph(s) for dynamic shape "
@@ -441,35 +444,6 @@ def set_model_tag(tag: str):
441444
model_tag = old_tag
442445

443446

444-
try:
445-
from torch._dynamo.aot_compile import SerializableCallable
446-
except ImportError:
447-
SerializableCallable = object
448-
449-
assert isinstance(SerializableCallable, type)
450-
451-
452-
class VllmCompiledFunction(SerializableCallable):
453-
454-
def __init__(self, graph_module, example_inputs, vllm_config,
455-
optimized_call):
456-
self.graph_module = graph_module
457-
self.example_inputs = example_inputs
458-
self.vllm_config = vllm_config
459-
self.optimized_call = optimized_call
460-
461-
def __call__(self, *args, **kwargs):
462-
return self.optimized_call(*args, **kwargs)
463-
464-
@classmethod
465-
def serialize_compile_artifacts(cls, compiled_fn):
466-
raise NotImplementedError("serialization not implemented")
467-
468-
@classmethod
469-
def deserialize_compile_artifacts(cls, data):
470-
raise NotImplementedError("deserialization not implemented")
471-
472-
473447
class VllmBackend:
474448
"""The compilation backend for `torch.compile` with vLLM.
475449
It is used for compilation level of `CompilationLevel.PIECEWISE`,
@@ -547,47 +521,23 @@ def configure_post_pass(self):
547521
self.post_grad_pass_manager.add(inductor_config[PASS_KEY])
548522
inductor_config[PASS_KEY] = self.post_grad_pass_manager
549523

550-
def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
524+
def __call__(
525+
self, graph: fx.GraphModule, example_inputs
526+
) -> VllmSerializableFunction:
527+
from .caching import _compute_code_hash, compilation_config_hash_factors
528+
551529
vllm_config = self.vllm_config
552530
if not self.compilation_config.cache_dir:
553531
# no provided cache dir, generate one based on the known factors
554532
# that affects the compilation. if none of the factors change,
555533
# the cache dir will be the same so that we can reuse the compiled
556534
# graph.
557535

558-
factors = []
559-
# 0. factors come from the env, for example, The values of
560-
# VLLM_PP_LAYER_PARTITION will affect the computation graph.
561-
env_hash = envs.compute_hash()
562-
factors.append(env_hash)
563-
564-
# 1. factors come from the vllm_config (it mainly summarizes how the
565-
# model is created)
566-
config_hash = vllm_config.compute_hash()
567-
factors.append(config_hash)
568-
536+
factors = compilation_config_hash_factors(vllm_config)
569537
# 2. factors come from the code files that are traced by Dynamo (
570538
# it mainly summarizes how the model is used in forward pass)
571-
forward_code_files = list(sorted(self.compilation_config.traced_files))
539+
code_hash = _compute_code_hash(self.compilation_config.traced_files)
572540
self.compilation_config.traced_files.clear()
573-
logger.debug(
574-
"Traced files (to be considered for compilation cache):\n%s",
575-
"\n".join(forward_code_files),
576-
)
577-
hash_content = []
578-
for filepath in forward_code_files:
579-
hash_content.append(filepath)
580-
if filepath == "<string>":
581-
# This means the function was dynamically generated, with
582-
# e.g. exec(). We can't actually check these.
583-
continue
584-
with open(filepath) as f:
585-
hash_content.append(f.read())
586-
import hashlib
587-
588-
code_hash = hashlib.md5(
589-
"\n".join(hash_content).encode(), usedforsecurity=False
590-
).hexdigest()
591541
factors.append(code_hash)
592542

593543
# 3. compiler hash
@@ -688,8 +638,9 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
688638
self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE
689639
or not self.compilation_config.cudagraph_copy_inputs
690640
):
691-
return VllmCompiledFunction(graph, example_inputs, vllm_config,
692-
self.split_gm)
641+
return VllmSerializableFunction(
642+
graph, example_inputs, self.prefix, self.split_gm
643+
)
693644

694645
# if we need to copy input buffers for cudagraph
695646
from torch._guards import detect_fake_mode
@@ -734,5 +685,6 @@ def copy_and_call(*args):
734685
list_args[index] = static_tensor
735686
return self.split_gm(*list_args)
736687

737-
return VllmCompiledFunction(graph, example_inputs, vllm_config,
738-
copy_and_call)
688+
return VllmSerializableFunction(
689+
graph, example_inputs, self.prefix, copy_and_call
690+
)

0 commit comments

Comments
 (0)