Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 132 additions & 7 deletions tests/compile/test_dynamic_shapes_compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,21 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import gc
import tempfile
from contextlib import contextmanager

import pytest
import torch

from vllm import LLM, SamplingParams
from vllm.config.compilation import CompilationMode, DynamicShapesType
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CompilationConfig, VllmConfig, set_current_vllm_config
from vllm.config.compilation import (
CompilationMode,
DynamicShapesConfig,
DynamicShapesType,
)
from vllm.forward_context import set_forward_context
from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.utils.torch_utils import is_torch_equal_or_newer

Expand All @@ -29,18 +38,19 @@ def get_test_models():
)
@pytest.mark.parametrize("use_aot_compile", ["0"])
@pytest.mark.parametrize("use_bytecode_hook", [True, False])
@pytest.mark.parametrize("evaluate_guards", [False, True])
@pytest.mark.skipif(
not is_torch_equal_or_newer("2.10.0.dev"), reason="requires torch 2.10"
)
def test_dynamic_shapes_compilation(
monkeypatch, model_name, shapes_type, use_aot_compile, use_bytecode_hook
monkeypatch,
model_name,
shapes_type,
use_aot_compile,
use_bytecode_hook,
evaluate_guards,
):
"""Test that all dynamic shapes types compile successfully"""
print(
f"\nTesting model: {model_name} with {shapes_type.name}, "
f"AOT compile: {use_aot_compile}, "
f"Bytecode hook: {use_bytecode_hook}"
)
if use_bytecode_hook and shapes_type == DynamicShapesType.UNBACKED:
pytest.skip("UNBACKED dynamic shapes require VLLM_USE_BYTECODE_HOOK=0")

Expand All @@ -58,6 +68,7 @@ def test_dynamic_shapes_compilation(
"mode": CompilationMode.VLLM_COMPILE,
"dynamic_shapes_config": {
"type": shapes_type.value,
"evaluate_guards": evaluate_guards,
},
},
)
Expand Down Expand Up @@ -86,3 +97,117 @@ def test_dynamic_shapes_compilation(
torch.cuda.empty_cache()
torch.cuda.synchronize()
print("GPU memory cleared")


@pytest.mark.parametrize("use_aot_compile", ["0", "1"])
@pytest.mark.parametrize(
"dynamic_shapes_type",
[
DynamicShapesType.BACKED,
DynamicShapesType.BACKED_SIZE_OBLIVIOUS,
],
)
@pytest.mark.parametrize("evaluate_guards", [False, True])
def test_model_specialization_with_evaluate_guards(
monkeypatch, use_aot_compile, dynamic_shapes_type, evaluate_guards
):
"""Test that evaluate_guards correctly detects shape specialization
violations.
"""

if (
use_aot_compile == "1"
and dynamic_shapes_type == DynamicShapesType.BACKED
and evaluate_guards
):
pytest.skip("evaluate_guards for backed does not work with aot_compile =1")

@support_torch_compile
class ModelWithSizeCheck(torch.nn.Module):
def __init__(self, **kwargs):
super().__init__()

def forward(self, x: torch.Tensor):
# This will cause specialization - torch.compile will guard on
# sx.shape[0]
if x.shape[0] >= 10:
return x * 10
else:
return x * 10

@support_torch_compile
class ModelWithOneSizeCheck(torch.nn.Module):
def __init__(self, **kwargs):
super().__init__()

def forward(self, x: torch.Tensor):
# This will cause 0/1 specializations.
if x.shape[0] == 0:
return x * 10
if x.shape[0] == 1:
return x * 10
else:
return x * 10

@contextmanager
def use_vllm_config(vllm_config: VllmConfig):
with set_forward_context({}, vllm_config), set_current_vllm_config(vllm_config):
yield

monkeypatch.setenv("TOKENIZERS_PARALLELISM", "true")
monkeypatch.setenv("VLLM_USE_AOT_COMPILE", use_aot_compile)
monkeypatch.setenv("VLLM_USE_BYTECODE_HOOK", "0")

# Create vllm config with the desired settings
from vllm.config import CompilationMode

vllm_config = VllmConfig(
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
dynamic_shapes_config=DynamicShapesConfig(
type=dynamic_shapes_type,
evaluate_guards=evaluate_guards,
),
)
)

def test(model_class, input1, input2, is_01_specialization=False):
with (
torch.no_grad(),
use_vllm_config(vllm_config),
tempfile.TemporaryDirectory() as tmpdirname,
):
monkeypatch.setenv("VLLM_CACHE_ROOT", tmpdirname)

model = model_class(vllm_config=vllm_config).cuda()

model(input1)

if evaluate_guards and (
not (
is_01_specialization
and dynamic_shapes_type == DynamicShapesType.BACKED
)
):
# This should fail because guards were added.
with pytest.raises(RuntimeError) as excinfo:
model(input2)

# Expected failure - guard was violated
error_msg = str(excinfo.value)
assert (
"GuardManager check failed" in error_msg
or "Detected recompile when torch.compile stance" in error_msg
), error_msg

else:
model(input2)

test(ModelWithSizeCheck, torch.randn(20, 10).cuda(), torch.randn(5, 10).cuda())
test(ModelWithSizeCheck, torch.randn(5, 10).cuda(), torch.randn(20, 10).cuda())
test(
ModelWithOneSizeCheck,
torch.randn(20, 10).cuda(),
torch.randn(1, 10).cuda(),
is_01_specialization=True,
)
27 changes: 24 additions & 3 deletions vllm/compilation/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
should_split,
)
from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
from vllm.config.compilation import DynamicShapesType
from vllm.config.utils import hash_factors
from vllm.logger import init_logger
from vllm.logging_utils import lazy
Expand Down Expand Up @@ -752,6 +753,29 @@ def __call__(
self.split_gm, submod_names_to_compile, self.vllm_config, self
).run(*example_inputs)

from torch._guards import detect_fake_mode

fake_mode = detect_fake_mode()

if (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we only do this with evaluate_guards on?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh is it because otherwise we drop all guards anyway?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also for backed size oblivious we do no t want to drop 0/1 guards

self.compilation_config.dynamic_shapes_config.evaluate_guards
and self.compilation_config.dynamic_shapes_config.type
== DynamicShapesType.BACKED
):
from torch.utils._sympy.value_ranges import ValueRanges

# Drop counter-0/1 specializations guards; for backed dynamic shapes,
# torch.compile will specialize for 0/1 inputs or otherwise guards that
# shape is >= 2. This is because it's really hard not to hit a check
# against 0/1. When we evaluate shape guards, we exclude checking those
# guards (We would fail always otherwise).

# We avoid that by updating the ranges of backed sizes when the min is
# 2 for any, we assume it's 0.
for s, r in fake_mode.shape_env.var_to_range.items():
if r.lower == 2:
fake_mode.shape_env.var_to_range[s] = ValueRanges(0, r.upper)

graph_path = os.path.join(local_cache_dir, "computation_graph.py")
if not os.path.exists(graph_path):
# code adapted from
Expand Down Expand Up @@ -780,9 +804,6 @@ def __call__(
)

# if we need to copy input buffers for cudagraph
from torch._guards import detect_fake_mode

fake_mode = detect_fake_mode()
fake_args = [
fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t
for t in example_inputs
Expand Down
4 changes: 2 additions & 2 deletions vllm/compilation/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,6 @@ def __call__(self, *args, **kwargs):

factors.append(_model_hash_key(self.forward))
hash_key = hashlib.sha256(str(factors).encode()).hexdigest()

cache_dir = os.path.join(
envs.VLLM_CACHE_ROOT,
"torch_aot_compile",
Expand All @@ -411,7 +410,8 @@ def __call__(self, *args, **kwargs):
start_monitoring_torch_compile(self.vllm_config)
loaded_fn = torch.compiler.load_compiled_function(f)
_verify_source_unchanged(loaded_fn.source_info(), self.vllm_config)
loaded_fn.disable_guard_check()
if not self.compilation_config.dynamic_shapes_config.evaluate_guards:
loaded_fn.disable_guard_check()
self.aot_compiled_fn = loaded_fn
except Exception as e:
if os.path.exists(aot_compilation_path):
Expand Down
64 changes: 51 additions & 13 deletions vllm/compilation/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import os
import sys
from abc import abstractmethod
from contextlib import contextmanager
from contextlib import contextmanager, nullcontext
from types import CodeType
from typing import Any

Expand All @@ -13,6 +13,7 @@

import vllm.envs as envs
from vllm.config import CompilationMode, CUDAGraphMode, get_current_vllm_config
from vllm.config.compilation import DynamicShapesType
from vllm.logger import init_logger

logger = init_logger(__name__)
Expand Down Expand Up @@ -98,6 +99,7 @@ def __init__(self):
vllm_config = get_current_vllm_config()
self.vllm_config = vllm_config
mode = vllm_config.compilation_config.mode

if mode is None:
raise RuntimeError("Compilation mode cannot be NO_COMPILATION")

Expand All @@ -107,23 +109,53 @@ def __init__(self):
if isinstance(backend, str) and backend == "inductor":
options = vllm_config.compilation_config.inductor_compile_config

self.first_compile = True

ds_type = vllm_config.compilation_config.dynamic_shapes_config.type

if mode != CompilationMode.STOCK_TORCH_COMPILE:
# Drop all the guards.
options["guard_filter_fn"] = lambda x: [False for _ in x]
if vllm_config.compilation_config.dynamic_shapes_config.evaluate_guards:
assert not envs.VLLM_USE_BYTECODE_HOOK, (
"compilation_config.dynamic_shapes_config.evaluate_guards "
"requires VLLM_USE_BYTECODE_HOOK=0. "
)

# Validate that unbacked dynamic shapes require VLLM_USE_BYTECODE_HOOK=False
from vllm.compilation.decorators import DynamicShapesType
if envs.VLLM_USE_AOT_COMPILE:
# disabled until https://github.com/pytorch/pytorch/pull/169239
# is picked up.
assert ds_type != DynamicShapesType.BACKED, (
"evaluate_guards for backed shapes requires "
"VLLM_USE_AOT_COMPILE=False. "
)

assert not envs.VLLM_USE_BYTECODE_HOOK, (
"compilation_config.dynamic_shapes_config.evaluate_guards "
"requires VLLM_USE_BYTECODE_HOOK=0. "
)

options["guard_filter_fn"] = lambda x: [
entry.guard_type == "SHAPE_ENV" for entry in x
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems brittle..
can we add a test in pytorch for this or something? "if you change how the dynamic shape guards look then you need to change vLLM"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeh we shall if we dont have any. I will file issue and assign to me.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually vllm testd would fail I also found a an internal test that checks it def test_symbool_guards(

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sgtm

]
else:
options["guard_filter_fn"] = lambda x: [False for _ in x]

ds_type = vllm_config.compilation_config.dynamic_shapes_config.type
compiled_ptr: Any = self.forward
# Validate that unbacked dynamic shapes require VLLM_USE_BYTECODE_HOOK=False

if ds_type == DynamicShapesType.UNBACKED:
if envs.VLLM_USE_BYTECODE_HOOK:
# reason is that bytecode does this hack torch._dynamo.eval_frame.
# remove_from_cache(self.original_code_object()) to force a new
# re-compilation.
raise ValueError(
"UNBACKED dynamic shapes require VLLM_USE_BYTECODE_HOOK=0. "
)
# reason is that bytecode does torch._dynamo.eval_frame.
# remove_from_cache(self.original_code_object()) to force a new
# re-compilation. And if we use
# compiled_ptr = self.check_invariants_and_forward
# it will reset all entries.
assert envs.VLLM_USE_BYTECODE_HOOK, (
"UNBACKED dynamic shapes requires VLLM_USE_BYTECODE_HOOK=0. "
)
assert (
not vllm_config.compilation_config.dynamic_shapes_config.evaluate_guards
), "UNBACKED dynamic shapes do not add guards"

compiled_ptr = self.check_invariants_and_forward

if envs.VLLM_USE_AOT_COMPILE:
Expand Down Expand Up @@ -173,7 +205,13 @@ def __call__(self, *args, **kwargs):
with self._dispatch_to_compiled_code():
return self.forward(*args, **kwargs)
else:
with _compilation_context():
ctx = (
nullcontext()
if self.first_compile
else torch.compiler.set_stance("fail_on_recompile")
)
self.first_compile = False
with _compilation_context(), ctx:
return self._compiled_callable(*args, **kwargs)

@abstractmethod
Expand Down
17 changes: 14 additions & 3 deletions vllm/config/compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,18 @@ class DynamicShapesConfig:
backed/unbacked.
"""

# TODO add a debug mode to fail
evaluate_guards: bool = False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I want to get into a state where this is always on during vLLM warmup, and then optionally during runtime. This involves making sure the guards are actually correct (maybe this requires unbacked). Not for this PR, but we should do this in a follow-up.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

using unbacked basically just avoids the need of this, once we address the perf gp

"""
A debug mode to detect and fail if Dynamo ever specializes a dynamic shape by
guarding on it. When True, dynamic shape guards are not dropped from dynamo.
And a failure will be triggered if a recompilation ever happens due to that.
This mode requires VLLM_USE_BYTECODE_HOOK to be 0.
Enabling this allow observing the dynamic shapes guards in the tlparse
artifacts also.
When type is backed, aot_compile must be disabled for this mode to work.
until this change picked up https://github.com/pytorch/pytorch/pull/169239.

"""

def compute_hash(self) -> str:
"""
Expand Down Expand Up @@ -358,8 +369,8 @@ class CompilationConfig:
We use string to avoid serialization issues when using compilation in a
distributed setting. When the compilation mode is 1 or 2, the backend is
used for the compilation directly (it sees the whole graph). When the
compilation mode is 3, the backend supports both whole graph and piecewise
compilation, available backends include eager, inductor, and custom backends,
compilation mode is 3, the backend supports both whole graph and piecewise
compilation, available backends include eager, inductor, and custom backends,
the latter of which can be defined via `get_compile_backend`. Furthermore,
compilation is only piecewise if splitting ops is set accordingly and
use_inductor_graph_partition is off. Note that the default options for
Expand Down
2 changes: 1 addition & 1 deletion vllm/config/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ class OptimizationLevel(IntEnum):
"""O0 : No optimization. no compilation, no cudagraphs, no other
optimization, just starting up immediately"""
O1 = 1
"""O1: Quick optimizations. Dynamo+Inductor compilation and Piecewise
"""O1: Quick optimizations. Dynamo+Inductor compilation and Piecewise
cudagraphs"""
O2 = 2
"""O2: Full optimizations. -O1 as well as Full and Piecewise cudagraphs."""
Expand Down