Skip to content

Commit 2c2eb08

Browse files
committed
add evaluate_guards option to DynamicShapesConfig
Signed-off-by: Laith Sakka <lsakka@meta.com>
1 parent de75b0b commit 2c2eb08

File tree

6 files changed

+230
-30
lines changed

6 files changed

+230
-30
lines changed

tests/compile/test_dynamic_shapes_compilation.py

Lines changed: 133 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,21 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

44
import gc
5+
import tempfile
6+
from contextlib import contextmanager
57

68
import pytest
79
import torch
810

911
from vllm import LLM, SamplingParams
10-
from vllm.config.compilation import CompilationMode, DynamicShapesType
12+
from vllm.compilation.decorators import support_torch_compile
13+
from vllm.config import CompilationConfig, VllmConfig, set_current_vllm_config
14+
from vllm.config.compilation import (
15+
CompilationMode,
16+
DynamicShapesConfig,
17+
DynamicShapesType,
18+
)
19+
from vllm.forward_context import set_forward_context
1120
from vllm.transformers_utils.tokenizer import get_tokenizer
1221
from vllm.utils.torch_utils import is_torch_equal_or_newer
1322

@@ -29,18 +38,19 @@ def get_test_models():
2938
)
3039
@pytest.mark.parametrize("use_aot_compile", ["0"])
3140
@pytest.mark.parametrize("use_bytecode_hook", [True, False])
41+
@pytest.mark.parametrize("evaluate_guards", [False, True])
3242
@pytest.mark.skipif(
3343
not is_torch_equal_or_newer("2.10.0.dev"), reason="requires torch 2.10"
3444
)
3545
def test_dynamic_shapes_compilation(
36-
monkeypatch, model_name, shapes_type, use_aot_compile, use_bytecode_hook
46+
monkeypatch,
47+
model_name,
48+
shapes_type,
49+
use_aot_compile,
50+
use_bytecode_hook,
51+
evaluate_guards,
3752
):
3853
"""Test that all dynamic shapes types compile successfully"""
39-
print(
40-
f"\nTesting model: {model_name} with {shapes_type.name}, "
41-
f"AOT compile: {use_aot_compile}, "
42-
f"Bytecode hook: {use_bytecode_hook}"
43-
)
4454
if use_bytecode_hook and shapes_type == DynamicShapesType.UNBACKED:
4555
pytest.skip("UNBACKED dynamic shapes require VLLM_USE_BYTECODE_HOOK=0")
4656

@@ -58,6 +68,7 @@ def test_dynamic_shapes_compilation(
5868
"mode": CompilationMode.VLLM_COMPILE,
5969
"dynamic_shapes_config": {
6070
"type": shapes_type.value,
71+
"evaluate_guards": evaluate_guards,
6172
},
6273
},
6374
)
@@ -86,3 +97,118 @@ def test_dynamic_shapes_compilation(
8697
torch.cuda.empty_cache()
8798
torch.cuda.synchronize()
8899
print("GPU memory cleared")
100+
101+
102+
@pytest.mark.parametrize("use_aot_compile", ["0", "1"])
103+
@pytest.mark.parametrize(
104+
"dynamic_shapes_type",
105+
[
106+
DynamicShapesType.BACKED,
107+
DynamicShapesType.BACKED_SIZE_OBLIVIOUS,
108+
],
109+
)
110+
@pytest.mark.parametrize("evaluate_guards", [False, True])
111+
def test_model_specialization_with_evaluate_guards(
112+
monkeypatch, use_aot_compile, dynamic_shapes_type, evaluate_guards
113+
):
114+
"""Test that evaluate_guards correctly detects shape specialization
115+
violations.
116+
"""
117+
118+
if (
119+
use_aot_compile == "1"
120+
and dynamic_shapes_type == DynamicShapesType.BACKED
121+
and evaluate_guards
122+
):
123+
pytest.skip("evaluate_guards for backed does not work with aot_compile =1")
124+
125+
@support_torch_compile
126+
class ModelWithSizeCheck(torch.nn.Module):
127+
def __init__(self, **kwargs):
128+
super().__init__()
129+
130+
def forward(self, x: torch.Tensor):
131+
# This will cause specialization - torch.compile will guard on
132+
# sx.shape[0]
133+
if x.shape[0] >= 10:
134+
return x * 10
135+
else:
136+
return x * 10
137+
138+
@support_torch_compile
139+
class ModelWithOneSizeCheck(torch.nn.Module):
140+
def __init__(self, **kwargs):
141+
super().__init__()
142+
143+
def forward(self, x: torch.Tensor):
144+
# This will cause 0/1 specializations.
145+
if x.shape[0] == 0:
146+
return x * 10
147+
if x.shape[0] == 1:
148+
return x * 10
149+
else:
150+
return x * 10
151+
152+
@contextmanager
153+
def use_vllm_config(vllm_config: VllmConfig):
154+
with set_forward_context({}, vllm_config), set_current_vllm_config(vllm_config):
155+
yield
156+
157+
monkeypatch.setenv("TOKENIZERS_PARALLELISM", "true")
158+
monkeypatch.setenv("VLLM_USE_AOT_COMPILE", use_aot_compile)
159+
monkeypatch.setenv("VLLM_USE_BYTECODE_HOOK", "0")
160+
161+
# Create vllm config with the desired settings
162+
from vllm.config import CompilationMode
163+
164+
vllm_config = VllmConfig(
165+
compilation_config=CompilationConfig(
166+
mode=CompilationMode.VLLM_COMPILE,
167+
dynamic_shapes_config=DynamicShapesConfig(
168+
type=dynamic_shapes_type,
169+
evaluate_guards=evaluate_guards,
170+
),
171+
)
172+
)
173+
174+
def test(model_class, input1, input2, is_01_specialization=False):
175+
with (
176+
torch.no_grad(),
177+
use_vllm_config(vllm_config),
178+
tempfile.TemporaryDirectory() as tmpdirname,
179+
):
180+
monkeypatch.setenv("VLLM_CACHE_ROOT", tmpdirname)
181+
182+
model = model_class(vllm_config=vllm_config).cuda()
183+
184+
model(input1)
185+
186+
if evaluate_guards and (
187+
not (
188+
is_01_specialization
189+
and dynamic_shapes_type == DynamicShapesType.BACKED
190+
)
191+
):
192+
# This should fail because guards were added.
193+
try:
194+
model(input2)
195+
raise RuntimeError("expected guard violation to occur")
196+
except RuntimeError as e:
197+
# Expected failure - guard was violated
198+
error_msg = str(e)
199+
assert (
200+
"GuardManager check failed" in error_msg
201+
or "Detected recompile when torch.compile stance" in error_msg
202+
), error_msg
203+
204+
else:
205+
model(input2)
206+
207+
# test(ModelWithSizeCheck, torch.randn(20, 10).cuda(), torch.randn(5, 10).cuda())
208+
# test(ModelWithSizeCheck, torch.randn(5, 10).cuda(), torch.randn(20, 10).cuda())
209+
test(
210+
ModelWithOneSizeCheck,
211+
torch.randn(20, 10).cuda(),
212+
torch.randn(1, 10).cuda(),
213+
is_01_specialization=True,
214+
)

vllm/compilation/backends.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import torch
1919
import torch.fx as fx
2020
from torch._dispatch.python import enable_python_dispatcher
21+
from torch.utils._sympy.value_ranges import ValueRanges
2122

2223
import vllm.envs as envs
2324
from vllm.compilation.inductor_pass import pass_context
@@ -26,6 +27,7 @@
2627
should_split,
2728
)
2829
from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
30+
from vllm.config.compilation import DynamicShapesType
2931
from vllm.config.utils import hash_factors
3032
from vllm.logger import init_logger
3133
from vllm.logging_utils import lazy
@@ -752,6 +754,27 @@ def __call__(
752754
self.split_gm, submod_names_to_compile, self.vllm_config, self
753755
).run(*example_inputs)
754756

757+
from torch._guards import detect_fake_mode
758+
759+
fake_mode = detect_fake_mode()
760+
761+
if (
762+
self.compilation_config.dynamic_shapes_config.evaluate_guards
763+
and self.compilation_config.dynamic_shapes_config.type
764+
== DynamicShapesType.BACKED
765+
):
766+
# Drop counter-0/1 specializations guards; for backed dynamic shapes,
767+
# torch.compile will specialize for 0/1 inputs or otherwise guards that
768+
# shape is >= 2. This is because it's really hard not to hit a check
769+
# against 0/1. When we evaluate shape guards, we exclude checking those
770+
# guards (We would fail always otherwise).
771+
772+
# We avoid that by updating the ranges of backed sizes when the min is
773+
# 2 for any, we assume it's 0.
774+
for s, r in fake_mode.shape_env.var_to_range.items():
775+
if r.lower == 2:
776+
fake_mode.shape_env.var_to_range[s] = ValueRanges(0, r.upper)
777+
755778
graph_path = os.path.join(local_cache_dir, "computation_graph.py")
756779
if not os.path.exists(graph_path):
757780
# code adapted from
@@ -780,9 +803,6 @@ def __call__(
780803
)
781804

782805
# if we need to copy input buffers for cudagraph
783-
from torch._guards import detect_fake_mode
784-
785-
fake_mode = detect_fake_mode()
786806
fake_args = [
787807
fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t
788808
for t in example_inputs

vllm/compilation/decorators.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -392,7 +392,6 @@ def __call__(self, *args, **kwargs):
392392

393393
factors.append(_model_hash_key(self.forward))
394394
hash_key = hashlib.sha256(str(factors).encode()).hexdigest()
395-
396395
cache_dir = os.path.join(
397396
envs.VLLM_CACHE_ROOT,
398397
"torch_aot_compile",
@@ -411,7 +410,8 @@ def __call__(self, *args, **kwargs):
411410
start_monitoring_torch_compile(self.vllm_config)
412411
loaded_fn = torch.compiler.load_compiled_function(f)
413412
_verify_source_unchanged(loaded_fn.source_info(), self.vllm_config)
414-
loaded_fn.disable_guard_check()
413+
if not self.compilation_config.dynamic_shapes_config.evaluate_guards:
414+
loaded_fn.disable_guard_check()
415415
self.aot_compiled_fn = loaded_fn
416416
except Exception as e:
417417
if os.path.exists(aot_compilation_path):

vllm/compilation/wrapper.py

Lines changed: 58 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,15 @@
44
import os
55
import sys
66
from abc import abstractmethod
7-
from contextlib import contextmanager
7+
from contextlib import contextmanager, nullcontext
88
from types import CodeType
99
from typing import Any
1010

1111
import torch
1212
import torch._C._dynamo.guards
1313

1414
import vllm.envs as envs
15+
from vllm.compilation.decorators import DynamicShapesType
1516
from vllm.config import CompilationMode, CUDAGraphMode, get_current_vllm_config
1617
from vllm.logger import init_logger
1718

@@ -98,6 +99,7 @@ def __init__(self):
9899
vllm_config = get_current_vllm_config()
99100
self.vllm_config = vllm_config
100101
mode = vllm_config.compilation_config.mode
102+
101103
if mode is None:
102104
raise RuntimeError("Compilation mode cannot be NO_COMPILATION")
103105

@@ -107,23 +109,55 @@ def __init__(self):
107109
if isinstance(backend, str) and backend == "inductor":
108110
options = vllm_config.compilation_config.inductor_compile_config
109111

112+
self.first_compile = True
113+
114+
ds_type = vllm_config.compilation_config.dynamic_shapes_config.type
115+
110116
if mode != CompilationMode.STOCK_TORCH_COMPILE:
111117
# Drop all the guards.
112-
options["guard_filter_fn"] = lambda x: [False for _ in x]
118+
if vllm_config.compilation_config.dynamic_shapes_config.evaluate_guards:
119+
assert not envs.VLLM_USE_BYTECODE_HOOK, (
120+
"compilation_config.dynamic_shapes_config.evaluate_guards "
121+
"requires VLLM_USE_BYTECODE_HOOK=0. "
122+
)
113123

114-
# Validate that unbacked dynamic shapes require VLLM_USE_BYTECODE_HOOK=False
115-
from vllm.compilation.decorators import DynamicShapesType
124+
# aot_compile would installs dynamic shapes guards before vllm backend.
125+
# Hence we cant avoid checking at 0/1 specialization.
126+
# But it aot_compile is fine with BACKED_SIZE_OBLIVIOUS dynamic shapes,
127+
# since we do not want to specialize on 0/1.
128+
if envs.VLLM_USE_AOT_COMPILE:
129+
assert ds_type != DynamicShapesType.BACKED, (
130+
"evaluate_guards for backed shapes requires "
131+
"VLLM_USE_AOT_COMPILE=False. "
132+
)
133+
134+
assert not envs.VLLM_USE_BYTECODE_HOOK, (
135+
"compilation_config.dynamic_shapes_config.evaluate_guards "
136+
"requires VLLM_USE_BYTECODE_HOOK=0. "
137+
)
138+
139+
options["guard_filter_fn"] = lambda x: [
140+
entry.guard_type == "SHAPE_ENV" for entry in x
141+
]
142+
else:
143+
options["guard_filter_fn"] = lambda x: [False for _ in x]
116144

117-
ds_type = vllm_config.compilation_config.dynamic_shapes_config.type
118145
compiled_ptr: Any = self.forward
146+
# Validate that unbacked dynamic shapes require VLLM_USE_BYTECODE_HOOK=False
147+
119148
if ds_type == DynamicShapesType.UNBACKED:
120-
if envs.VLLM_USE_BYTECODE_HOOK:
121-
# reason is that bytecode does this hack torch._dynamo.eval_frame.
122-
# remove_from_cache(self.original_code_object()) to force a new
123-
# re-compilation.
124-
raise ValueError(
125-
"UNBACKED dynamic shapes require VLLM_USE_BYTECODE_HOOK=0. "
126-
)
149+
# reason is that bytecode does torch._dynamo.eval_frame.
150+
# remove_from_cache(self.original_code_object()) to force a new
151+
# re-compilation. And if we use
152+
# compiled_ptr = self.check_invariants_and_forward
153+
# it will reset all entries.
154+
assert envs.VLLM_USE_BYTECODE_HOOK, (
155+
"UNBACKED dynamic shapes requires VLLM_USE_BYTECODE_HOOK=0. "
156+
)
157+
assert (
158+
not vllm_config.compilation_config.dynamic_shapes_config.evaluate_guards
159+
), "UNBACKED dynamic shapes do not add guards"
160+
127161
compiled_ptr = self.check_invariants_and_forward
128162

129163
if envs.VLLM_USE_AOT_COMPILE:
@@ -154,7 +188,11 @@ def aot_compile(self, *args, **kwargs):
154188
+ "Please make sure torch.compile is enabled with the latest "
155189
+ f"version of PyTorch (current using torch: {torch.__version__})"
156190
)
157-
return self._compiled_callable.aot_compile((args, kwargs))
191+
prev = self.first_compile
192+
self.first_compile = False
193+
ctx = nullcontext() if prev else torch.compiler.set_stance("fail_on_recompile")
194+
with ctx:
195+
return self._compiled_callable.aot_compile((args, kwargs))
158196

159197
def __call__(self, *args, **kwargs):
160198
if envs.VLLM_USE_BYTECODE_HOOK:
@@ -173,7 +211,13 @@ def __call__(self, *args, **kwargs):
173211
with self._dispatch_to_compiled_code():
174212
return self.forward(*args, **kwargs)
175213
else:
176-
with _compilation_context():
214+
ctx = (
215+
nullcontext()
216+
if self.first_compile
217+
else torch.compiler.set_stance("fail_on_recompile")
218+
)
219+
self.first_compile = False
220+
with _compilation_context(), ctx:
177221
return self._compiled_callable(*args, **kwargs)
178222

179223
@abstractmethod

vllm/config/compilation.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,17 @@ class DynamicShapesConfig:
227227
backed/unbacked.
228228
"""
229229

230-
# TODO add a debug mode to fail
230+
evaluate_guards: bool = False
231+
"""
232+
A debug mode to detect and fail if Dynamo ever specializes a dynamic shape by
233+
guarding on it. When True, dynamic shape guards are not dropped from dynamo.
234+
And a failure will be triggered if a recompilation ever happens due to that.
235+
236+
Enabling this allow observing the dynamic shapes guards in the tl-parse
237+
artifacts also.
238+
When type is backed, aot_compile must be disabled for this mode to work.
239+
This mode also requires VLLM_USE_BYTECODE_HOOK to be 0.
240+
"""
231241

232242
def compute_hash(self) -> str:
233243
"""
@@ -330,8 +340,8 @@ class CompilationConfig:
330340
We use string to avoid serialization issues when using compilation in a
331341
distributed setting. When the compilation mode is 1 or 2, the backend is
332342
used for the compilation directly (it sees the whole graph). When the
333-
compilation mode is 3, the backend supports both whole graph and piecewise
334-
compilation, available backends include eager, inductor, and custom backends,
343+
compilation mode is 3, the backend supports both whole graph and piecewise
344+
compilation, available backends include eager, inductor, and custom backends,
335345
the latter of which can be defined via `get_compile_backend`. Furthermore,
336346
compilation is only piecewise if splitting ops is set accordingly and
337347
use_inductor_graph_partition is off. Note that the default options for

0 commit comments

Comments
 (0)