Skip to content

Commit d670ef3

Browse files
committed
add evaluate_guards option to DynamicShapesConfig
Signed-off-by: Laith Sakka <lsakka@meta.com>
1 parent 3f78104 commit d670ef3

File tree

4 files changed

+188
-39
lines changed

4 files changed

+188
-39
lines changed

tests/compile/test_dynamic_shapes_compilation.py

Lines changed: 137 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from torch.torch_version import TorchVersion
99

1010
from vllm import LLM, SamplingParams
11+
from vllm.config import set_current_vllm_config
1112
from vllm.config.compilation import DynamicShapesType
1213

1314

@@ -35,9 +36,10 @@ def get_test_models():
3536

3637

3738
@pytest.mark.parametrize("model_name", get_test_models())
38-
def test_dynamic_shapes_compilation(monkeypatch, model_name):
39+
@pytest.mark.parametrize("evaluate_guards", [False, True])
40+
def test_dynamic_shapes_compilation(monkeypatch, model_name, evaluate_guards):
3941
"""Test that all dynamic shapes types produce compiles"""
40-
print(f"\nTesting model: {model_name}")
42+
print(f"\nTesting model: {model_name} with evaluate_guards={evaluate_guards}")
4143

4244
monkeypatch.setenv("TOKENIZERS_PARALLELISM", "true")
4345
# Note USE_AOT_COMPILE fails https://github.com/vllm-project/vllm/issues/27040.
@@ -76,7 +78,10 @@ def test_dynamic_shapes_compilation(monkeypatch, model_name):
7678
DynamicShapesType.UNBACKED,
7779
DynamicShapesType.BACKED_SIZE_OBLIVIOUS,
7880
]:
79-
print(f"Testing {shapes_type.name} dynamic shapes...")
81+
print(
82+
f"Testing {shapes_type.name} dynamic shapes with "
83+
f"evaluate_guards={evaluate_guards}..."
84+
)
8085

8186
# Initialize the model with specific dynamic shapes configuration
8287
model = LLM(
@@ -85,7 +90,7 @@ def test_dynamic_shapes_compilation(monkeypatch, model_name):
8590
"level": 3, # PIECEWISE compilation
8691
"dynamic_shapes_config": {
8792
"dynamic_shapes_type": shapes_type.value,
88-
"eval_dynamo_ds_guards": False,
93+
"eval_dynamo_ds_guards": evaluate_guards,
8994
},
9095
},
9196
# gpu_memory_utilization=0.2,
@@ -110,36 +115,136 @@ def test_dynamic_shapes_compilation(monkeypatch, model_name):
110115
print(f"{shape_type}: '{result}'")
111116

112117

113-
if __name__ == "__main__":
114-
"""Run the test directly as a Python script"""
115-
import os
116-
117-
print("Running dynamic shapes compilation test...")
118-
119-
# Get test models based on PyTorch version
120-
test_models = get_test_models()
121-
print(f"Testing {len(test_models)} models: {test_models}")
122-
123-
# Create a mock monkeypatch object for environment variables
124-
class MockMonkeypatch:
125-
def setenv(self, key, value):
126-
os.environ[key] = value
127-
128-
monkeypatch = MockMonkeypatch()
118+
@pytest.mark.parametrize("use_aot_compile", ["0", "1"])
119+
@pytest.mark.parametrize(
120+
"dynamic_shapes_type",
121+
[
122+
DynamicShapesType.BACKED,
123+
DynamicShapesType.BACKED_SIZE_OBLIVIOUS,
124+
],
125+
)
126+
@pytest.mark.parametrize("evaluate_guards", [False, True])
127+
def test_model_specialization_with_evaluate_guards(
128+
monkeypatch, use_aot_compile, dynamic_shapes_type, evaluate_guards
129+
):
130+
"""Test that evaluate_guards correctly detects shape specialization violations."""
131+
from contextlib import contextmanager
132+
133+
from vllm.compilation.decorators import support_torch_compile
134+
from vllm.config import CompilationConfig, VllmConfig
135+
from vllm.config.compilation import DynamicShapesConfig
136+
from vllm.forward_context import set_forward_context
137+
138+
@support_torch_compile
139+
class ModelWithSizeCheck(torch.nn.Module):
140+
def __init__(self, **kwargs):
141+
super().__init__()
142+
self.linear = torch.nn.Linear(10, 10)
143+
144+
def forward(self, x: torch.Tensor):
145+
x = self.linear(x)
146+
# This will cause specialization - torch.compile will guard on x.shape[0]
147+
if x.shape[0] >= 10:
148+
return x
149+
else:
150+
return x
151+
152+
@contextmanager
153+
def use_vllm_config(vllm_config: VllmConfig):
154+
with set_forward_context({}, vllm_config), set_current_vllm_config(vllm_config):
155+
yield
129156

130-
# Run test for each model
131-
for model_name in test_models:
132-
try:
133-
print(f"\n{'=' * 60}")
134-
print(f"Testing model: {model_name}")
135-
print(f"{'=' * 60}")
157+
monkeypatch.setenv("TOKENIZERS_PARALLELISM", "true")
158+
monkeypatch.setenv("VLLM_USE_AOT_COMPILE", use_aot_compile)
136159

137-
test_dynamic_shapes_compilation(monkeypatch, model_name)
160+
# Reset torch dynamo to clear any cached compilation state
161+
torch._dynamo.reset()
138162

139-
print(f"✅ Test passed for {model_name}")
163+
config_desc = (
164+
f"AOT={use_aot_compile}, shapes={dynamic_shapes_type.name}, "
165+
f"eval_guards={evaluate_guards}"
166+
)
167+
print(f"\n{'=' * 60}")
168+
print(f"Testing: {config_desc}")
169+
print(f"{'=' * 60}")
170+
171+
# Create vllm config with the desired settings
172+
from vllm.config import CompilationMode
173+
174+
vllm_config = VllmConfig(
175+
compilation_config=CompilationConfig(
176+
mode=CompilationMode.VLLM_COMPILE,
177+
dynamic_shapes_config=DynamicShapesConfig(
178+
dynamic_shapes_type=dynamic_shapes_type,
179+
evaluate_guards=evaluate_guards,
180+
),
181+
)
182+
)
140183

141-
except Exception as e:
142-
print(f"❌ Test failed for {model_name}: {e}")
143-
raise
184+
assert (
185+
vllm_config.compilation_config.dynamic_shapes_config.evaluate_guards
186+
== evaluate_guards
187+
)
188+
with torch.no_grad(), use_vllm_config(vllm_config):
189+
model = ModelWithSizeCheck(vllm_config=vllm_config).cuda()
190+
191+
# First call with size 20 - should always work
192+
input_10 = torch.randn(20, 10).cuda()
193+
model(input_10)
194+
195+
# Second call with different size (5) - behavior depends on evaluate_guards
196+
input_5 = torch.randn(5, 10).cuda()
197+
198+
# Allow recompiles for evaluate_guards=False case
199+
# Only when evaluate_guards=True do we want to detect guard violations
200+
if evaluate_guards:
201+
# With evaluate_guards=True, this should fail because
202+
# guards were added. The model specialized on size 10,
203+
# so size 5 violates the guard
204+
try:
205+
model(input_5)
206+
# If we get here, no guard violation occurred
207+
# This is a TEST FAILURE - evaluate_guards should have caused a failure
208+
pytest.fail(
209+
f"{config_desc}: Expected guard violation did "
210+
f"not occur! evaluate_guards=True should fail "
211+
f"when shape changes from 10 to 5, but the "
212+
f"model ran successfully without error."
213+
)
214+
except Exception as e:
215+
# Expected failure - guard was violated
216+
error_msg = str(e)
217+
if "guard" in error_msg.lower() or "recompile" in error_msg.lower():
218+
print(f"✅ {config_desc}: Expected failure due to guard violation")
219+
print(f" Error (truncated): {error_msg[:150]}")
220+
else:
221+
# Unexpected error type
222+
print(f"❌ {config_desc}: Unexpected error type")
223+
print(f" Error: {e}")
224+
raise
225+
else:
226+
# With evaluate_guards=False, guards are dropped, so this should work
227+
# However, recompilation may still occur, which is expected
228+
try:
229+
output_5 = model(input_5)
230+
assert output_5.shape == (
231+
5,
232+
10,
233+
), "Output shape should match input"
234+
print(f"✅ {config_desc}: Passed without guard violations")
235+
print(" Second call (size 5): Success")
236+
except RuntimeError as e:
237+
# If it's a recompile error, that's expected when evaluate_guards=False
238+
# The model is allowed to recompile with different shapes
239+
if (
240+
"recompile" in str(e).lower()
241+
and "fail_on_recompile" in str(e).lower()
242+
):
243+
print(f"✅ {config_desc}: Recompile occurred (expected behavior)")
244+
print(" Recompiles are allowed when evaluate_guards=False")
245+
else:
246+
print(f"❌ {config_desc}: Unexpected failure")
247+
print(f" Error: {e}")
248+
raise
144249

145-
print("\n🎉 All tests completed successfully!")
250+
cleanup_gpu_memory()

vllm/compilation/backends.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import torch
1515
import torch.fx as fx
1616
from torch._dispatch.python import enable_python_dispatcher
17+
from torch.utils._sympy.value_ranges import ValueRanges
1718

1819
import vllm.envs as envs
1920
from vllm.compilation.inductor_pass import pass_context
@@ -22,6 +23,7 @@
2223
resolve_defined_ops,
2324
)
2425
from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
26+
from vllm.config.compilation import DynamicShapesType
2527
from vllm.logger import init_logger
2628
from vllm.platforms import current_platform
2729
from vllm.utils.import_utils import resolve_obj_by_qualname
@@ -659,6 +661,27 @@ def __call__(
659661
self.split_gm, submod_names_to_compile, self.vllm_config, self
660662
).run(*example_inputs)
661663

664+
from torch._guards import detect_fake_mode
665+
666+
fake_mode = detect_fake_mode()
667+
668+
if (
669+
self.compilation_config.dynamic_shapes_config.evaluate_guards
670+
and self.compilation_config.dynamic_shapes_config
671+
== DynamicShapesType.BACKED
672+
):
673+
# Drop counter-0/1 specializations guards; for backed dynamic shapes,
674+
# torch.compile will specialize for 0/1 inputs or otherwise guards that
675+
# shape is >= 2. This is because it's really hard not to hit a check
676+
# against 0/1. When we evaluate shape guards, we exclude checking those
677+
# guards (We would fail always otherwise).
678+
679+
# We avoid that by updating the ranges of backed sizes when the min is
680+
# 2 for any, we assume it's 0.
681+
for s, r in fake_mode.shape_env.var_to_range.items():
682+
if r.lower == 2:
683+
fake_mode.shape_env.var_to_range[s] = ValueRanges(0, r.upper)
684+
662685
graph_path = os.path.join(local_cache_dir, "computation_graph.py")
663686
if not os.path.exists(graph_path):
664687
# code adapted from
@@ -685,9 +708,6 @@ def __call__(
685708
)
686709

687710
# if we need to copy input buffers for cudagraph
688-
from torch._guards import detect_fake_mode
689-
690-
fake_mode = detect_fake_mode()
691711
fake_args = [
692712
fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t
693713
for t in example_inputs

vllm/compilation/wrapper.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

44
from abc import abstractmethod
5+
from contextlib import nullcontext
56
from types import CodeType
67

78
import torch
@@ -33,6 +34,7 @@ def __init__(self):
3334

3435
vllm_config = get_current_vllm_config()
3536
mode = vllm_config.compilation_config.mode
37+
3638
if mode is None:
3739
raise RuntimeError("Compilation mode cannot be NO_COMPILATION")
3840

@@ -41,10 +43,15 @@ def __init__(self):
4143

4244
if isinstance(backend, str) and backend == "inductor":
4345
options = vllm_config.compilation_config.inductor_compile_config
44-
46+
self.first_compile = True
4547
if mode != CompilationMode.STOCK_TORCH_COMPILE:
4648
# Drop all the guards.
47-
options["guard_filter_fn"] = lambda x: [False for _ in x]
49+
if vllm_config.compilation_config.dynamic_shapes_config.evaluate_guards:
50+
options["guard_filter_fn"] = lambda x: [
51+
entry.guard_type == "SHAPE_ENV" for entry in x
52+
]
53+
else:
54+
options["guard_filter_fn"] = lambda x: [False for _ in x]
4855

4956
if envs.VLLM_USE_AOT_COMPILE:
5057
if hasattr(torch._dynamo.config, "enable_aot_compile"):
@@ -69,10 +76,18 @@ def aot_compile(self, *args, **kwargs):
6976
+ "Please make sure torch.compile is enabled with the latest "
7077
+ f"version of PyTorch (current using torch: {torch.__version__})"
7178
)
72-
return self._compiled_callable.aot_compile((args, kwargs))
79+
prev = self.first_compile
80+
self.first_compile = False
81+
ctx = nullcontext() if prev else torch.compiler.set_stance("fail_on_recompile")
82+
with ctx:
83+
return self._compiled_callable.aot_compile((args, kwargs))
7384

7485
def __call__(self, *args, **kwargs):
75-
return self._compiled_callable(*args, **kwargs)
86+
prev = self.first_compile
87+
self.first_compile = False
88+
ctx = nullcontext() if prev else torch.compiler.set_stance("fail_on_recompile")
89+
with ctx:
90+
return self._compiled_callable(*args, **kwargs)
7691

7792
@abstractmethod
7893
def forward(self, *args, **kwargs): ...

vllm/config/compilation.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,14 @@ class DynamicShapesConfig:
208208
backed/unbacked.
209209
"""
210210

211+
evaluate_guards: bool = False
212+
"""
213+
A debug mode to detect and fail if Dynamo ever specializes a dynamic shape by
214+
guarding on it. When True, dynamic shape guards are not dropped from Dynamo.
215+
And a failure will be triggered if recompilation ever happens due to that.
216+
Enabling this allow observing the dynamic shapes guards in the tl-parse artifact.
217+
"""
218+
211219
# TODO add a debug mode to fail
212220

213221
def compute_hash(self) -> str:
@@ -224,6 +232,7 @@ def compute_hash(self) -> str:
224232
"""
225233
factors: list[Any] = []
226234
factors.append(self.dynamic_shapes_type.value)
235+
factors.append(self.evaluate_guards)
227236
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
228237
return hash_str
229238

0 commit comments

Comments
 (0)