Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,7 @@ steps:
- pytest -v -s compile/test_fusion_all_reduce.py
- pytest -v -s compile/test_decorator.py
- pytest -v -s compile/test_noop_elimination.py
- pytest -v -s compile/test_aot_compile.py

- label: PyTorch Fullgraph Smoke Test # 15min
timeout_in_minutes: 30
Expand Down
139 changes: 139 additions & 0 deletions tests/compile/test_aot_compile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import tempfile
from contextlib import contextmanager

import pytest
import torch

from vllm.compilation.decorators import support_torch_compile
from vllm.config import (
CompilationConfig,
CompilationLevel,
VllmConfig,
set_current_vllm_config,
)
from vllm.forward_context import set_forward_context
from vllm.utils import is_torch_equal_or_newer

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NB: vLLM does run vLLM main x PyTorch nightly, but that CI is very busted (it's been failing for weeks), so we're not going to get signal that way. Zhengxu has verified that this test file passes locally, plus this change is opt-in and still undergoing integration testing.


def reference_fn(x: torch.Tensor):
assert x.shape[0] <= 42
assert x.shape[0] % 2 == 0
for _ in range(3000):
x = x + x.shape[0]
return x


@support_torch_compile
class CompiledMod(torch.nn.Module):
def __init__(self, **kwargs):
super().__init__()

def forward(self, x: torch.Tensor):
return reference_fn(x)


def make_vllm_config() -> VllmConfig:
return VllmConfig(
compilation_config=CompilationConfig(
level=CompilationLevel.PIECEWISE,
)
)


@contextmanager
def use_vllm_config(vllm_config: VllmConfig):
with set_forward_context({}, vllm_config), set_current_vllm_config(vllm_config):
yield


@pytest.mark.skipif(
not is_torch_equal_or_newer("2.10.0.dev"), reason="requires torch 2.10"
)
def test_no_dynamo_cache_entry(monkeypatch: pytest.MonkeyPatch):
with monkeypatch.context() as m:
vllm_config = make_vllm_config()
args = (torch.randn(10, 10),)
expected = reference_fn(*args)
with use_vllm_config(vllm_config):
m.setenv("VLLM_USE_AOT_COMPILE", "0")
with (
pytest.raises(RuntimeError, match="Detected recompile"),
torch.compiler.set_stance("fail_on_recompile"),
):
CompiledMod(vllm_config=vllm_config)(*args)

m.setenv("VLLM_USE_AOT_COMPILE", "1")
torch._dynamo.reset()
with torch.compiler.set_stance("fail_on_recompile"):
actual = CompiledMod(vllm_config=vllm_config)(*args)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why doesn't this fail - where does the compiled code come from? Does the previous run that raised a recompile error create it? Or does it come from the cache?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the name of API torch.compiler.set_stance('fail_on_recompile') is the source of confusion here. Basically torch.compile() has 2 modes now: JIT and AOT. torch.compiler.set_stance('fail_on_recompile') means torch.compile() will fail when we recompile in JIT mode.

Here by setting VLLM_USE_AOT_COMPILE=1, we're testing that torch.compile() JIT mode is not triggered. We're not testing the loading behavior in this unit test yet (we'll test loading part in the following tests). In other words, we are just testing we're using the correct AOT compile API from torch.

I think it's possible to address this by naming our API to be something like set_stance("fail_on_new_cache_entry") or better, but the behavior here is just about JIT vs AOT.

assert torch.allclose(actual, expected)


@pytest.mark.skipif(
not is_torch_equal_or_newer("2.10.0.dev"), reason="requires torch 2.10"
)
def test_force_aot_load(monkeypatch: pytest.MonkeyPatch):
with tempfile.TemporaryDirectory() as tmpdirname, monkeypatch.context() as m:
args = (torch.randn(10, 10),)
m.setenv("VLLM_USE_AOT_COMPILE", "1")
m.setenv("VLLM_FORCE_AOT_LOAD", "1")
m.setenv("VLLM_CACHE_ROOT", tmpdirname)
vllm_config = make_vllm_config()
with use_vllm_config(vllm_config), pytest.raises(FileNotFoundError):
CompiledMod(vllm_config=vllm_config)(*args)


@pytest.mark.skipif(
not is_torch_equal_or_newer("2.10.0.dev"), reason="requires torch 2.10"
)
def test_save_and_load(monkeypatch: pytest.MonkeyPatch):
with monkeypatch.context() as m:
args = (torch.randn(10, 10),)

with tempfile.TemporaryDirectory() as tmpdirname:
m.setenv("VLLM_CACHE_ROOT", tmpdirname)
m.setenv("VLLM_USE_AOT_COMPILE", "1")
vllm_config = make_vllm_config()
with use_vllm_config(vllm_config):
expected = CompiledMod(vllm_config=vllm_config)(*args)

m.setenv("VLLM_FORCE_AOT_LOAD", "1")
vllm_config = make_vllm_config()
with use_vllm_config(vllm_config):
ret = CompiledMod(vllm_config=vllm_config)(*args)
assert torch.allclose(ret, expected)


@pytest.mark.skipif(
not is_torch_equal_or_newer("2.10.0.dev"), reason="requires torch 2.10"
)
def test_shape_env(monkeypatch: pytest.MonkeyPatch):
"""
Test that the shape environment is correctly serialized and preserved
when loading from cache.
"""
with monkeypatch.context() as m:
args = (torch.randn(10, 10),)

with tempfile.TemporaryDirectory() as tmpdirname:
m.setenv("VLLM_CACHE_ROOT", tmpdirname)
m.setenv("VLLM_USE_AOT_COMPILE", "1")
vllm_config = make_vllm_config()
with use_vllm_config(vllm_config):
compiled_mod = CompiledMod(vllm_config=vllm_config)
compiled_mod(*args)
artifacts = compiled_mod.aot_compiled_fn._artifacts
guards_string = artifacts.compiled_fn.shape_env.format_guards()
assert guards_string == " - s77 <= 42\n - Eq(Mod(s77, 2), 0)"

m.setenv("VLLM_FORCE_AOT_LOAD", "1")
vllm_config = make_vllm_config()
with use_vllm_config(vllm_config):
compiled_mod = CompiledMod(vllm_config=vllm_config)
compiled_mod(*args)
artifacts = compiled_mod.aot_compiled_fn._artifacts
guards_string = artifacts.compiled_fn.shape_env.format_guards()
assert guards_string == " - s77 <= 42\n - Eq(Mod(s77, 2), 0)"
1 change: 1 addition & 0 deletions tools/pre_commit/check_pickle_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
"vllm/multimodal/hasher.py",
"vllm/transformers_utils/config.py",
"vllm/model_executor/models/registry.py",
"vllm/compilation/caching.py",
"tests/utils_/test_utils.py",
"tests/tokenization/test_cached_tokenizer.py",
"vllm/distributed/utils.py",
Expand Down
49 changes: 16 additions & 33 deletions vllm/compilation/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import ast
import dataclasses
import hashlib
import os
import pprint
import time
Expand All @@ -25,6 +26,7 @@
from vllm.platforms import current_platform
from vllm.utils import is_torch_equal_or_newer, resolve_obj_by_qualname

from .caching import VllmSerializableFunction
from .compiler_interface import (
CompilerInterface,
EagerAdaptor,
Expand Down Expand Up @@ -195,6 +197,7 @@ def compile(
# there can be multiple graphs due to piecewise compilation.
now = time.time()
elapsed = now - compilation_start_time
compilation_config.compilation_time += elapsed
if runtime_shape is None:
logger.info(
"Directly load the compiled graph(s) for dynamic shape "
Expand Down Expand Up @@ -549,47 +552,23 @@ def configure_post_pass(self):
self.post_grad_pass_manager.add(inductor_config[PASS_KEY])
inductor_config[PASS_KEY] = self.post_grad_pass_manager

def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
def __call__(
self, graph: fx.GraphModule, example_inputs
) -> VllmSerializableFunction:
from .caching import _compute_code_hash, compilation_config_hash_factors

vllm_config = self.vllm_config
if not self.compilation_config.cache_dir:
# no provided cache dir, generate one based on the known factors
# that affects the compilation. if none of the factors change,
# the cache dir will be the same so that we can reuse the compiled
# graph.

factors = []
# 0. factors come from the env, for example, The values of
# VLLM_PP_LAYER_PARTITION will affect the computation graph.
env_hash = envs.compute_hash()
factors.append(env_hash)

# 1. factors come from the vllm_config (it mainly summarizes how the
# model is created)
config_hash = vllm_config.compute_hash()
factors.append(config_hash)

factors = compilation_config_hash_factors(vllm_config)
# 2. factors come from the code files that are traced by Dynamo (
# it mainly summarizes how the model is used in forward pass)
forward_code_files = list(sorted(self.compilation_config.traced_files))
code_hash = _compute_code_hash(self.compilation_config.traced_files)
self.compilation_config.traced_files.clear()
logger.debug(
"Traced files (to be considered for compilation cache):\n%s",
"\n".join(forward_code_files),
)
hash_content = []
for filepath in forward_code_files:
hash_content.append(filepath)
if filepath == "<string>":
# This means the function was dynamically generated, with
# e.g. exec(). We can't actually check these.
continue
with open(filepath) as f:
hash_content.append(f.read())
import hashlib

code_hash = hashlib.md5(
"\n".join(hash_content).encode(), usedforsecurity=False
).hexdigest()
factors.append(code_hash)

# 3. compiler hash
Expand Down Expand Up @@ -695,7 +674,9 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE
or not self.compilation_config.cudagraph_copy_inputs
):
return self.split_gm
return VllmSerializableFunction(
graph, example_inputs, self.prefix, self.split_gm
)

# if we need to copy input buffers for cudagraph
from torch._guards import detect_fake_mode
Expand Down Expand Up @@ -740,4 +721,6 @@ def copy_and_call(*args):
list_args[index] = static_tensor
return self.split_gm(*list_args)

return copy_and_call
return VllmSerializableFunction(
graph, example_inputs, self.prefix, copy_and_call
)
Loading