Skip to content

Commit b31059f

Browse files
committed
remove the bytecode hook and replace TorchCompileWrapperWithCustomDispatcher with TorchCompileGuardsStripWrapper
Signed-off-by: Laith Sakka <lsakka@meta.com>
1 parent 250fb1b commit b31059f

File tree

4 files changed

+241
-278
lines changed

4 files changed

+241
-278
lines changed

tests/compile/test_wrapper.py

Lines changed: 94 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -4,57 +4,112 @@
44

55
import torch
66

7-
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
8-
from vllm.config import CompilationMode
7+
from vllm.compilation.wrapper import TorchCompileGuardsStripWrapper
8+
from vllm.config import (
9+
CompilationConfig,
10+
CompilationMode,
11+
VllmConfig,
12+
set_current_vllm_config,
13+
)
914

1015

1116
class MyMod(torch.nn.Module):
1217
def forward(self, x: torch.Tensor, cache: torch.Tensor | None = None):
13-
if cache is not None:
14-
return x + cache
15-
return x * 2
18+
if x.size()[0] >= 4:
19+
return x * 2
20+
else:
21+
return x * 100
1622

1723

18-
class MyWrapper(TorchCompileWrapperWithCustomDispatcher):
24+
class MyWrapper(TorchCompileGuardsStripWrapper):
1925
def __init__(self, model):
2026
self.model = model
21-
compiled_callable = torch.compile(self.forward, backend="eager")
22-
super().__init__(
23-
compiled_callable, compilation_mode=CompilationMode.DYNAMO_TRACE_ONCE
24-
)
27+
super().__init__()
2528

26-
def forward(self, x: torch.Tensor, cache: torch.Tensor | None = None):
29+
def forward(self, x: torch.Tensor): # type: ignore[override]
2730
# this is the function to be compiled
28-
return self.model(x, cache)
29-
30-
def __call__(self, x: torch.Tensor, cache: torch.Tensor | None = None):
31-
# let torch.compile compile twice
32-
if len(self.compiled_codes) == 2:
33-
dispatch_id = 0 if cache is None else 1
34-
with self.dispatch_to_code(dispatch_id):
35-
return self.forward(x, cache)
36-
else:
37-
return self.compiled_callable(x, cache)
31+
return self.model(x)
3832

3933

4034
def test_torch_compile_wrapper():
41-
mod = MyMod()
42-
wrappers = []
43-
for i in range(3):
35+
"""Test basic functionality of TorchCompileGuardsStripWrapper."""
36+
# Create a proper vLLM config instead of mocking
37+
vllm_config = VllmConfig()
38+
vllm_config.compilation_config = CompilationConfig()
39+
vllm_config.compilation_config.mode = CompilationMode.DYNAMO_TRACE_ONCE
40+
vllm_config.compilation_config.backend = "inductor"
41+
42+
# Test DYNAMO_TRACE_ONCE
43+
with set_current_vllm_config(vllm_config):
44+
torch._dynamo.reset()
45+
mod = MyMod()
46+
wrapper = MyWrapper(mod)
47+
48+
# First call should trigger compilation
49+
x = torch.tensor([1, 2, 3, 4])
50+
torch._dynamo.mark_dynamic(x, 0)
51+
52+
result1 = wrapper(x)
53+
expected1 = torch.tensor([2, 4, 6, 8])
54+
assert torch.allclose(result1, expected1), (
55+
f"Expected {expected1}, got {result1}"
56+
)
57+
58+
# Second call should use compiled code
59+
x2 = torch.tensor([1, 2, 3])
60+
result2 = wrapper(x2)
61+
expected2 = torch.tensor([2, 4, 6])
62+
assert torch.allclose(result2, expected2), (
63+
f"Expected {expected2}, got {result2}"
64+
)
65+
66+
# without the wrapper result would be different.
67+
result3 = mod(x2)
68+
expected3 = torch.tensor([100, 200, 300])
69+
70+
assert torch.allclose(result3, expected3), (
71+
f"Expected {result3}, got {expected3}"
72+
)
73+
74+
# with STOCK_TORCH_COMPILE we do not remove guards.
75+
vllm_config.compilation_config.mode = CompilationMode.STOCK_TORCH_COMPILE
76+
torch._dynamo.reset()
77+
with set_current_vllm_config(vllm_config):
4478
torch._dynamo.reset()
79+
mod = MyMod()
4580
wrapper = MyWrapper(mod)
46-
wrappers.append(wrapper)
47-
x = torch.tensor([1])
48-
wrapper(x, None) # profile run, compile
49-
# create a cache tensor
50-
cache = torch.tensor([2])
51-
wrapper(x, cache) # warm up with cache, recompile
52-
53-
# for new input, dispatch to the compiled code directly
54-
new_x = torch.tensor([3])
55-
assert wrapper(new_x, None).item() == 6 # dispatch to the first compiled code
56-
assert wrapper(new_x, cache).item() == 5 # dispatch to the second compiled code
57-
58-
for wrapper in wrappers:
59-
# make sure they have independent compiled codes
60-
assert len(wrapper.compiled_codes) == 2
81+
82+
# First call should trigger compilation
83+
x = torch.tensor([1, 2, 3, 4])
84+
torch._dynamo.mark_dynamic(x, 0)
85+
86+
result1 = wrapper(x)
87+
expected1 = torch.tensor([2, 4, 6, 8])
88+
assert torch.allclose(result1, expected1), (
89+
f"Expected {expected1}, got {result1}"
90+
)
91+
92+
# Second call should triger another compilation
93+
x2 = torch.tensor([1, 2, 3])
94+
result2 = wrapper(x2)
95+
expected2 = torch.tensor([100, 200, 300])
96+
assert torch.allclose(result2, expected2), (
97+
f"Expected {expected2}, got {result2}"
98+
)
99+
100+
# NO_COMPILATION level not supported.
101+
vllm_config.compilation_config.mode = None
102+
torch._dynamo.reset()
103+
with set_current_vllm_config(vllm_config):
104+
torch._dynamo.reset()
105+
mod = MyMod()
106+
107+
try:
108+
wrapper = MyWrapper(mod)
109+
except Exception:
110+
return
111+
raise AssertionError("expected an exception to be raised")
112+
113+
114+
if __name__ == "__main__":
115+
test_torch_compile_wrapper()

vllm/compilation/decorators.py

Lines changed: 100 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
import vllm.envs as envs
1919
from vllm.compilation.counter import compilation_counter
20-
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
20+
from vllm.compilation.wrapper import TorchCompileGuardsStripWrapper
2121
from vllm.config import CompilationMode, VllmConfig, set_current_vllm_config
2222
from vllm.logger import init_logger
2323
from vllm.sequence import IntermediateTensors
@@ -217,14 +217,14 @@ def _support_torch_compile(
217217
"""
218218
A decorator to add support for compiling the forward method of a class.
219219
"""
220-
if TorchCompileWrapperWithCustomDispatcher in cls.__bases__:
220+
if TorchCompileGuardsStripWrapper in cls.__bases__:
221221
# support decorating multiple times
222222
return cls
223223

224224
# take care of method resolution order
225225
# make sure super().__init__ is called on the base class
226-
# other than TorchCompileWrapperWithCustomDispatcher
227-
cls.__bases__ = cls.__bases__ + (TorchCompileWrapperWithCustomDispatcher,)
226+
# other than TorchCompileGuardsStripWrapper
227+
cls.__bases__ = cls.__bases__ + (TorchCompileGuardsStripWrapper,)
228228

229229
old_init = cls.__init__
230230

@@ -247,19 +247,42 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs):
247247
return
248248

249249
compilation_counter.num_models_seen += 1
250-
TorchCompileWrapperWithCustomDispatcher.__init__(
251-
self, compilation_mode=vllm_config.compilation_config.mode
252-
)
250+
self.compiled = False
251+
TorchCompileGuardsStripWrapper.__init__(self)
253252

254253
cls.__init__ = __init__
255254

255+
def _mark_dynamic_inputs(mod, *args, **kwargs):
256+
sig = inspect.signature(mod.__class__.forward)
257+
bound_args = sig.bind(mod, *args, **kwargs)
258+
bound_args.apply_defaults()
259+
for k, dims in dynamic_arg_dims.items():
260+
arg = bound_args.arguments.get(k)
261+
if arg is not None:
262+
dims = [dims] if isinstance(dims, int) else dims
263+
if isinstance(arg, torch.Tensor):
264+
# In case dims is specified with negative indexing
265+
dims = [arg.ndim + dim if dim < 0 else dim for dim in dims]
266+
torch._dynamo.mark_dynamic(arg, dims)
267+
elif isinstance(arg, IntermediateTensors):
268+
for tensor in arg.tensors.values():
269+
# In case dims is specified with negative indexing
270+
dims = [tensor.ndim + dim if dim < 0 else dim for dim in dims]
271+
torch._dynamo.mark_dynamic(tensor, dims)
272+
else:
273+
raise ValueError(
274+
"Unsupported dynamic dimensions"
275+
f" {dims} for argument {k} with type {type(arg)}."
276+
)
277+
256278
def __call__(self, *args, **kwargs):
257279
# torch.compiler.is_compiling() means we are inside the compilation
258280
# e.g. TPU has the compilation logic in model runner, so we don't
259281
# need to compile the model inside.
260282
if self.do_not_compile or torch.compiler.is_compiling():
261283
return self.forward(*args, **kwargs)
262284

285+
# if aot_compiled_fn is set, just call it.
263286
if getattr(self, "aot_compiled_fn", None) is not None:
264287
return self.aot_compiled_fn(self, *args, **kwargs)
265288

@@ -318,102 +341,78 @@ def __call__(self, *args, **kwargs):
318341
)
319342
return self.aot_compiled_fn(self, *args, **kwargs)
320343

344+
if self.compiled:
345+
assert not envs.VLLM_USE_AOT_COMPILE
346+
return TorchCompileGuardsStripWrapper.__call__(self, *args, **kwargs)
347+
348+
# This is the path for the first compilation.
349+
321350
# the first compilation needs to have dynamic shapes marked
322-
if len(self.compiled_codes) < 1:
323-
sig = inspect.signature(self.__class__.forward)
324-
bound_args = sig.bind(self, *args, **kwargs)
325-
bound_args.apply_defaults()
326-
for k, dims in dynamic_arg_dims.items():
327-
arg = bound_args.arguments.get(k)
328-
if arg is not None:
329-
dims = [dims] if isinstance(dims, int) else dims
330-
if isinstance(arg, torch.Tensor):
331-
# In case dims is specified with negative indexing
332-
dims = [arg.ndim + dim if dim < 0 else dim for dim in dims]
333-
torch._dynamo.mark_dynamic(arg, dims)
334-
elif isinstance(arg, IntermediateTensors):
335-
for tensor in arg.tensors.values():
336-
# In case dims is specified with negative indexing
337-
dims = [
338-
tensor.ndim + dim if dim < 0 else dim for dim in dims
339-
]
340-
torch._dynamo.mark_dynamic(tensor, dims)
341-
else:
342-
raise ValueError(
343-
"Unsupported dynamic dimensions"
344-
f" {dims} for argument {k} with type {type(arg)}."
345-
)
346-
# here, it is the starting point of the `torch.compile` process
347-
start_monitoring_torch_compile(self.vllm_config)
348-
logger.debug("Start compiling function %s", self.original_code_object)
349-
350-
# if we don't use custom dispatcher, we can directly call the
351-
# compiled function and let torch.compile handle the dispatching,
352-
# with the overhead of guard evaluation and recompilation.
353-
if len(self.compiled_codes) < 1 or not self.use_custom_dispatcher:
354-
# it seems Dynamo reuse the compilation across instances,
355-
# while we need to make sure the compiled code is not reused.
356-
# we need to control all the compilation of the model.
357-
torch._dynamo.eval_frame.remove_from_cache(self.original_code_object)
358-
359-
# collect all relevant files traced by Dynamo,
360-
# so that the compilation cache can trigger re-compilation
361-
# properly when any of these files change.
362-
363-
# 1. the file containing the top-level forward function
364-
self.vllm_config.compilation_config.traced_files.add(
365-
self.original_code_object.co_filename
366-
)
351+
_mark_dynamic_inputs(self, *args, **kwargs)
367352

368-
# 2. every time Dynamo sees a function call, it will inline
369-
# the function by calling InliningInstructionTranslator.inline_call_
370-
# we hijack this function to know all the functions called
371-
# during Dynamo tracing, and their corresponding files
372-
inline_call = InliningInstructionTranslator.inline_call_
373-
374-
def patched_inline_call(self_):
375-
code = self_.f_code
376-
self.vllm_config.compilation_config.traced_files.add(code.co_filename)
377-
return inline_call(self_)
378-
379-
# Disable the C++ compilation of symbolic shape guards. C++-fication
380-
# of symbolic shape guards can improve guard overhead. But, since
381-
# vllm skip guards anyways, setting this flag to False can improve
382-
# compile time.
383-
dynamo_config_patches = {}
384-
try:
385-
_ = torch._dynamo.config.enable_cpp_symbolic_shape_guards
386-
dynamo_config_patches["enable_cpp_symbolic_shape_guards"] = False
387-
except AttributeError:
388-
# Note: this config is not available in torch 2.6, we can skip
389-
# if the config doesn't exist
390-
logger.debug("enable_cpp_symbolic_shape_guards config not available")
391-
392-
with (
393-
patch.object(
394-
InliningInstructionTranslator, "inline_call_", patched_inline_call
395-
),
396-
torch._dynamo.config.patch(**dynamo_config_patches),
397-
maybe_use_cudagraph_partition_wrapper(self.vllm_config),
398-
_torch27_patch_tensor_subclasses(),
399-
):
400-
if envs.VLLM_USE_AOT_COMPILE:
401-
self.aot_compiled_fn = self.aot_compile(*args, **kwargs)
402-
output = self.aot_compiled_fn(self, *args, **kwargs)
403-
assert aot_compilation_path is not None
404-
assert cache_dir is not None
405-
os.makedirs(cache_dir, exist_ok=True)
406-
self.aot_compiled_fn.save_compiled_function(aot_compilation_path)
407-
else:
408-
output = self.compiled_callable(*args, **kwargs)
409-
return output
410-
411-
# usually, capturing the model once is enough, and then we can
412-
# dispatch to the compiled code directly, without going through
413-
# the Dynamo guard mechanism.
414-
with self.dispatch_to_code(0):
415-
model_output = self.forward(*args, **kwargs)
416-
return model_output
353+
# here, it is the starting point of the `torch.compile` process
354+
start_monitoring_torch_compile(self.vllm_config)
355+
original_code_object = self.original_code_object()
356+
logger.debug("Start compiling function %s", original_code_object)
357+
358+
# it seems Dynamo reuse the compilation across instances,
359+
# while we need to make sure the compiled code is not reused.
360+
# we need to control all the compilation of the model.
361+
torch._dynamo.eval_frame.remove_from_cache(original_code_object)
362+
363+
# collect all relevant files traced by Dynamo,
364+
# so that the compilation cache can trigger re-compilation
365+
# properly when any of these files change.
366+
367+
# 1. the file containing the top-level forward function
368+
self.vllm_config.compilation_config.traced_files.add(
369+
original_code_object.co_filename
370+
)
371+
372+
# 2. every time Dynamo sees a function call, it will inline
373+
# the function by calling InliningInstructionTranslator.inline_call_
374+
# we hijack this function to know all the functions called
375+
# during Dynamo tracing, and their corresponding files
376+
inline_call = InliningInstructionTranslator.inline_call_
377+
378+
def patched_inline_call(self_):
379+
code = self_.f_code
380+
self.vllm_config.compilation_config.traced_files.add(code.co_filename)
381+
return inline_call(self_)
382+
383+
# Disable the C++ compilation of symbolic shape guards. C++-fication
384+
# of symbolic shape guards can improve guard overhead. But, since
385+
# vllm skip guards anyways, setting this flag to False can improve
386+
# compile time.
387+
dynamo_config_patches = {}
388+
try:
389+
_ = torch._dynamo.config.enable_cpp_symbolic_shape_guards
390+
dynamo_config_patches["enable_cpp_symbolic_shape_guards"] = False
391+
except AttributeError:
392+
# Note: this config is not available in torch 2.6, we can skip
393+
# if the config doesn't exist
394+
logger.debug("enable_cpp_symbolic_shape_guards config not available")
395+
396+
with (
397+
patch.object(
398+
InliningInstructionTranslator, "inline_call_", patched_inline_call
399+
),
400+
torch._dynamo.config.patch(**dynamo_config_patches),
401+
maybe_use_cudagraph_partition_wrapper(self.vllm_config),
402+
_torch27_patch_tensor_subclasses(),
403+
):
404+
if envs.VLLM_USE_AOT_COMPILE:
405+
self.aot_compiled_fn = self.aot_compile(*args, **kwargs)
406+
output = self.aot_compiled_fn(self, *args, **kwargs)
407+
assert aot_compilation_path is not None
408+
assert cache_dir is not None
409+
os.makedirs(cache_dir, exist_ok=True)
410+
self.aot_compiled_fn.save_compiled_function(aot_compilation_path)
411+
else:
412+
output = TorchCompileGuardsStripWrapper.__call__(self, *args, **kwargs)
413+
414+
self.compiled = True
415+
return output
417416

418417
cls.__call__ = __call__
419418
return cls

0 commit comments

Comments
 (0)