From a46f6bf3f19e8377c455b7b3e47cd79153e99db5 Mon Sep 17 00:00:00 2001 From: Laith Sakka Date: Thu, 23 Oct 2025 11:32:53 -0700 Subject: [PATCH] add evaluate_guards option to DynamicShapesConfig Signed-off-by: Laith Sakka --- .../test_dynamic_shapes_compilation.py | 139 +++++++++++++++++- vllm/compilation/backends.py | 27 +++- vllm/compilation/decorators.py | 4 +- vllm/compilation/wrapper.py | 64 ++++++-- vllm/config/compilation.py | 17 ++- vllm/config/vllm.py | 2 +- 6 files changed, 224 insertions(+), 29 deletions(-) diff --git a/tests/compile/test_dynamic_shapes_compilation.py b/tests/compile/test_dynamic_shapes_compilation.py index c20aea822fe8..045d2fab9092 100644 --- a/tests/compile/test_dynamic_shapes_compilation.py +++ b/tests/compile/test_dynamic_shapes_compilation.py @@ -2,12 +2,21 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import gc +import tempfile +from contextlib import contextmanager import pytest import torch from vllm import LLM, SamplingParams -from vllm.config.compilation import CompilationMode, DynamicShapesType +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CompilationConfig, VllmConfig, set_current_vllm_config +from vllm.config.compilation import ( + CompilationMode, + DynamicShapesConfig, + DynamicShapesType, +) +from vllm.forward_context import set_forward_context from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.utils.torch_utils import is_torch_equal_or_newer @@ -29,18 +38,19 @@ def get_test_models(): ) @pytest.mark.parametrize("use_aot_compile", ["0"]) @pytest.mark.parametrize("use_bytecode_hook", [True, False]) +@pytest.mark.parametrize("evaluate_guards", [False, True]) @pytest.mark.skipif( not is_torch_equal_or_newer("2.10.0.dev"), reason="requires torch 2.10" ) def test_dynamic_shapes_compilation( - monkeypatch, model_name, shapes_type, use_aot_compile, use_bytecode_hook + monkeypatch, + model_name, + shapes_type, + use_aot_compile, + use_bytecode_hook, + evaluate_guards, ): """Test that all dynamic shapes types compile successfully""" - print( - f"\nTesting model: {model_name} with {shapes_type.name}, " - f"AOT compile: {use_aot_compile}, " - f"Bytecode hook: {use_bytecode_hook}" - ) if use_bytecode_hook and shapes_type == DynamicShapesType.UNBACKED: pytest.skip("UNBACKED dynamic shapes require VLLM_USE_BYTECODE_HOOK=0") @@ -58,6 +68,7 @@ def test_dynamic_shapes_compilation( "mode": CompilationMode.VLLM_COMPILE, "dynamic_shapes_config": { "type": shapes_type.value, + "evaluate_guards": evaluate_guards, }, }, ) @@ -86,3 +97,117 @@ def test_dynamic_shapes_compilation( torch.cuda.empty_cache() torch.cuda.synchronize() print("GPU memory cleared") + + +@pytest.mark.parametrize("use_aot_compile", ["0", "1"]) +@pytest.mark.parametrize( + "dynamic_shapes_type", + [ + DynamicShapesType.BACKED, + DynamicShapesType.BACKED_SIZE_OBLIVIOUS, + ], +) +@pytest.mark.parametrize("evaluate_guards", [False, True]) +def test_model_specialization_with_evaluate_guards( + monkeypatch, use_aot_compile, dynamic_shapes_type, evaluate_guards +): + """Test that evaluate_guards correctly detects shape specialization + violations. + """ + + if ( + use_aot_compile == "1" + and dynamic_shapes_type == DynamicShapesType.BACKED + and evaluate_guards + ): + pytest.skip("evaluate_guards for backed does not work with aot_compile =1") + + @support_torch_compile + class ModelWithSizeCheck(torch.nn.Module): + def __init__(self, **kwargs): + super().__init__() + + def forward(self, x: torch.Tensor): + # This will cause specialization - torch.compile will guard on + # sx.shape[0] + if x.shape[0] >= 10: + return x * 10 + else: + return x * 10 + + @support_torch_compile + class ModelWithOneSizeCheck(torch.nn.Module): + def __init__(self, **kwargs): + super().__init__() + + def forward(self, x: torch.Tensor): + # This will cause 0/1 specializations. + if x.shape[0] == 0: + return x * 10 + if x.shape[0] == 1: + return x * 10 + else: + return x * 10 + + @contextmanager + def use_vllm_config(vllm_config: VllmConfig): + with set_forward_context({}, vllm_config), set_current_vllm_config(vllm_config): + yield + + monkeypatch.setenv("TOKENIZERS_PARALLELISM", "true") + monkeypatch.setenv("VLLM_USE_AOT_COMPILE", use_aot_compile) + monkeypatch.setenv("VLLM_USE_BYTECODE_HOOK", "0") + + # Create vllm config with the desired settings + from vllm.config import CompilationMode + + vllm_config = VllmConfig( + compilation_config=CompilationConfig( + mode=CompilationMode.VLLM_COMPILE, + dynamic_shapes_config=DynamicShapesConfig( + type=dynamic_shapes_type, + evaluate_guards=evaluate_guards, + ), + ) + ) + + def test(model_class, input1, input2, is_01_specialization=False): + with ( + torch.no_grad(), + use_vllm_config(vllm_config), + tempfile.TemporaryDirectory() as tmpdirname, + ): + monkeypatch.setenv("VLLM_CACHE_ROOT", tmpdirname) + + model = model_class(vllm_config=vllm_config).cuda() + + model(input1) + + if evaluate_guards and ( + not ( + is_01_specialization + and dynamic_shapes_type == DynamicShapesType.BACKED + ) + ): + # This should fail because guards were added. + with pytest.raises(RuntimeError) as excinfo: + model(input2) + + # Expected failure - guard was violated + error_msg = str(excinfo.value) + assert ( + "GuardManager check failed" in error_msg + or "Detected recompile when torch.compile stance" in error_msg + ), error_msg + + else: + model(input2) + + test(ModelWithSizeCheck, torch.randn(20, 10).cuda(), torch.randn(5, 10).cuda()) + test(ModelWithSizeCheck, torch.randn(5, 10).cuda(), torch.randn(20, 10).cuda()) + test( + ModelWithOneSizeCheck, + torch.randn(20, 10).cuda(), + torch.randn(1, 10).cuda(), + is_01_specialization=True, + ) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 1773913d0b6c..359f36ee881b 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -26,6 +26,7 @@ should_split, ) from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig +from vllm.config.compilation import DynamicShapesType from vllm.config.utils import hash_factors from vllm.logger import init_logger from vllm.logging_utils import lazy @@ -752,6 +753,29 @@ def __call__( self.split_gm, submod_names_to_compile, self.vllm_config, self ).run(*example_inputs) + from torch._guards import detect_fake_mode + + fake_mode = detect_fake_mode() + + if ( + self.compilation_config.dynamic_shapes_config.evaluate_guards + and self.compilation_config.dynamic_shapes_config.type + == DynamicShapesType.BACKED + ): + from torch.utils._sympy.value_ranges import ValueRanges + + # Drop counter-0/1 specializations guards; for backed dynamic shapes, + # torch.compile will specialize for 0/1 inputs or otherwise guards that + # shape is >= 2. This is because it's really hard not to hit a check + # against 0/1. When we evaluate shape guards, we exclude checking those + # guards (We would fail always otherwise). + + # We avoid that by updating the ranges of backed sizes when the min is + # 2 for any, we assume it's 0. + for s, r in fake_mode.shape_env.var_to_range.items(): + if r.lower == 2: + fake_mode.shape_env.var_to_range[s] = ValueRanges(0, r.upper) + graph_path = os.path.join(local_cache_dir, "computation_graph.py") if not os.path.exists(graph_path): # code adapted from @@ -780,9 +804,6 @@ def __call__( ) # if we need to copy input buffers for cudagraph - from torch._guards import detect_fake_mode - - fake_mode = detect_fake_mode() fake_args = [ fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t for t in example_inputs diff --git a/vllm/compilation/decorators.py b/vllm/compilation/decorators.py index 6d9da1c488c6..bb67ef0c659d 100644 --- a/vllm/compilation/decorators.py +++ b/vllm/compilation/decorators.py @@ -392,7 +392,6 @@ def __call__(self, *args, **kwargs): factors.append(_model_hash_key(self.forward)) hash_key = hashlib.sha256(str(factors).encode()).hexdigest() - cache_dir = os.path.join( envs.VLLM_CACHE_ROOT, "torch_aot_compile", @@ -411,7 +410,8 @@ def __call__(self, *args, **kwargs): start_monitoring_torch_compile(self.vllm_config) loaded_fn = torch.compiler.load_compiled_function(f) _verify_source_unchanged(loaded_fn.source_info(), self.vllm_config) - loaded_fn.disable_guard_check() + if not self.compilation_config.dynamic_shapes_config.evaluate_guards: + loaded_fn.disable_guard_check() self.aot_compiled_fn = loaded_fn except Exception as e: if os.path.exists(aot_compilation_path): diff --git a/vllm/compilation/wrapper.py b/vllm/compilation/wrapper.py index b120c85bf232..024982386bdf 100644 --- a/vllm/compilation/wrapper.py +++ b/vllm/compilation/wrapper.py @@ -4,7 +4,7 @@ import os import sys from abc import abstractmethod -from contextlib import contextmanager +from contextlib import contextmanager, nullcontext from types import CodeType from typing import Any @@ -13,6 +13,7 @@ import vllm.envs as envs from vllm.config import CompilationMode, CUDAGraphMode, get_current_vllm_config +from vllm.config.compilation import DynamicShapesType from vllm.logger import init_logger logger = init_logger(__name__) @@ -98,6 +99,7 @@ def __init__(self): vllm_config = get_current_vllm_config() self.vllm_config = vllm_config mode = vllm_config.compilation_config.mode + if mode is None: raise RuntimeError("Compilation mode cannot be NO_COMPILATION") @@ -107,23 +109,53 @@ def __init__(self): if isinstance(backend, str) and backend == "inductor": options = vllm_config.compilation_config.inductor_compile_config + self.first_compile = True + + ds_type = vllm_config.compilation_config.dynamic_shapes_config.type + if mode != CompilationMode.STOCK_TORCH_COMPILE: # Drop all the guards. - options["guard_filter_fn"] = lambda x: [False for _ in x] + if vllm_config.compilation_config.dynamic_shapes_config.evaluate_guards: + assert not envs.VLLM_USE_BYTECODE_HOOK, ( + "compilation_config.dynamic_shapes_config.evaluate_guards " + "requires VLLM_USE_BYTECODE_HOOK=0. " + ) - # Validate that unbacked dynamic shapes require VLLM_USE_BYTECODE_HOOK=False - from vllm.compilation.decorators import DynamicShapesType + if envs.VLLM_USE_AOT_COMPILE: + # disabled until https://github.com/pytorch/pytorch/pull/169239 + # is picked up. + assert ds_type != DynamicShapesType.BACKED, ( + "evaluate_guards for backed shapes requires " + "VLLM_USE_AOT_COMPILE=False. " + ) + + assert not envs.VLLM_USE_BYTECODE_HOOK, ( + "compilation_config.dynamic_shapes_config.evaluate_guards " + "requires VLLM_USE_BYTECODE_HOOK=0. " + ) + + options["guard_filter_fn"] = lambda x: [ + entry.guard_type == "SHAPE_ENV" for entry in x + ] + else: + options["guard_filter_fn"] = lambda x: [False for _ in x] - ds_type = vllm_config.compilation_config.dynamic_shapes_config.type compiled_ptr: Any = self.forward + # Validate that unbacked dynamic shapes require VLLM_USE_BYTECODE_HOOK=False + if ds_type == DynamicShapesType.UNBACKED: - if envs.VLLM_USE_BYTECODE_HOOK: - # reason is that bytecode does this hack torch._dynamo.eval_frame. - # remove_from_cache(self.original_code_object()) to force a new - # re-compilation. - raise ValueError( - "UNBACKED dynamic shapes require VLLM_USE_BYTECODE_HOOK=0. " - ) + # reason is that bytecode does torch._dynamo.eval_frame. + # remove_from_cache(self.original_code_object()) to force a new + # re-compilation. And if we use + # compiled_ptr = self.check_invariants_and_forward + # it will reset all entries. + assert envs.VLLM_USE_BYTECODE_HOOK, ( + "UNBACKED dynamic shapes requires VLLM_USE_BYTECODE_HOOK=0. " + ) + assert ( + not vllm_config.compilation_config.dynamic_shapes_config.evaluate_guards + ), "UNBACKED dynamic shapes do not add guards" + compiled_ptr = self.check_invariants_and_forward if envs.VLLM_USE_AOT_COMPILE: @@ -173,7 +205,13 @@ def __call__(self, *args, **kwargs): with self._dispatch_to_compiled_code(): return self.forward(*args, **kwargs) else: - with _compilation_context(): + ctx = ( + nullcontext() + if self.first_compile + else torch.compiler.set_stance("fail_on_recompile") + ) + self.first_compile = False + with _compilation_context(), ctx: return self._compiled_callable(*args, **kwargs) @abstractmethod diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index da2c100dae3d..e742f9287d55 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -249,7 +249,18 @@ class DynamicShapesConfig: backed/unbacked. """ - # TODO add a debug mode to fail + evaluate_guards: bool = False + """ + A debug mode to detect and fail if Dynamo ever specializes a dynamic shape by + guarding on it. When True, dynamic shape guards are not dropped from dynamo. + And a failure will be triggered if a recompilation ever happens due to that. + This mode requires VLLM_USE_BYTECODE_HOOK to be 0. + Enabling this allow observing the dynamic shapes guards in the tlparse + artifacts also. + When type is backed, aot_compile must be disabled for this mode to work. + until this change picked up https://github.com/pytorch/pytorch/pull/169239. + + """ def compute_hash(self) -> str: """ @@ -358,8 +369,8 @@ class CompilationConfig: We use string to avoid serialization issues when using compilation in a distributed setting. When the compilation mode is 1 or 2, the backend is used for the compilation directly (it sees the whole graph). When the - compilation mode is 3, the backend supports both whole graph and piecewise - compilation, available backends include eager, inductor, and custom backends, + compilation mode is 3, the backend supports both whole graph and piecewise + compilation, available backends include eager, inductor, and custom backends, the latter of which can be defined via `get_compile_backend`. Furthermore, compilation is only piecewise if splitting ops is set accordingly and use_inductor_graph_partition is off. Note that the default options for diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index 34e70e3e134b..0f66c3ff01c9 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -65,7 +65,7 @@ class OptimizationLevel(IntEnum): """O0 : No optimization. no compilation, no cudagraphs, no other optimization, just starting up immediately""" O1 = 1 - """O1: Quick optimizations. Dynamo+Inductor compilation and Piecewise + """O1: Quick optimizations. Dynamo+Inductor compilation and Piecewise cudagraphs""" O2 = 2 """O2: Full optimizations. -O1 as well as Full and Piecewise cudagraphs."""