Skip to content

Commit bcc0f99

Browse files
committed
remove the bytecode hook and replace TorchCompileWrapperWithCustomDispatcher with TorchCompileGuardsStripWrapper
Signed-off-by: Laith Sakka <lsakka@meta.com>
1 parent 8b32464 commit bcc0f99

File tree

4 files changed

+217
-245
lines changed

4 files changed

+217
-245
lines changed

tests/compile/test_wrapper.py

Lines changed: 87 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -5,60 +5,107 @@
55

66
import torch
77

8-
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
9-
from vllm.config import CompilationLevel
8+
from vllm.compilation.wrapper import TorchCompileGuardsStripWrapper
9+
from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig,
10+
set_current_vllm_config)
1011

1112

1213
class MyMod(torch.nn.Module):
1314

1415
def forward(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None):
15-
if cache is not None:
16-
return x + cache
17-
return x * 2
16+
if x.size()[0] >= 4:
17+
return x * 2
18+
else:
19+
return x * 100
1820

1921

20-
class MyWrapper(TorchCompileWrapperWithCustomDispatcher):
22+
class MyWrapper(TorchCompileGuardsStripWrapper):
2123

2224
def __init__(self, model):
2325
self.model = model
24-
compiled_callable = torch.compile(self.forward, backend="eager")
25-
super().__init__(compiled_callable,
26-
compilation_level=CompilationLevel.DYNAMO_ONCE)
26+
super().__init__()
2727

28-
def forward(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None):
28+
def forward(self, x: torch.Tensor): # type: ignore[override]
2929
# this is the function to be compiled
30-
return self.model(x, cache)
31-
32-
def __call__(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None):
33-
# let torch.compile compile twice
34-
if len(self.compiled_codes) == 2:
35-
dispatch_id = 0 if cache is None else 1
36-
with self.dispatch_to_code(dispatch_id):
37-
return self.forward(x, cache)
38-
else:
39-
return self.compiled_callable(x, cache)
30+
return self.model(x)
4031

4132

4233
def test_torch_compile_wrapper():
43-
mod = MyMod()
44-
wrappers = []
45-
for i in range(3):
34+
"""Test basic functionality of TorchCompileGuardsStripWrapper."""
35+
# Create a proper vLLM config instead of mocking
36+
vllm_config = VllmConfig()
37+
vllm_config.compilation_config = CompilationConfig()
38+
vllm_config.compilation_config.level = CompilationLevel.DYNAMO_ONCE
39+
vllm_config.compilation_config.backend = "inductor"
40+
41+
with set_current_vllm_config(vllm_config):
4642
torch._dynamo.reset()
43+
mod = MyMod()
4744
wrapper = MyWrapper(mod)
48-
wrappers.append(wrapper)
49-
x = torch.tensor([1])
50-
wrapper(x, None) # profile run, compile
51-
# create a cache tensor
52-
cache = torch.tensor([2])
53-
wrapper(x, cache) # warm up with cache, recompile
54-
55-
# for new input, dispatch to the compiled code directly
56-
new_x = torch.tensor([3])
57-
assert wrapper(new_x,
58-
None).item() == 6 # dispatch to the first compiled code
59-
assert wrapper(
60-
new_x, cache).item() == 5 # dispatch to the second compiled code
61-
62-
for wrapper in wrappers:
63-
# make sure they have independent compiled codes
64-
assert len(wrapper.compiled_codes) == 2
45+
46+
# First call should trigger compilation
47+
x = torch.tensor([1, 2, 3, 4])
48+
torch._dynamo.mark_dynamic(x, 0)
49+
50+
result1 = wrapper(x)
51+
expected1 = torch.tensor([2, 4, 6, 8])
52+
assert torch.allclose(
53+
result1, expected1), f"Expected {expected1}, got {result1}"
54+
55+
# Second call should use compiled code
56+
x2 = torch.tensor([1, 2, 3])
57+
result2 = wrapper(x2)
58+
expected2 = torch.tensor([2, 4, 6])
59+
assert torch.allclose(
60+
result2, expected2), f"Expected {expected2}, got {result2}"
61+
62+
# without the wrapper result would be different.
63+
result3 = mod(x2)
64+
expected3 = torch.tensor([100, 200, 300])
65+
66+
assert torch.allclose(
67+
result3, expected3), f"Expected {result3}, got {expected3}"
68+
69+
# Verify compilation flag is set
70+
assert wrapper.compiled is True
71+
72+
# with DYNAMO_AS_IS we do not remove guards.
73+
vllm_config.compilation_config.level = CompilationLevel.DYNAMO_AS_IS
74+
torch._dynamo.reset()
75+
with set_current_vllm_config(vllm_config):
76+
torch._dynamo.reset()
77+
mod = MyMod()
78+
wrapper = MyWrapper(mod)
79+
80+
# First call should trigger compilation
81+
x = torch.tensor([1, 2, 3, 4])
82+
torch._dynamo.mark_dynamic(x, 0)
83+
84+
result1 = wrapper(x)
85+
expected1 = torch.tensor([2, 4, 6, 8])
86+
assert torch.allclose(
87+
result1, expected1), f"Expected {expected1}, got {result1}"
88+
89+
# Second call should triger another compilation
90+
x2 = torch.tensor([1, 2, 3])
91+
result2 = wrapper(x2)
92+
expected2 = torch.tensor([100, 200, 300])
93+
assert torch.allclose(
94+
result2, expected2), f"Expected {expected2}, got {result2}"
95+
96+
# NO_COMPILATION level not supported.
97+
vllm_config.compilation_config.level = CompilationLevel.NO_COMPILATION
98+
torch._dynamo.reset()
99+
with set_current_vllm_config(vllm_config):
100+
torch._dynamo.reset()
101+
mod = MyMod()
102+
103+
try:
104+
wrapper = MyWrapper(mod)
105+
except Exception:
106+
return
107+
raise AssertionError("expected an exception to be raised")
108+
109+
110+
if __name__ == "__main__":
111+
test_torch_compile_wrapper()

vllm/compilation/decorators.py

Lines changed: 67 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from torch._dynamo.symbolic_convert import InliningInstructionTranslator
1111

1212
from vllm.compilation.counter import compilation_counter
13-
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
13+
from vllm.compilation.wrapper import TorchCompileGuardsStripWrapper
1414
from vllm.config import CompilationLevel, VllmConfig
1515
from vllm.logger import init_logger
1616
from vllm.sequence import IntermediateTensors
@@ -32,11 +32,11 @@ def ignore_torch_compile(cls: _T) -> _T:
3232
a support_torch_compile decorator, but we don't want to
3333
compile the class `cls` that inherits the parent class.
3434
This only ignores compiling the forward of the class the
35-
decorator is applied to.
35+
decorator is applied to.
3636
3737
If the parent has ignore_torch_compile but the child has
3838
support_torch_compile, the child will still be compiled.
39-
39+
4040
If the class has one or more submodules
4141
that have support_torch_compile decorator applied, compile will
4242
not be ignored for those submodules.
@@ -182,14 +182,14 @@ def _support_torch_compile(
182182
"""
183183
A decorator to add support for compiling the forward method of a class.
184184
"""
185-
if TorchCompileWrapperWithCustomDispatcher in cls.__bases__:
185+
if TorchCompileGuardsStripWrapper in cls.__bases__:
186186
# support decorating multiple times
187187
return cls
188188

189189
# take care of method resolution order
190190
# make sure super().__init__ is called on the base class
191-
# other than TorchCompileWrapperWithCustomDispatcher
192-
cls.__bases__ = cls.__bases__ + (TorchCompileWrapperWithCustomDispatcher, )
191+
# other than TorchCompileGuardsStripWrapper
192+
cls.__bases__ = cls.__bases__ + (TorchCompileGuardsStripWrapper, )
193193

194194
old_init = cls.__init__
195195

@@ -210,107 +210,83 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = '', **kwargs):
210210
return
211211

212212
compilation_counter.num_models_seen += 1
213-
TorchCompileWrapperWithCustomDispatcher.__init__(
214-
self, compilation_level=vllm_config.compilation_config.level)
213+
TorchCompileGuardsStripWrapper.__init__(self)
215214

216215
cls.__init__ = __init__
217216

217+
def _mark_dynamic_inputs(mod, *args, **kwargs):
218+
sig = inspect.signature(mod.__class__.forward)
219+
bound_args = sig.bind(mod, *args, **kwargs)
220+
bound_args.apply_defaults()
221+
for k, dims in dynamic_arg_dims.items():
222+
arg = bound_args.arguments.get(k)
223+
if arg is not None:
224+
dims = [dims] if isinstance(dims, int) else dims
225+
if isinstance(arg, torch.Tensor):
226+
# In case dims is specified with negative indexing
227+
dims = [arg.ndim + dim if dim < 0 else dim for dim in dims]
228+
torch._dynamo.mark_dynamic(arg, dims)
229+
elif isinstance(arg, IntermediateTensors):
230+
for tensor in arg.tensors.values():
231+
# In case dims is specified with negative indexing
232+
dims = [
233+
tensor.ndim + dim if dim < 0 else dim
234+
for dim in dims
235+
]
236+
torch._dynamo.mark_dynamic(tensor, dims)
237+
else:
238+
raise ValueError(
239+
"Unsupported dynamic dimensions"
240+
f" {dims} for argument {k} with type {type(arg)}.")
241+
218242
def __call__(self, *args, **kwargs):
219243
# torch.compiler.is_compiling() means we are inside the compilation
220244
# e.g. TPU has the compilation logic in model runner, so we don't
221245
# need to compile the model inside.
222246
if self.do_not_compile or torch.compiler.is_compiling():
223247
return self.forward(*args, **kwargs)
224248

249+
# This attributed is added by TorchCompileGuardsStripWrapper
250+
if self.compiled:
251+
return TorchCompileGuardsStripWrapper.__call__(
252+
self, *args, **kwargs)
253+
254+
# This is the path for the first compilation.
255+
_mark_dynamic_inputs(self, *args, **kwargs)
256+
225257
# the first compilation needs to have dynamic shapes marked
226-
if len(self.compiled_codes) < 1:
227-
sig = inspect.signature(self.__class__.forward)
228-
bound_args = sig.bind(self, *args, **kwargs)
229-
bound_args.apply_defaults()
230-
for k, dims in dynamic_arg_dims.items():
231-
arg = bound_args.arguments.get(k)
232-
if arg is not None:
233-
dims = [dims] if isinstance(dims, int) else dims
234-
if isinstance(arg, torch.Tensor):
235-
# In case dims is specified with negative indexing
236-
dims = [
237-
arg.ndim + dim if dim < 0 else dim for dim in dims
238-
]
239-
torch._dynamo.mark_dynamic(arg, dims)
240-
elif isinstance(arg, IntermediateTensors):
241-
for tensor in arg.tensors.values():
242-
# In case dims is specified with negative indexing
243-
dims = [
244-
tensor.ndim + dim if dim < 0 else dim
245-
for dim in dims
246-
]
247-
torch._dynamo.mark_dynamic(tensor, dims)
248-
else:
249-
raise ValueError(
250-
"Unsupported dynamic dimensions"
251-
f" {dims} for argument {k} with type {type(arg)}.")
252-
# here, it is the starting point of the `torch.compile` process
253-
start_monitoring_torch_compile(self.vllm_config)
254-
logger.debug("Start compiling function %s",
255-
self.original_code_object)
258+
start_monitoring_torch_compile(self.vllm_config)
259+
logger.debug("Start compiling function %s",
260+
self.original_code_object())
256261

257262
# if we don't use custom dispatcher, we can directly call the
258263
# compiled function and let torch.compile handle the dispatching,
259264
# with the overhead of guard evaluation and recompilation.
260-
if len(self.compiled_codes) < 1 or not self.use_custom_dispatcher:
261-
# it seems Dynamo reuse the compilation across instances,
262-
# while we need to make sure the compiled code is not reused.
263-
# we need to control all the compilation of the model.
264-
torch._dynamo.eval_frame.remove_from_cache(
265-
self.original_code_object)
266-
267-
# collect all relevant files traced by Dynamo,
268-
# so that the compilation cache can trigger re-compilation
269-
# properly when any of these files change.
270-
271-
# 1. the file containing the top-level forward function
265+
266+
# collect all relevant files traced by Dynamo,
267+
# so that the compilation cache can trigger re-compilation
268+
# properly when any of these files change.
269+
270+
# 1. the file containing the top-level forward function
271+
self.vllm_config.compilation_config.traced_files.add(
272+
self.original_code_object().co_filename)
273+
274+
# 2. every time Dynamo sees a function call, it will inline
275+
# the function by calling InliningInstructionTranslator.inline_call
276+
# we hijack this function to know all the functions called
277+
# during Dynamo tracing, and their corresponding files
278+
inline_call = InliningInstructionTranslator.inline_call
279+
280+
def patched_inline_call(parent, func, args, kwargs):
281+
code = func.get_code()
272282
self.vllm_config.compilation_config.traced_files.add(
273-
self.original_code_object.co_filename)
274-
275-
# 2. every time Dynamo sees a function call, it will inline
276-
# the function by calling InliningInstructionTranslator.inline_call
277-
# we hijack this function to know all the functions called
278-
# during Dynamo tracing, and their corresponding files
279-
inline_call = InliningInstructionTranslator.inline_call
280-
281-
def patched_inline_call(parent, func, args, kwargs):
282-
code = func.get_code()
283-
self.vllm_config.compilation_config.traced_files.add(
284-
code.co_filename)
285-
return inline_call(parent, func, args, kwargs)
286-
287-
# Disable the C++ compilation of symbolic shape guards. C++-fication
288-
# of symbolic shape guards can improve guard overhead. But, since
289-
# vllm skip guards anyways, setting this flag to False can improve
290-
# compile time.
291-
dynamo_config_patches = {}
292-
try:
293-
_ = torch._dynamo.config.enable_cpp_symbolic_shape_guards
294-
dynamo_config_patches[
295-
"enable_cpp_symbolic_shape_guards"] = False
296-
except AttributeError:
297-
# Note: this config is not available in torch 2.6, we can skip
298-
# if the config doesn't exist
299-
logger.debug(
300-
"enable_cpp_symbolic_shape_guards config not available")
301-
302-
with patch.object(InliningInstructionTranslator, 'inline_call',
303-
patched_inline_call), torch._dynamo.config.patch(
304-
**dynamo_config_patches):
305-
output = self.compiled_callable(*args, **kwargs)
306-
return output
307-
308-
# usually, capturing the model once is enough, and then we can
309-
# dispatch to the compiled code directly, without going through
310-
# the Dynamo guard mechanism.
311-
with self.dispatch_to_code(0):
312-
model_output = self.forward(*args, **kwargs)
313-
return model_output
283+
code.co_filename)
284+
return inline_call(parent, func, args, kwargs)
285+
286+
with patch.object(InliningInstructionTranslator, "inline_call",
287+
patched_inline_call):
288+
return TorchCompileGuardsStripWrapper.__call__(
289+
self, *args, **kwargs)
314290

315291
cls.__call__ = __call__
316292
return cls

0 commit comments

Comments
 (0)