Skip to content

Commit cc295da

Browse files
committed
wip
Signed-off-by: Laith Sakka <lsakka@meta.com>
1 parent bcc0f99 commit cc295da

File tree

7 files changed

+328
-18
lines changed

7 files changed

+328
-18
lines changed
Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
import gc
5+
6+
import pytest
7+
import torch
8+
from torch.torch_version import TorchVersion
9+
10+
from vllm import LLM, SamplingParams
11+
from vllm.config.compilation import DynamicShapesType
12+
13+
14+
def cleanup_gpu_memory():
15+
"""Clean up GPU memory after each test"""
16+
gc.collect() # Clear Python objects
17+
torch.cuda.empty_cache() # Clear PyTorch GPU memory cache
18+
torch.cuda.synchronize() # Wait for all GPU operations to complete
19+
20+
21+
def get_test_models():
22+
"""Get list of models to test based on PyTorch version"""
23+
# Parse PyTorch version
24+
result = ["microsoft/DialoGPT-small", "gpt2", "facebook/opt-125m"]
25+
# Handle alpha versions by removing pre-release suffixes
26+
version_parts = torch.__version__.split('+')[0].split('a')[0]
27+
clean_version = version_parts.split('b')[0].split('rc')[0]
28+
if TorchVersion(clean_version) >= TorchVersion("2.10"):
29+
30+
# Requires some fixes only available in PyTorch 2.10+
31+
result.append("Qwen/Qwen2-1.5B-Instruct")
32+
result.append("Qwen/Qwen2-7B-Instruct")
33+
result.append("openlm-research/open_llama_13b")
34+
35+
return result
36+
37+
38+
@pytest.mark.parametrize("model_name", get_test_models())
39+
def test_dynamic_shapes_compilation(monkeypatch, model_name):
40+
"""Test that all dynamic shapes types produce compiles"""
41+
print(f"\nTesting model: {model_name}")
42+
43+
# monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
44+
monkeypatch.setenv("TOKENIZERS_PARALLELISM", "true")
45+
46+
prompt = "Hello, my name is"
47+
results = {}
48+
49+
print("Testing EAGER (no compilation) baseline...")
50+
cleanup_gpu_memory()
51+
52+
eager_model = LLM(
53+
model=model_name,
54+
compilation_config={
55+
"level": 0, # NO_COMPILATION - eager mode
56+
},
57+
# gpu_memory_utilization=0.2,
58+
)
59+
60+
# Generate text with deterministic sampling parameters
61+
sampling_params = SamplingParams(
62+
max_tokens=10,
63+
temperature=0.0, # Deterministic generation
64+
seed=42, # Fixed seed for consistency
65+
)
66+
eager_output = eager_model.generate(prompt,
67+
sampling_params=sampling_params)
68+
results["EAGER"] = eager_output[0].outputs[0].text
69+
70+
# Cleanup model
71+
del eager_model
72+
cleanup_gpu_memory()
73+
74+
# Test all dynamic shapes types with compilation
75+
for shapes_type in [
76+
DynamicShapesType.BACKED, DynamicShapesType.UNBACKED,
77+
DynamicShapesType.BACKED_SIZE_OBLIVIOUS
78+
]:
79+
print(f"Testing {shapes_type.name} dynamic shapes...")
80+
81+
# Initialize the model with specific dynamic shapes configuration
82+
model = LLM(
83+
model=model_name,
84+
compilation_config={
85+
"level": 3, # PIECEWISE compilation
86+
"dynamic_shapes_config": {
87+
"dynamic_shapes_type": shapes_type.value,
88+
"eval_dynamo_ds_guards": False,
89+
},
90+
},
91+
# gpu_memory_utilization=0.2,
92+
)
93+
94+
output = model.generate(prompt, sampling_params=sampling_params)
95+
96+
# Store results for comparison
97+
results[shapes_type.name] = output[0].outputs[0].text
98+
99+
# Cleanup model
100+
del model
101+
cleanup_gpu_memory()
102+
103+
# Verify all results are non-empty strings
104+
for shape_type, result in results.items():
105+
assert isinstance(result, str), f"{shape_type} should return a string"
106+
assert len(
107+
result.strip()) > 0, f"{shape_type} should generate non-empty text"
108+
109+
# Print results
110+
for shape_type, result in results.items():
111+
print(f"{shape_type}: '{result}'")
112+
113+
114+
if __name__ == "__main__":
115+
"""Run the test directly as a Python script"""
116+
import os
117+
118+
print("Running dynamic shapes compilation test...")
119+
120+
# Get test models based on PyTorch version
121+
test_models = get_test_models()
122+
print(f"Testing {len(test_models)} models: {test_models}")
123+
124+
# Create a mock monkeypatch object for environment variables
125+
class MockMonkeypatch:
126+
127+
def setenv(self, key, value):
128+
os.environ[key] = value
129+
130+
monkeypatch = MockMonkeypatch()
131+
132+
# Run test for each model
133+
for model_name in test_models:
134+
try:
135+
print(f"\n{'='*60}")
136+
print(f"Testing model: {model_name}")
137+
print(f"{'='*60}")
138+
139+
test_dynamic_shapes_compilation(monkeypatch, model_name)
140+
141+
print(f"✅ Test passed for {model_name}")
142+
143+
except Exception as e:
144+
print(f"❌ Test failed for {model_name}: {e}")
145+
raise
146+
147+
print("\n🎉 All tests completed successfully!")

vllm/compilation/decorators.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from vllm.compilation.counter import compilation_counter
1313
from vllm.compilation.wrapper import TorchCompileGuardsStripWrapper
1414
from vllm.config import CompilationLevel, VllmConfig
15+
from vllm.config.compilation import DynamicShapesType
1516
from vllm.logger import init_logger
1617
from vllm.sequence import IntermediateTensors
1718
from vllm.utils import supports_dynamo
@@ -78,6 +79,7 @@ def support_torch_compile(
7879
*,
7980
dynamic_arg_dims: Optional[dict[str, Union[int, list[int]]]] = None,
8081
enable_if: Optional[Callable[[VllmConfig], bool]] = None,
82+
shape_invariants: Callable[..., None] = lambda *args, **kwargs: None
8183
) -> Union[Callable[[_T], _T], _T]:
8284
"""
8385
A decorator to add support for compiling the forward method of a class.
@@ -164,7 +166,7 @@ def cls_decorator_helper(cls: _T) -> _T:
164166
raise ValueError(
165167
f"Argument {k} not found in the forward method of {cls}")
166168
return _support_torch_compile(cls, inferred_dynamic_arg_dims,
167-
enable_if)
169+
enable_if, shape_invariants)
168170

169171
if cls is not None:
170172
# use `support_torch_compile` as a decorator without arguments
@@ -178,7 +180,8 @@ def _support_torch_compile(
178180
cls: _T,
179181
dynamic_arg_dims: dict[str, Union[int, list[int]]],
180182
enable_if: Optional[Callable[[VllmConfig], bool]] = None,
181-
) -> _T:
183+
shape_invariants: Callable[...,
184+
None] = lambda *args, **kwargs: None) -> _T:
182185
"""
183186
A decorator to add support for compiling the forward method of a class.
184187
"""
@@ -209,31 +212,41 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = '', **kwargs):
209212
if self.do_not_compile:
210213
return
211214

215+
self._check_shape_invariants = shape_invariants
216+
212217
compilation_counter.num_models_seen += 1
213218
TorchCompileGuardsStripWrapper.__init__(self)
214219

215220
cls.__init__ = __init__
216221

217-
def _mark_dynamic_inputs(mod, *args, **kwargs):
222+
def _mark_dynamic_inputs(mod, dynamic_shapes_type, *args, **kwargs):
223+
224+
def mark_dynamic(arg, dims):
225+
if dynamic_shapes_type == DynamicShapesType.UNBACKED:
226+
torch._dynamo.decorators.mark_unbacked(arg, dims)
227+
else:
228+
torch._dynamo.mark_dynamic(arg, dims)
229+
218230
sig = inspect.signature(mod.__class__.forward)
219231
bound_args = sig.bind(mod, *args, **kwargs)
220232
bound_args.apply_defaults()
221233
for k, dims in dynamic_arg_dims.items():
222234
arg = bound_args.arguments.get(k)
235+
223236
if arg is not None:
224237
dims = [dims] if isinstance(dims, int) else dims
225238
if isinstance(arg, torch.Tensor):
226239
# In case dims is specified with negative indexing
227240
dims = [arg.ndim + dim if dim < 0 else dim for dim in dims]
228-
torch._dynamo.mark_dynamic(arg, dims)
241+
mark_dynamic(arg, dims)
229242
elif isinstance(arg, IntermediateTensors):
230243
for tensor in arg.tensors.values():
231244
# In case dims is specified with negative indexing
232245
dims = [
233246
tensor.ndim + dim if dim < 0 else dim
234247
for dim in dims
235248
]
236-
torch._dynamo.mark_dynamic(tensor, dims)
249+
mark_dynamic(tensor, dims)
237250
else:
238251
raise ValueError(
239252
"Unsupported dynamic dimensions"
@@ -251,8 +264,11 @@ def __call__(self, *args, **kwargs):
251264
return TorchCompileGuardsStripWrapper.__call__(
252265
self, *args, **kwargs)
253266

267+
_mark_dynamic_inputs(
268+
self, self.vllm_config.compilation_config.dynamic_shapes_config.
269+
dynamic_shapes_type, *args, **kwargs)
270+
254271
# This is the path for the first compilation.
255-
_mark_dynamic_inputs(self, *args, **kwargs)
256272

257273
# the first compilation needs to have dynamic shapes marked
258274
start_monitoring_torch_compile(self.vllm_config)

vllm/compilation/wrapper.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,11 @@ class TorchCompileGuardsStripWrapper:
2222
(Since we drop all guards)
2323
"""
2424

25+
def check_invariantes_and_forward(self, *args, **kwargs):
26+
self._check_shape_invariants(*args, **kwargs)
27+
28+
return self.forward(*args, **kwargs)
29+
2530
def __init__(self):
2631
self.compiled = False
2732

@@ -42,7 +47,7 @@ def __init__(self):
4247
options["guard_filter_fn"] = lambda x: [False for _ in x]
4348

4449
self._compiled_callable = torch.compile(
45-
self.forward,
50+
self.check_invariantes_and_forward,
4651
fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
4752
backend=backend,
4853
options=options,
@@ -54,9 +59,13 @@ def __call__(self, *args, **kwargs):
5459
method, for directly dispatching to the compiled code.
5560
"""
5661
if not self.compiled:
62+
# We check eagirly on the first compile as well.
63+
self.check_invariantes_and_forward(*args, **kwargs)
64+
5765
# Make sure a compilation is triggered by clearing dynamo cache.
5866
torch._dynamo.eval_frame.remove_from_cache(
5967
self.original_code_object())
68+
6069
self.compiled = True
6170

6271
# Disable the C++ compilation of symbolic shape guards. C++-fication

vllm/config/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1859,8 +1859,8 @@ class PoolerConfig:
18591859
"""
18601860
Maximum input length allowed for embedding generation. When set, allows
18611861
inputs longer than max_embed_len to be accepted for embedding models.
1862-
When an input exceeds max_embed_len, it will be handled according to
1863-
the original max_model_len validation logic.
1862+
When an input exceeds max_embed_len, it will be handled according to
1863+
the original max_model_len validation logic.
18641864
Defaults to None (i.e. set to max_model_len).
18651865
"""
18661866

0 commit comments

Comments
 (0)