-
-
Notifications
You must be signed in to change notification settings - Fork 11.7k
Add evaluate_guards option to DynamicShapesConfig #27432
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,7 +4,7 @@ | |
| import os | ||
| import sys | ||
| from abc import abstractmethod | ||
| from contextlib import contextmanager | ||
| from contextlib import contextmanager, nullcontext | ||
| from types import CodeType | ||
| from typing import Any | ||
|
|
||
|
|
@@ -13,6 +13,7 @@ | |
|
|
||
| import vllm.envs as envs | ||
| from vllm.config import CompilationMode, CUDAGraphMode, get_current_vllm_config | ||
| from vllm.config.compilation import DynamicShapesType | ||
| from vllm.logger import init_logger | ||
|
|
||
| logger = init_logger(__name__) | ||
|
|
@@ -98,6 +99,7 @@ def __init__(self): | |
| vllm_config = get_current_vllm_config() | ||
| self.vllm_config = vllm_config | ||
| mode = vllm_config.compilation_config.mode | ||
|
|
||
| if mode is None: | ||
| raise RuntimeError("Compilation mode cannot be NO_COMPILATION") | ||
|
|
||
|
|
@@ -107,23 +109,53 @@ def __init__(self): | |
| if isinstance(backend, str) and backend == "inductor": | ||
| options = vllm_config.compilation_config.inductor_compile_config | ||
|
|
||
| self.first_compile = True | ||
|
|
||
| ds_type = vllm_config.compilation_config.dynamic_shapes_config.type | ||
|
|
||
| if mode != CompilationMode.STOCK_TORCH_COMPILE: | ||
| # Drop all the guards. | ||
| options["guard_filter_fn"] = lambda x: [False for _ in x] | ||
| if vllm_config.compilation_config.dynamic_shapes_config.evaluate_guards: | ||
| assert not envs.VLLM_USE_BYTECODE_HOOK, ( | ||
| "compilation_config.dynamic_shapes_config.evaluate_guards " | ||
| "requires VLLM_USE_BYTECODE_HOOK=0. " | ||
| ) | ||
|
|
||
| # Validate that unbacked dynamic shapes require VLLM_USE_BYTECODE_HOOK=False | ||
| from vllm.compilation.decorators import DynamicShapesType | ||
| if envs.VLLM_USE_AOT_COMPILE: | ||
| # disabled until https://github.com/pytorch/pytorch/pull/169239 | ||
| # is picked up. | ||
| assert ds_type != DynamicShapesType.BACKED, ( | ||
| "evaluate_guards for backed shapes requires " | ||
| "VLLM_USE_AOT_COMPILE=False. " | ||
| ) | ||
|
|
||
| assert not envs.VLLM_USE_BYTECODE_HOOK, ( | ||
| "compilation_config.dynamic_shapes_config.evaluate_guards " | ||
| "requires VLLM_USE_BYTECODE_HOOK=0. " | ||
| ) | ||
|
|
||
| options["guard_filter_fn"] = lambda x: [ | ||
| entry.guard_type == "SHAPE_ENV" for entry in x | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this seems brittle..
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeh we shall if we dont have any. I will file issue and assign to me.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. actually vllm testd would fail I also found a an internal test that checks it def test_symbool_guards(
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sgtm |
||
| ] | ||
| else: | ||
| options["guard_filter_fn"] = lambda x: [False for _ in x] | ||
|
|
||
| ds_type = vllm_config.compilation_config.dynamic_shapes_config.type | ||
| compiled_ptr: Any = self.forward | ||
| # Validate that unbacked dynamic shapes require VLLM_USE_BYTECODE_HOOK=False | ||
|
|
||
| if ds_type == DynamicShapesType.UNBACKED: | ||
| if envs.VLLM_USE_BYTECODE_HOOK: | ||
| # reason is that bytecode does this hack torch._dynamo.eval_frame. | ||
| # remove_from_cache(self.original_code_object()) to force a new | ||
| # re-compilation. | ||
| raise ValueError( | ||
| "UNBACKED dynamic shapes require VLLM_USE_BYTECODE_HOOK=0. " | ||
| ) | ||
| # reason is that bytecode does torch._dynamo.eval_frame. | ||
| # remove_from_cache(self.original_code_object()) to force a new | ||
| # re-compilation. And if we use | ||
| # compiled_ptr = self.check_invariants_and_forward | ||
| # it will reset all entries. | ||
| assert envs.VLLM_USE_BYTECODE_HOOK, ( | ||
| "UNBACKED dynamic shapes requires VLLM_USE_BYTECODE_HOOK=0. " | ||
| ) | ||
| assert ( | ||
| not vllm_config.compilation_config.dynamic_shapes_config.evaluate_guards | ||
| ), "UNBACKED dynamic shapes do not add guards" | ||
|
|
||
| compiled_ptr = self.check_invariants_and_forward | ||
|
|
||
| if envs.VLLM_USE_AOT_COMPILE: | ||
|
|
@@ -173,7 +205,13 @@ def __call__(self, *args, **kwargs): | |
| with self._dispatch_to_compiled_code(): | ||
| return self.forward(*args, **kwargs) | ||
| else: | ||
| with _compilation_context(): | ||
| ctx = ( | ||
| nullcontext() | ||
| if self.first_compile | ||
| else torch.compiler.set_stance("fail_on_recompile") | ||
| ) | ||
| self.first_compile = False | ||
| with _compilation_context(), ctx: | ||
| return self._compiled_callable(*args, **kwargs) | ||
|
|
||
| @abstractmethod | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -249,7 +249,18 @@ class DynamicShapesConfig: | |
| backed/unbacked. | ||
| """ | ||
|
|
||
| # TODO add a debug mode to fail | ||
| evaluate_guards: bool = False | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think I want to get into a state where this is always on during vLLM warmup, and then optionally during runtime. This involves making sure the guards are actually correct (maybe this requires unbacked). Not for this PR, but we should do this in a follow-up.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. using unbacked basically just avoids the need of this, once we address the perf gp |
||
| """ | ||
| A debug mode to detect and fail if Dynamo ever specializes a dynamic shape by | ||
| guarding on it. When True, dynamic shape guards are not dropped from dynamo. | ||
| And a failure will be triggered if a recompilation ever happens due to that. | ||
| This mode requires VLLM_USE_BYTECODE_HOOK to be 0. | ||
| Enabling this allow observing the dynamic shapes guards in the tlparse | ||
| artifacts also. | ||
| When type is backed, aot_compile must be disabled for this mode to work. | ||
| until this change picked up https://github.com/pytorch/pytorch/pull/169239. | ||
|
|
||
| """ | ||
|
|
||
| def compute_hash(self) -> str: | ||
| """ | ||
|
|
@@ -358,8 +369,8 @@ class CompilationConfig: | |
| We use string to avoid serialization issues when using compilation in a | ||
| distributed setting. When the compilation mode is 1 or 2, the backend is | ||
| used for the compilation directly (it sees the whole graph). When the | ||
| compilation mode is 3, the backend supports both whole graph and piecewise | ||
| compilation, available backends include eager, inductor, and custom backends, | ||
| compilation mode is 3, the backend supports both whole graph and piecewise | ||
| compilation, available backends include eager, inductor, and custom backends, | ||
| the latter of which can be defined via `get_compile_backend`. Furthermore, | ||
| compilation is only piecewise if splitting ops is set accordingly and | ||
| use_inductor_graph_partition is off. Note that the default options for | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we only do this with
evaluate_guardson?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh is it because otherwise we drop all guards anyway?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also for backed size oblivious we do no t want to drop 0/1 guards