Skip to content

Commit 8d18941

Browse files
zhxchen17dolpm
authored andcommitted
AOT compilation workflow [2/n]
Signed-off-by: zhxchen17 <zhxchen17@fb.com>
1 parent 308275f commit 8d18941

File tree

4 files changed

+174
-19
lines changed

4 files changed

+174
-19
lines changed

tests/compile/test_aot_compile.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

4+
import tempfile
45
from contextlib import contextmanager
56

67
import pytest
@@ -59,3 +60,38 @@ def test_no_eval_frame(monkeypatch: pytest.MonkeyPatch):
5960
"fail_on_recompile"):
6061
ret = CompiledMod(vllm_config=vllm_config)(*args)
6162
assert torch.allclose(ret, expected)
63+
64+
65+
def test_force_aot_load(monkeypatch: pytest.MonkeyPatch):
66+
with tempfile.TemporaryDirectory() as tmpdirname, monkeypatch.context(
67+
) as m:
68+
args = (torch.randn(10, 10), )
69+
m.setenv("VLLM_USE_AOT_COMPILE", "1")
70+
m.setenv("VLLM_FORCE_AOT_LOAD", "1")
71+
m.setenv("VLLM_CACHE_ROOT", tmpdirname)
72+
vllm_config = make_vllm_config()
73+
with use_vllm_config(vllm_config):
74+
CompiledMod = support_torch_compile(MyMod)
75+
try:
76+
CompiledMod(vllm_config=vllm_config)(*args)
77+
except Exception as e:
78+
assert isinstance(e, FileNotFoundError)
79+
else:
80+
raise AssertionError(
81+
"Expected failed aot compilation with clean state.")
82+
83+
84+
def test_basic(monkeypatch: pytest.MonkeyPatch):
85+
with monkeypatch.context() as m:
86+
args = (torch.randn(10, 10), )
87+
CompiledMod = support_torch_compile(MyMod)
88+
89+
with tempfile.TemporaryDirectory() as tmpdirname:
90+
m.setenv("VLLM_CACHE_ROOT", tmpdirname)
91+
m.setenv("VLLM_USE_AOT_COMPILE", "1")
92+
vllm_config = make_vllm_config()
93+
with use_vllm_config(vllm_config):
94+
expected = CompiledMod(vllm_config=vllm_config)(*args)
95+
m.setenv("VLLM_FORCE_AOT_LOAD", "1")
96+
ret = CompiledMod(vllm_config=vllm_config)(*args)
97+
assert torch.allclose(ret, expected)

vllm/compilation/backends.py

Lines changed: 85 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,20 @@
33

44
import ast
55
import dataclasses
6+
import inspect
67
import os
8+
import pickle
79
import pprint
810
import time
911
from collections.abc import Sequence
1012
from contextlib import contextmanager
1113
from typing import Any, Callable, Optional
14+
from unittest.mock import patch
1215

1316
import torch
1417
import torch.fx as fx
1518
from torch._dispatch.python import enable_python_dispatcher
19+
from torch.utils import _pytree as pytree
1620

1721
import vllm.envs as envs
1822
from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
@@ -408,23 +412,94 @@ def set_model_tag(tag: str):
408412

409413
class VllmCompiledFunction(SerializableCallable):
410414

411-
def __init__(self, graph_module, example_inputs, vllm_config,
415+
def __init__(self, graph_module, example_inputs, vllm_config, prefix,
412416
optimized_call):
417+
assert isinstance(graph_module, torch.fx.GraphModule)
413418
self.graph_module = graph_module
414419
self.example_inputs = example_inputs
415420
self.vllm_config = vllm_config
421+
self.prefix = prefix
416422
self.optimized_call = optimized_call
417423

418424
def __call__(self, *args, **kwargs):
419425
return self.optimized_call(*args, **kwargs)
420426

421427
@classmethod
422-
def serialize_compile_artifacts(cls, compiled_fn):
423-
raise NotImplementedError("serialization not implemented")
428+
def serialize_compile_artifacts(
429+
cls, compiled_fn: "VllmCompiledFunction") -> bytes:
430+
import sympy
431+
from torch._subclasses import FakeTensorMode
432+
from torch.fx._graph_pickler import GraphPickler, Options
433+
state = compiled_fn.__dict__.copy()
434+
state.pop("optimized_call")
435+
for node in state["graph_module"].graph.nodes:
436+
node.meta.pop("source_fn_stack", None)
437+
node.meta.pop("nn_module_stack", None)
438+
439+
graph_reducer_override = GraphPickler.reducer_override
440+
441+
def _graph_reducer_override(self, obj):
442+
if (inspect.isclass(obj) and issubclass(obj, sympy.Function)
443+
and hasattr(obj, "_torch_unpickler")):
444+
return obj._torch_unpickler, (obj._torch_handler_name, )
445+
if isinstance(obj, FakeTensorMode):
446+
return type(None), ()
447+
return graph_reducer_override(self, obj)
448+
449+
# Mask off tensor inputs since they are large and not needed.
450+
state["example_inputs"] = pytree.tree_map_only(torch.Tensor,
451+
lambda _: None,
452+
state["example_inputs"])
453+
with patch.object(GraphPickler, 'reducer_override',
454+
_graph_reducer_override):
455+
state["graph_module"] = GraphPickler.dumps(
456+
state["graph_module"], Options(ops_filter=None))
457+
state["example_inputs"] = GraphPickler.dumps(
458+
state["example_inputs"])
459+
return pickle.dumps(state)
424460

425461
@classmethod
426-
def deserialize_compile_artifacts(cls, data):
427-
raise NotImplementedError("deserialization not implemented")
462+
def deserialize_compile_artifacts(cls,
463+
data: bytes) -> "VllmCompiledFunction":
464+
from torch._guards import TracingContext, tracing
465+
from torch._subclasses import FakeTensorMode
466+
from torch.fx._graph_pickler import GraphPickler
467+
from torch.fx.experimental.symbolic_shapes import ShapeEnv
468+
469+
state = pickle.loads(data)
470+
fake_mode = FakeTensorMode(shape_env=ShapeEnv())
471+
state["graph_module"] = GraphPickler.loads(state["graph_module"],
472+
fake_mode)
473+
state["example_inputs"] = GraphPickler.loads(state["example_inputs"],
474+
fake_mode)
475+
vllm_backend = VllmBackend(state["vllm_config"], state["prefix"])
476+
477+
def optimized_call(*example_inputs):
478+
compile_inputs = [
479+
inp or example_inputs[i]
480+
for i, inp in enumerate(fn.example_inputs)
481+
]
482+
with tracing(TracingContext(fake_mode)):
483+
fn.optimized_call = vllm_backend(state["graph_module"],
484+
compile_inputs).optimized_call
485+
return fn.optimized_call(*example_inputs)
486+
487+
fn = cls(**state, optimized_call=optimized_call)
488+
return fn
489+
490+
491+
def compilation_config_hash_factors(vllm_config: VllmConfig) -> list[str]:
492+
factors = []
493+
# 0. factors come from the env, for example, The values of
494+
# VLLM_PP_LAYER_PARTITION will affect the computation graph.
495+
env_hash = envs.compute_hash()
496+
factors.append(env_hash)
497+
498+
# 1. factors come from the vllm_config (it mainly summarizes how the
499+
# model is created)
500+
config_hash = vllm_config.compute_hash()
501+
factors.append(config_hash)
502+
return factors
428503

429504

430505
class VllmBackend:
@@ -502,7 +577,8 @@ def configure_post_pass(self):
502577
self.post_grad_pass_manager.add(inductor_config[PASS_KEY])
503578
inductor_config[PASS_KEY] = self.post_grad_pass_manager
504579

505-
def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
580+
def __call__(self, graph: fx.GraphModule,
581+
example_inputs) -> VllmCompiledFunction:
506582

507583
vllm_config = self.vllm_config
508584
if not self.compilation_config.cache_dir:
@@ -511,17 +587,7 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
511587
# the cache dir will be the same so that we can reuse the compiled
512588
# graph.
513589

514-
factors = []
515-
# 0. factors come from the env, for example, The values of
516-
# VLLM_PP_LAYER_PARTITION will affect the computation graph.
517-
env_hash = envs.compute_hash()
518-
factors.append(env_hash)
519-
520-
# 1. factors come from the vllm_config (it mainly summarizes how the
521-
# model is created)
522-
config_hash = vllm_config.compute_hash()
523-
factors.append(config_hash)
524-
590+
factors = compilation_config_hash_factors(vllm_config)
525591
# 2. factors come from the code files that are traced by Dynamo (
526592
# it mainly summarizes how the model is used in forward pass)
527593
forward_code_files = list(
@@ -635,7 +701,7 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
635701
if self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE or \
636702
not self.compilation_config.cudagraph_copy_inputs:
637703
return VllmCompiledFunction(graph, example_inputs, vllm_config,
638-
self.split_gm)
704+
self.prefix, self.split_gm)
639705

640706
# if we need to copy input buffers for cudagraph
641707
from torch._guards import detect_fake_mode
@@ -678,4 +744,4 @@ def copy_and_call(*args):
678744
return self.split_gm(*list_args)
679745

680746
return VllmCompiledFunction(graph, example_inputs, vllm_config,
681-
copy_and_call)
747+
self.prefix, copy_and_call)

vllm/compilation/decorators.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

44
import contextlib
5+
import hashlib
56
import inspect
7+
import os
68
from typing import Callable, Optional, TypeVar, Union, overload
79
from unittest.mock import patch
810

@@ -176,6 +178,13 @@ def cls_decorator_helper(cls: _T) -> _T:
176178
return cls_decorator_helper
177179

178180

181+
def _model_hash_key(fn) -> str:
182+
sha256_hash = hashlib.sha256()
183+
sha256_hash.update(fn.__qualname__.encode())
184+
sha256_hash.update(str(fn.__code__.co_firstlineno).encode())
185+
return sha256_hash.hexdigest()
186+
187+
179188
def _support_torch_compile(
180189
cls: _T,
181190
dynamic_arg_dims: dict[str, Union[int, list[int]]],
@@ -227,6 +236,39 @@ def __call__(self, *args, **kwargs):
227236
if getattr(self, "aot_compiled_fn", None) is not None:
228237
return self.aot_compiled_fn(self, *args, **kwargs)
229238

239+
cache_dir = None
240+
aot_compilation_path = None
241+
if envs.VLLM_USE_AOT_COMPILE:
242+
from .backends import compilation_config_hash_factors
243+
factors: list[str] = compilation_config_hash_factors(
244+
self.vllm_config)
245+
246+
factors.append(_model_hash_key(self.forward))
247+
hash_key = hashlib.sha256(str(factors).encode()).hexdigest()
248+
249+
cache_dir = os.path.join(
250+
envs.VLLM_CACHE_ROOT,
251+
"aot_compilation",
252+
hash_key,
253+
)
254+
255+
rank = self.vllm_config.parallel_config.rank
256+
dp_rank = self.vllm_config.parallel_config.data_parallel_rank
257+
cache_dir = os.path.join(cache_dir, f"rank_{rank}_{dp_rank}")
258+
aot_compilation_path = os.path.join(cache_dir, "model")
259+
try:
260+
with open(aot_compilation_path, "rb") as f:
261+
loaded_fn = torch.compiler.load_compiled_function(f)
262+
self.aot_compiled_fn = loaded_fn
263+
return self.aot_compiled_fn(self, *args, **kwargs)
264+
except Exception as e:
265+
if os.path.exists(aot_compilation_path):
266+
logger.warning(
267+
"Cannot load aot compilation from path %s, error: %s",
268+
aot_compilation_path, str(e))
269+
if envs.VLLM_FORCE_AOT_LOAD:
270+
raise e
271+
230272
# the first compilation needs to have dynamic shapes marked
231273
if len(self.compiled_codes) < 1:
232274
sig = inspect.signature(self.__class__.forward)
@@ -312,6 +354,11 @@ def patched_inline_call(parent, func, args, kwargs):
312354
if envs.VLLM_USE_AOT_COMPILE:
313355
self.aot_compiled_fn = self.aot_compile(*args, **kwargs)
314356
output = self.aot_compiled_fn(self, *args, **kwargs)
357+
assert aot_compilation_path is not None
358+
assert cache_dir is not None
359+
os.makedirs(cache_dir, exist_ok=True)
360+
self.aot_compiled_fn.save_compiled_function(
361+
aot_compilation_path)
315362
else:
316363
output = self.compiled_callable(*args, **kwargs)
317364

vllm/envs.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -510,6 +510,12 @@ def get_vllm_port() -> Optional[int]:
510510
"VLLM_USE_AOT_COMPILE":
511511
lambda: os.environ.get("VLLM_USE_AOT_COMPILE", "0") == "1",
512512

513+
# Force vllm to always load AOT compiled models from disk. Failure
514+
# to load will result in a hard error when this is enabled.
515+
# Will be ignored when VLLM_USE_AOT_COMPILE is disabled.
516+
"VLLM_FORCE_AOT_LOAD":
517+
lambda: os.environ.get("VLLM_FORCE_AOT_LOAD", "0") == "1",
518+
513519
# local rank of the process in the distributed setting, used to determine
514520
# the GPU device id
515521
"LOCAL_RANK":

0 commit comments

Comments
 (0)