Skip to content

Commit b2b8e71

Browse files
committed
add evaluate_guards option to DynamicShapesConfig
Signed-off-by: Laith Sakka <lsakka@meta.com>
1 parent de75b0b commit b2b8e71

File tree

4 files changed

+173
-19
lines changed

4 files changed

+173
-19
lines changed

tests/compile/test_dynamic_shapes_compilation.py

Lines changed: 114 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,20 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

44
import gc
5+
from contextlib import contextmanager
56

67
import pytest
78
import torch
89

910
from vllm import LLM, SamplingParams
10-
from vllm.config.compilation import CompilationMode, DynamicShapesType
11+
from vllm.compilation.decorators import support_torch_compile
12+
from vllm.config import CompilationConfig, VllmConfig, set_current_vllm_config
13+
from vllm.config.compilation import (
14+
CompilationMode,
15+
DynamicShapesConfig,
16+
DynamicShapesType,
17+
)
18+
from vllm.forward_context import set_forward_context
1119
from vllm.transformers_utils.tokenizer import get_tokenizer
1220
from vllm.utils.torch_utils import is_torch_equal_or_newer
1321

@@ -29,18 +37,19 @@ def get_test_models():
2937
)
3038
@pytest.mark.parametrize("use_aot_compile", ["0"])
3139
@pytest.mark.parametrize("use_bytecode_hook", [True, False])
40+
@pytest.mark.parametrize("evaluate_guards", [False, True])
3241
@pytest.mark.skipif(
3342
not is_torch_equal_or_newer("2.10.0.dev"), reason="requires torch 2.10"
3443
)
3544
def test_dynamic_shapes_compilation(
36-
monkeypatch, model_name, shapes_type, use_aot_compile, use_bytecode_hook
45+
monkeypatch,
46+
model_name,
47+
shapes_type,
48+
use_aot_compile,
49+
use_bytecode_hook,
50+
evaluate_guards,
3751
):
3852
"""Test that all dynamic shapes types compile successfully"""
39-
print(
40-
f"\nTesting model: {model_name} with {shapes_type.name}, "
41-
f"AOT compile: {use_aot_compile}, "
42-
f"Bytecode hook: {use_bytecode_hook}"
43-
)
4453
if use_bytecode_hook and shapes_type == DynamicShapesType.UNBACKED:
4554
pytest.skip("UNBACKED dynamic shapes require VLLM_USE_BYTECODE_HOOK=0")
4655

@@ -58,6 +67,7 @@ def test_dynamic_shapes_compilation(
5867
"mode": CompilationMode.VLLM_COMPILE,
5968
"dynamic_shapes_config": {
6069
"type": shapes_type.value,
70+
"evaluate_guards": evaluate_guards,
6171
},
6272
},
6373
)
@@ -86,3 +96,100 @@ def test_dynamic_shapes_compilation(
8696
torch.cuda.empty_cache()
8797
torch.cuda.synchronize()
8898
print("GPU memory cleared")
99+
100+
101+
@pytest.mark.parametrize("use_aot_compile", ["0", "1"])
102+
@pytest.mark.parametrize(
103+
"dynamic_shapes_type",
104+
[
105+
DynamicShapesType.BACKED,
106+
DynamicShapesType.BACKED_SIZE_OBLIVIOUS,
107+
DynamicShapesType.UNBACKED,
108+
],
109+
)
110+
@pytest.mark.parametrize("evaluate_guards", [False, True])
111+
def test_model_specialization_with_evaluate_guards(
112+
monkeypatch, use_aot_compile, dynamic_shapes_type, evaluate_guards
113+
):
114+
"""Test that evaluate_guards correctly detects shape specialization
115+
violations.
116+
"""
117+
if use_aot_compile and dynamic_shapes_type == DynamicShapesType.UNBACKED:
118+
pytest.skip("UNBACKED dynamic shapes require use_aot_compile=0")
119+
120+
@support_torch_compile
121+
class ModelWithSizeCheck(torch.nn.Module):
122+
def __init__(self, **kwargs):
123+
super().__init__()
124+
self.linear = torch.nn.Linear(10, 10)
125+
126+
def forward(self, x: torch.Tensor):
127+
x = self.linear(x)
128+
# This will cause specialization - torch.compile will guard on x.shape[0]
129+
if x.shape[0] >= 10:
130+
return x
131+
else:
132+
return x
133+
134+
@contextmanager
135+
def use_vllm_config(vllm_config: VllmConfig):
136+
with set_forward_context({}, vllm_config), set_current_vllm_config(vllm_config):
137+
yield
138+
139+
monkeypatch.setenv("TOKENIZERS_PARALLELISM", "true")
140+
monkeypatch.setenv("VLLM_USE_AOT_COMPILE", use_aot_compile)
141+
monkeypatch.setenv("VLLM_USE_BYTECODE_HOOK", "0")
142+
143+
# Create vllm config with the desired settings
144+
from vllm.config import CompilationMode
145+
146+
vllm_config = VllmConfig(
147+
compilation_config=CompilationConfig(
148+
mode=CompilationMode.VLLM_COMPILE,
149+
dynamic_shapes_config=DynamicShapesConfig(
150+
type=dynamic_shapes_type,
151+
evaluate_guards=evaluate_guards,
152+
),
153+
)
154+
)
155+
156+
def test(model_class, input1, input2, is_01_specialization=False):
157+
with torch.no_grad(), use_vllm_config(vllm_config):
158+
model = model_class(vllm_config=vllm_config).cuda()
159+
160+
model(input1)
161+
162+
if evaluate_guards and not is_01_specialization:
163+
# This should fail because guards were added.
164+
try:
165+
model(input2)
166+
raise RuntimeError("expected guard violation to occur")
167+
except RuntimeError as e:
168+
# Expected failure - guard was violated
169+
error_msg = str(e)
170+
if "guard" in error_msg.lower() or "recompile" in error_msg.lower():
171+
pass
172+
else:
173+
raise e
174+
175+
else:
176+
model(input2)
177+
178+
test(ModelWithSizeCheck, torch.randn(20, 10).cuda(), torch.randn(5, 10).cuda())
179+
test(ModelWithSizeCheck, torch.randn(5, 10).cuda(), torch.randn(20, 10).cuda())
180+
181+
@support_torch_compile
182+
class ModelWithOneSizeCheck(torch.nn.Module):
183+
def __init__(self, **kwargs):
184+
super().__init__()
185+
self.linear = torch.nn.Linear(10, 10)
186+
187+
def forward(self, x: torch.Tensor):
188+
x = self.linear(x)
189+
# This will cause 0/1 specializations.
190+
if x.shape[0] >= 2:
191+
return x
192+
else:
193+
return x
194+
195+
test(ModelWithOneSizeCheck, torch.randn(20, 10).cuda(), torch.randn(1, 10).cuda())

vllm/compilation/backends.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import torch
1919
import torch.fx as fx
2020
from torch._dispatch.python import enable_python_dispatcher
21+
from torch.utils._sympy.value_ranges import ValueRanges
2122

2223
import vllm.envs as envs
2324
from vllm.compilation.inductor_pass import pass_context
@@ -26,6 +27,7 @@
2627
should_split,
2728
)
2829
from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
30+
from vllm.config.compilation import DynamicShapesType
2931
from vllm.config.utils import hash_factors
3032
from vllm.logger import init_logger
3133
from vllm.logging_utils import lazy
@@ -752,6 +754,27 @@ def __call__(
752754
self.split_gm, submod_names_to_compile, self.vllm_config, self
753755
).run(*example_inputs)
754756

757+
from torch._guards import detect_fake_mode
758+
759+
fake_mode = detect_fake_mode()
760+
761+
if (
762+
self.compilation_config.dynamic_shapes_config.evaluate_guards
763+
and self.compilation_config.dynamic_shapes_config.type
764+
== DynamicShapesType.BACKED
765+
):
766+
# Drop counter-0/1 specializations guards; for backed dynamic shapes,
767+
# torch.compile will specialize for 0/1 inputs or otherwise guards that
768+
# shape is >= 2. This is because it's really hard not to hit a check
769+
# against 0/1. When we evaluate shape guards, we exclude checking those
770+
# guards (We would fail always otherwise).
771+
772+
# We avoid that by updating the ranges of backed sizes when the min is
773+
# 2 for any, we assume it's 0.
774+
for s, r in fake_mode.shape_env.var_to_range.items():
775+
if r.lower == 2:
776+
fake_mode.shape_env.var_to_range[s] = ValueRanges(0, r.upper)
777+
755778
graph_path = os.path.join(local_cache_dir, "computation_graph.py")
756779
if not os.path.exists(graph_path):
757780
# code adapted from
@@ -780,9 +803,6 @@ def __call__(
780803
)
781804

782805
# if we need to copy input buffers for cudagraph
783-
from torch._guards import detect_fake_mode
784-
785-
fake_mode = detect_fake_mode()
786806
fake_args = [
787807
fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t
788808
for t in example_inputs

vllm/compilation/wrapper.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import os
55
import sys
66
from abc import abstractmethod
7-
from contextlib import contextmanager
7+
from contextlib import contextmanager, nullcontext
88
from types import CodeType
99
from typing import Any
1010

@@ -98,6 +98,7 @@ def __init__(self):
9898
vllm_config = get_current_vllm_config()
9999
self.vllm_config = vllm_config
100100
mode = vllm_config.compilation_config.mode
101+
101102
if mode is None:
102103
raise RuntimeError("Compilation mode cannot be NO_COMPILATION")
103104

@@ -106,10 +107,20 @@ def __init__(self):
106107

107108
if isinstance(backend, str) and backend == "inductor":
108109
options = vllm_config.compilation_config.inductor_compile_config
109-
110+
self.first_compile = True
110111
if mode != CompilationMode.STOCK_TORCH_COMPILE:
111112
# Drop all the guards.
112-
options["guard_filter_fn"] = lambda x: [False for _ in x]
113+
if vllm_config.compilation_config.dynamic_shapes_config.evaluate_guards:
114+
assert not envs.VLLM_USE_BYTECODE_HOOK, (
115+
"compilation_config.dynamic_shapes_config.evaluate_guards "
116+
"requires VLLM_USE_BYTECODE_HOOK=0. "
117+
)
118+
119+
options["guard_filter_fn"] = lambda x: [
120+
entry.guard_type == "SHAPE_ENV" for entry in x
121+
]
122+
else:
123+
options["guard_filter_fn"] = lambda x: [False for _ in x]
113124

114125
# Validate that unbacked dynamic shapes require VLLM_USE_BYTECODE_HOOK=False
115126
from vllm.compilation.decorators import DynamicShapesType
@@ -122,7 +133,7 @@ def __init__(self):
122133
# remove_from_cache(self.original_code_object()) to force a new
123134
# re-compilation.
124135
raise ValueError(
125-
"UNBACKED dynamic shapes require VLLM_USE_BYTECODE_HOOK=0. "
136+
"UNBACKED dynamic shapes requires VLLM_USE_BYTECODE_HOOK=0. "
126137
)
127138
compiled_ptr = self.check_invariants_and_forward
128139

@@ -154,7 +165,11 @@ def aot_compile(self, *args, **kwargs):
154165
+ "Please make sure torch.compile is enabled with the latest "
155166
+ f"version of PyTorch (current using torch: {torch.__version__})"
156167
)
157-
return self._compiled_callable.aot_compile((args, kwargs))
168+
prev = self.first_compile
169+
self.first_compile = False
170+
ctx = nullcontext() if prev else torch.compiler.set_stance("fail_on_recompile")
171+
with ctx:
172+
return self._compiled_callable.aot_compile((args, kwargs))
158173

159174
def __call__(self, *args, **kwargs):
160175
if envs.VLLM_USE_BYTECODE_HOOK:
@@ -173,7 +188,13 @@ def __call__(self, *args, **kwargs):
173188
with self._dispatch_to_compiled_code():
174189
return self.forward(*args, **kwargs)
175190
else:
176-
with _compilation_context():
191+
ctx = (
192+
nullcontext()
193+
if self.first_compile
194+
else torch.compiler.set_stance("fail_on_recompile")
195+
)
196+
self.first_compile = False
197+
with _compilation_context(), ctx:
177198
return self._compiled_callable(*args, **kwargs)
178199

179200
@abstractmethod

vllm/config/compilation.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,13 @@ class DynamicShapesConfig:
227227
backed/unbacked.
228228
"""
229229

230-
# TODO add a debug mode to fail
230+
evaluate_guards: bool = False
231+
"""
232+
A debug mode to detect and fail if Dynamo ever specializes a dynamic shape by
233+
guarding on it. When True, dynamic shape guards are not dropped from Dynamo.
234+
And a failure will be triggered if recompilation ever happens due to that.
235+
Enabling this allow observing the dynamic shapes guards in the tl-parse artifact.
236+
"""
231237

232238
def compute_hash(self) -> str:
233239
"""
@@ -330,8 +336,8 @@ class CompilationConfig:
330336
We use string to avoid serialization issues when using compilation in a
331337
distributed setting. When the compilation mode is 1 or 2, the backend is
332338
used for the compilation directly (it sees the whole graph). When the
333-
compilation mode is 3, the backend supports both whole graph and piecewise
334-
compilation, available backends include eager, inductor, and custom backends,
339+
compilation mode is 3, the backend supports both whole graph and piecewise
340+
compilation, available backends include eager, inductor, and custom backends,
335341
the latter of which can be defined via `get_compile_backend`. Furthermore,
336342
compilation is only piecewise if splitting ops is set accordingly and
337343
use_inductor_graph_partition is off. Note that the default options for

0 commit comments

Comments
 (0)