Skip to content

Commit 3f78104

Browse files
committed
add unbacked option
Signed-off-by: Laith Sakka <lsakka@meta.com>
1 parent b31059f commit 3f78104

File tree

6 files changed

+344
-21
lines changed

6 files changed

+344
-21
lines changed
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
import gc
5+
6+
import pytest
7+
import torch
8+
from torch.torch_version import TorchVersion
9+
10+
from vllm import LLM, SamplingParams
11+
from vllm.config.compilation import DynamicShapesType
12+
13+
14+
def cleanup_gpu_memory():
15+
"""Clean up GPU memory after each test"""
16+
gc.collect() # Clear Python objects
17+
torch.cuda.empty_cache() # Clear PyTorch GPU memory cache
18+
torch.cuda.synchronize() # Wait for all GPU operations to complete
19+
20+
21+
def get_test_models():
22+
"""Get list of models to test based on PyTorch version"""
23+
# Parse PyTorch version
24+
result = ["microsoft/DialoGPT-small", "gpt2", "facebook/opt-125m"]
25+
# Handle alpha versions by removing pre-release suffixes
26+
version_parts = torch.__version__.split("+")[0].split("a")[0]
27+
clean_version = version_parts.split("b")[0].split("rc")[0]
28+
if TorchVersion(clean_version) >= TorchVersion("2.10"):
29+
# Requires some fixes only available in PyTorch 2.10+
30+
result.append("Qwen/Qwen2-1.5B-Instruct")
31+
result.append("Qwen/Qwen2-7B-Instruct")
32+
result.append("openlm-research/open_llama_13b")
33+
34+
return result
35+
36+
37+
@pytest.mark.parametrize("model_name", get_test_models())
38+
def test_dynamic_shapes_compilation(monkeypatch, model_name):
39+
"""Test that all dynamic shapes types produce compiles"""
40+
print(f"\nTesting model: {model_name}")
41+
42+
monkeypatch.setenv("TOKENIZERS_PARALLELISM", "true")
43+
# Note USE_AOT_COMPILE fails https://github.com/vllm-project/vllm/issues/27040.
44+
monkeypatch.setenv("VLLM_USE_AOT_COMPILE", "0")
45+
46+
prompt = "Hello, my name is"
47+
results = {}
48+
49+
print("Testing EAGER (no compilation) baseline...")
50+
cleanup_gpu_memory()
51+
52+
eager_model = LLM(
53+
model=model_name,
54+
compilation_config={
55+
"level": 0, # NO_COMPILATION - eager mode
56+
},
57+
# gpu_memory_utilization=0.2,
58+
)
59+
60+
# Generate text with deterministic sampling parameters
61+
sampling_params = SamplingParams(
62+
max_tokens=10,
63+
temperature=0.0, # Deterministic generation
64+
seed=42, # Fixed seed for consistency
65+
)
66+
eager_output = eager_model.generate(prompt, sampling_params=sampling_params)
67+
results["EAGER"] = eager_output[0].outputs[0].text
68+
69+
# Cleanup model
70+
del eager_model
71+
cleanup_gpu_memory()
72+
73+
# Test all dynamic shapes types with compilation
74+
for shapes_type in [
75+
DynamicShapesType.BACKED,
76+
DynamicShapesType.UNBACKED,
77+
DynamicShapesType.BACKED_SIZE_OBLIVIOUS,
78+
]:
79+
print(f"Testing {shapes_type.name} dynamic shapes...")
80+
81+
# Initialize the model with specific dynamic shapes configuration
82+
model = LLM(
83+
model=model_name,
84+
compilation_config={
85+
"level": 3, # PIECEWISE compilation
86+
"dynamic_shapes_config": {
87+
"dynamic_shapes_type": shapes_type.value,
88+
"eval_dynamo_ds_guards": False,
89+
},
90+
},
91+
# gpu_memory_utilization=0.2,
92+
)
93+
94+
output = model.generate(prompt, sampling_params=sampling_params)
95+
96+
# Store results for comparison
97+
results[shapes_type.name] = output[0].outputs[0].text
98+
99+
# Cleanup model
100+
del model
101+
cleanup_gpu_memory()
102+
103+
# Verify all results are non-empty strings
104+
for shape_type, result in results.items():
105+
assert isinstance(result, str), f"{shape_type} should return a string"
106+
assert len(result.strip()) > 0, f"{shape_type} should generate non-empty text"
107+
108+
# Print results
109+
for shape_type, result in results.items():
110+
print(f"{shape_type}: '{result}'")
111+
112+
113+
if __name__ == "__main__":
114+
"""Run the test directly as a Python script"""
115+
import os
116+
117+
print("Running dynamic shapes compilation test...")
118+
119+
# Get test models based on PyTorch version
120+
test_models = get_test_models()
121+
print(f"Testing {len(test_models)} models: {test_models}")
122+
123+
# Create a mock monkeypatch object for environment variables
124+
class MockMonkeypatch:
125+
def setenv(self, key, value):
126+
os.environ[key] = value
127+
128+
monkeypatch = MockMonkeypatch()
129+
130+
# Run test for each model
131+
for model_name in test_models:
132+
try:
133+
print(f"\n{'=' * 60}")
134+
print(f"Testing model: {model_name}")
135+
print(f"{'=' * 60}")
136+
137+
test_dynamic_shapes_compilation(monkeypatch, model_name)
138+
139+
print(f"✅ Test passed for {model_name}")
140+
141+
except Exception as e:
142+
print(f"❌ Test failed for {model_name}: {e}")
143+
raise
144+
145+
print("\n🎉 All tests completed successfully!")

vllm/compilation/decorators.py

Lines changed: 43 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from vllm.compilation.counter import compilation_counter
2020
from vllm.compilation.wrapper import TorchCompileGuardsStripWrapper
2121
from vllm.config import CompilationMode, VllmConfig, set_current_vllm_config
22+
from vllm.config.compilation import DynamicShapesType
2223
from vllm.logger import init_logger
2324
from vllm.sequence import IntermediateTensors
2425
from vllm.utils.import_utils import resolve_obj_by_qualname
@@ -83,6 +84,7 @@ def support_torch_compile(
8384
*,
8485
dynamic_arg_dims: dict[str, int | list[int]] | None = None,
8586
enable_if: Callable[[VllmConfig], bool] | None = None,
87+
shape_invariants: Callable[..., None] = lambda *args, **kwargs: None,
8688
) -> Callable[[_T], _T] | _T:
8789
"""
8890
A decorator to add support for compiling the forward method of a class.
@@ -172,7 +174,9 @@ def cls_decorator_helper(cls: _T) -> _T:
172174
raise ValueError(
173175
f"Argument {k} not found in the forward method of {cls}"
174176
)
175-
return _support_torch_compile(cls, inferred_dynamic_arg_dims, enable_if)
177+
return _support_torch_compile(
178+
cls, inferred_dynamic_arg_dims, enable_if, shape_invariants
179+
)
176180

177181
if cls is not None:
178182
# use `support_torch_compile` as a decorator without arguments
@@ -213,6 +217,7 @@ def _support_torch_compile(
213217
cls: _T,
214218
dynamic_arg_dims: dict[str, int | list[int]],
215219
enable_if: Callable[[VllmConfig], bool] | None = None,
220+
shape_invariants: Callable[..., None] = lambda *args, **kwargs: None,
216221
) -> _T:
217222
"""
218223
A decorator to add support for compiling the forward method of a class.
@@ -233,11 +238,12 @@ def _support_torch_compile(
233238
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs):
234239
old_init(self, vllm_config=vllm_config, prefix=prefix, **kwargs)
235240
self.vllm_config = vllm_config
241+
self.compilation_config = self.vllm_config.compilation_config
236242
enable_compile = enable_if is None or enable_if(vllm_config)
237243
# for CompilationMode.STOCK_TORCH_COMPILE , the upper level model runner
238244
# will handle the compilation, so we don't need to do anything here.
239245
self.do_not_compile = (
240-
vllm_config.compilation_config.mode
246+
self.compilation_config.mode
241247
in [CompilationMode.NONE, CompilationMode.STOCK_TORCH_COMPILE]
242248
or not supports_dynamo()
243249
or _should_ignore_torch_compile(self.__class__)
@@ -246,29 +252,38 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs):
246252
if self.do_not_compile:
247253
return
248254

255+
self._check_shape_invariants = shape_invariants
256+
249257
compilation_counter.num_models_seen += 1
250258
self.compiled = False
251259
TorchCompileGuardsStripWrapper.__init__(self)
252260

253261
cls.__init__ = __init__
254262

255-
def _mark_dynamic_inputs(mod, *args, **kwargs):
263+
def _mark_dynamic_inputs(mod, dynamic_shapes_type, *args, **kwargs):
264+
def mark_dynamic(arg, dims):
265+
if dynamic_shapes_type == DynamicShapesType.UNBACKED:
266+
torch._dynamo.decorators.mark_unbacked(arg, dims)
267+
else:
268+
torch._dynamo.mark_dynamic(arg, dims)
269+
256270
sig = inspect.signature(mod.__class__.forward)
257271
bound_args = sig.bind(mod, *args, **kwargs)
258272
bound_args.apply_defaults()
259273
for k, dims in dynamic_arg_dims.items():
260274
arg = bound_args.arguments.get(k)
275+
261276
if arg is not None:
262277
dims = [dims] if isinstance(dims, int) else dims
263278
if isinstance(arg, torch.Tensor):
264279
# In case dims is specified with negative indexing
265280
dims = [arg.ndim + dim if dim < 0 else dim for dim in dims]
266-
torch._dynamo.mark_dynamic(arg, dims)
281+
mark_dynamic(arg, dims)
267282
elif isinstance(arg, IntermediateTensors):
268283
for tensor in arg.tensors.values():
269284
# In case dims is specified with negative indexing
270285
dims = [tensor.ndim + dim if dim < 0 else dim for dim in dims]
271-
torch._dynamo.mark_dynamic(tensor, dims)
286+
mark_dynamic(tensor, dims)
272287
else:
273288
raise ValueError(
274289
"Unsupported dynamic dimensions"
@@ -286,6 +301,7 @@ def __call__(self, *args, **kwargs):
286301
if getattr(self, "aot_compiled_fn", None) is not None:
287302
return self.aot_compiled_fn(self, *args, **kwargs)
288303

304+
ds_type = self.compilation_config.dynamic_shapes_config.dynamic_shapes_type
289305
cache_dir = None
290306
aot_compilation_path = None
291307
if envs.VLLM_USE_AOT_COMPILE:
@@ -300,6 +316,14 @@ def __call__(self, *args, **kwargs):
300316
serialized backend artifacts), then we need to generate a new AOT
301317
compile artifact from scratch.
302318
"""
319+
# Validate that AOT compile is not used with unbacked dynamic
320+
# shapes. aot_compile re-allocates backed symbols post dynamo!
321+
if ds_type == DynamicShapesType.UNBACKED:
322+
raise ValueError(
323+
"AOT compilation is not compatible with UNBACKED dynamic shapes. "
324+
"Please use BACKED or BACKED_SIZE_OBLIVIOUS dynamic shapes type "
325+
"when VLLM_USE_AOT_COMPILE is enabled."
326+
)
303327
from .caching import compilation_config_hash_factors
304328

305329
factors: list[str] = compilation_config_hash_factors(self.vllm_config)
@@ -348,7 +372,12 @@ def __call__(self, *args, **kwargs):
348372
# This is the path for the first compilation.
349373

350374
# the first compilation needs to have dynamic shapes marked
351-
_mark_dynamic_inputs(self, *args, **kwargs)
375+
_mark_dynamic_inputs(
376+
self,
377+
ds_type,
378+
*args,
379+
**kwargs,
380+
)
352381

353382
# here, it is the starting point of the `torch.compile` process
354383
start_monitoring_torch_compile(self.vllm_config)
@@ -365,9 +394,7 @@ def __call__(self, *args, **kwargs):
365394
# properly when any of these files change.
366395

367396
# 1. the file containing the top-level forward function
368-
self.vllm_config.compilation_config.traced_files.add(
369-
original_code_object.co_filename
370-
)
397+
self.compilation_config.traced_files.add(original_code_object.co_filename)
371398

372399
# 2. every time Dynamo sees a function call, it will inline
373400
# the function by calling InliningInstructionTranslator.inline_call_
@@ -377,7 +404,7 @@ def __call__(self, *args, **kwargs):
377404

378405
def patched_inline_call(self_):
379406
code = self_.f_code
380-
self.vllm_config.compilation_config.traced_files.add(code.co_filename)
407+
self.compilation_config.traced_files.add(code.co_filename)
381408
return inline_call(self_)
382409

383410
# Disable the C++ compilation of symbolic shape guards. C++-fication
@@ -393,12 +420,18 @@ def patched_inline_call(self_):
393420
# if the config doesn't exist
394421
logger.debug("enable_cpp_symbolic_shape_guards config not available")
395422

423+
# Prepare backed_size_oblivious config patch if needed
424+
fx_config_patches = {}
425+
if ds_type == DynamicShapesType.BACKED_SIZE_OBLIVIOUS:
426+
fx_config_patches["backed_size_oblivious"] = True
427+
396428
with (
397429
patch.object(
398430
InliningInstructionTranslator, "inline_call_", patched_inline_call
399431
),
400432
torch._dynamo.config.patch(**dynamo_config_patches),
401433
maybe_use_cudagraph_partition_wrapper(self.vllm_config),
434+
torch.fx.experimental._config.patch(**fx_config_patches),
402435
_torch27_patch_tensor_subclasses(),
403436
):
404437
if envs.VLLM_USE_AOT_COMPILE:

vllm/compilation/wrapper.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,12 @@ class TorchCompileGuardsStripWrapper:
2222
since we drop all guards.
2323
"""
2424

25+
def check_invariantes_and_forward(self, *args, **kwargs):
26+
assert hasattr(self, "_check_shape_invariants")
27+
self._check_shape_invariants(*args, **kwargs)
28+
29+
return self.forward(*args, **kwargs)
30+
2531
def __init__(self):
2632
self.compiled = False
2733

@@ -50,7 +56,7 @@ def __init__(self):
5056
logger.warning(msg)
5157

5258
self._compiled_callable = torch.compile(
53-
self.forward,
59+
self.check_invariantes_and_forward,
5460
fullgraph=True,
5561
backend=backend,
5662
options=options,

0 commit comments

Comments
 (0)