Skip to content

Commit 3500329

Browse files
committed
remove the bytecode hook and replace TorchCompileWrapperWithCustomDispatcher with TorchCompileGuardsStripWrapper
Signed-off-by: Laith Sakka <lsakka@meta.com>
1 parent c4768dc commit 3500329

File tree

8 files changed

+343
-296
lines changed

8 files changed

+343
-296
lines changed

tests/compile/piecewise/test_multiple_graphs.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
from vllm.forward_context import BatchDescriptor, set_forward_context
2323
from vllm.utils.torch_utils import is_torch_equal_or_newer
2424

25+
from ...utils import create_new_process_for_each_test
26+
2527
# This import automatically registers `torch.ops.silly.attention`
2628
from .. import silly_attention # noqa: F401
2729

@@ -193,6 +195,7 @@ def run_model(
193195

194196

195197
@pytest.mark.parametrize("use_inductor_graph_partition", [False, True])
198+
@create_new_process_for_each_test("spawn")
196199
def test_multi_graph_piecewise_compile(use_inductor_graph_partition: bool):
197200
if use_inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
198201
pytest.skip("inductor graph partition is only available in PyTorch 2.9+")

tests/compile/piecewise/test_simple.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
from vllm.forward_context import BatchDescriptor, set_forward_context
2222
from vllm.utils.torch_utils import is_torch_equal_or_newer
2323

24+
from ...utils import create_new_process_for_each_test
25+
2426
# This import automatically registers `torch.ops.silly.attention`
2527
from ..silly_attention import get_global_counter, reset_global_counter
2628

@@ -125,6 +127,7 @@ def _run_simple_model(
125127

126128
@pytest.mark.parametrize("use_inductor", [True, False])
127129
@torch.inference_mode()
130+
@create_new_process_for_each_test("spawn")
128131
def test_simple_piecewise_compile(use_inductor):
129132
_run_simple_model(
130133
splitting_ops=["silly::attention"],

tests/compile/piecewise/test_toy_llama.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929
from vllm.forward_context import BatchDescriptor, set_forward_context
3030
from vllm.utils.torch_utils import is_torch_equal_or_newer
3131

32+
from ...utils import create_new_process_for_each_test
33+
3234
# This import automatically registers `torch.ops.silly.attention`
3335
from .. import silly_attention # noqa: F401
3436

@@ -334,6 +336,7 @@ def run_model(llama_config, compile_config: CompilationConfig) -> torch.Tensor:
334336
("inductor", True), # Inductor, Inductor partition
335337
],
336338
)
339+
@create_new_process_for_each_test("spawn")
337340
def test_toy_llama(
338341
backend: str, use_inductor_graph_partition: bool, monkeypatch, tmp_path
339342
):
@@ -514,4 +517,8 @@ def benchmark():
514517

515518

516519
if __name__ == "__main__":
517-
benchmark()
520+
# Protect against subprocess reimport when using spawn_new_process_for_each_test
521+
import os
522+
523+
if os.environ.get("RUNNING_IN_SUBPROCESS") != "1":
524+
benchmark()

tests/compile/test_wrapper.py

Lines changed: 94 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -4,57 +4,112 @@
44

55
import torch
66

7-
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
8-
from vllm.config import CompilationMode
7+
from vllm.compilation.wrapper import TorchCompileWithNoGuardsWrapper
8+
from vllm.config import (
9+
CompilationConfig,
10+
CompilationMode,
11+
VllmConfig,
12+
set_current_vllm_config,
13+
)
914

1015

1116
class MyMod(torch.nn.Module):
1217
def forward(self, x: torch.Tensor, cache: torch.Tensor | None = None):
13-
if cache is not None:
14-
return x + cache
15-
return x * 2
18+
if x.size()[0] >= 4:
19+
return x * 2
20+
else:
21+
return x * 100
1622

1723

18-
class MyWrapper(TorchCompileWrapperWithCustomDispatcher):
24+
class MyWrapper(TorchCompileWithNoGuardsWrapper):
1925
def __init__(self, model):
2026
self.model = model
21-
compiled_callable = torch.compile(self.forward, backend="eager")
22-
super().__init__(
23-
compiled_callable, compilation_mode=CompilationMode.DYNAMO_TRACE_ONCE
24-
)
27+
super().__init__()
2528

26-
def forward(self, x: torch.Tensor, cache: torch.Tensor | None = None):
29+
def forward(self, x: torch.Tensor): # type: ignore[override]
2730
# this is the function to be compiled
28-
return self.model(x, cache)
29-
30-
def __call__(self, x: torch.Tensor, cache: torch.Tensor | None = None):
31-
# let torch.compile compile twice
32-
if len(self.compiled_codes) == 2:
33-
dispatch_id = 0 if cache is None else 1
34-
with self.dispatch_to_code(dispatch_id):
35-
return self.forward(x, cache)
36-
else:
37-
return self.compiled_callable(x, cache)
31+
return self.model(x)
3832

3933

4034
def test_torch_compile_wrapper():
41-
mod = MyMod()
42-
wrappers = []
43-
for i in range(3):
35+
"""Test basic functionality of TorchCompileWithNoGuardsWrapper."""
36+
# Create a proper vLLM config instead of mocking
37+
vllm_config = VllmConfig()
38+
vllm_config.compilation_config = CompilationConfig()
39+
vllm_config.compilation_config.mode = CompilationMode.DYNAMO_TRACE_ONCE
40+
vllm_config.compilation_config.backend = "inductor"
41+
42+
# Test DYNAMO_TRACE_ONCE
43+
with set_current_vllm_config(vllm_config):
44+
torch._dynamo.reset()
45+
mod = MyMod()
46+
wrapper = MyWrapper(mod)
47+
48+
# First call should trigger compilation
49+
x = torch.tensor([1, 2, 3, 4])
50+
torch._dynamo.mark_dynamic(x, 0)
51+
52+
result1 = wrapper(x)
53+
expected1 = torch.tensor([2, 4, 6, 8])
54+
assert torch.allclose(result1, expected1), (
55+
f"Expected {expected1}, got {result1}"
56+
)
57+
58+
# Second call should use compiled code
59+
x2 = torch.tensor([1, 2, 3])
60+
result2 = wrapper(x2)
61+
expected2 = torch.tensor([2, 4, 6])
62+
assert torch.allclose(result2, expected2), (
63+
f"Expected {expected2}, got {result2}"
64+
)
65+
66+
# without the wrapper result would be different.
67+
result3 = mod(x2)
68+
expected3 = torch.tensor([100, 200, 300])
69+
70+
assert torch.allclose(result3, expected3), (
71+
f"Expected {result3}, got {expected3}"
72+
)
73+
74+
# with STOCK_TORCH_COMPILE we do not remove guards.
75+
vllm_config.compilation_config.mode = CompilationMode.STOCK_TORCH_COMPILE
76+
torch._dynamo.reset()
77+
with set_current_vllm_config(vllm_config):
4478
torch._dynamo.reset()
79+
mod = MyMod()
4580
wrapper = MyWrapper(mod)
46-
wrappers.append(wrapper)
47-
x = torch.tensor([1])
48-
wrapper(x, None) # profile run, compile
49-
# create a cache tensor
50-
cache = torch.tensor([2])
51-
wrapper(x, cache) # warm up with cache, recompile
52-
53-
# for new input, dispatch to the compiled code directly
54-
new_x = torch.tensor([3])
55-
assert wrapper(new_x, None).item() == 6 # dispatch to the first compiled code
56-
assert wrapper(new_x, cache).item() == 5 # dispatch to the second compiled code
57-
58-
for wrapper in wrappers:
59-
# make sure they have independent compiled codes
60-
assert len(wrapper.compiled_codes) == 2
81+
82+
# First call should trigger compilation
83+
x = torch.tensor([1, 2, 3, 4])
84+
torch._dynamo.mark_dynamic(x, 0)
85+
86+
result1 = wrapper(x)
87+
expected1 = torch.tensor([2, 4, 6, 8])
88+
assert torch.allclose(result1, expected1), (
89+
f"Expected {expected1}, got {result1}"
90+
)
91+
92+
# Second call should triger another compilation
93+
x2 = torch.tensor([1, 2, 3])
94+
result2 = wrapper(x2)
95+
expected2 = torch.tensor([100, 200, 300])
96+
assert torch.allclose(result2, expected2), (
97+
f"Expected {expected2}, got {result2}"
98+
)
99+
100+
# NO_COMPILATION level not supported.
101+
vllm_config.compilation_config.mode = None
102+
torch._dynamo.reset()
103+
with set_current_vllm_config(vllm_config):
104+
torch._dynamo.reset()
105+
mod = MyMod()
106+
107+
try:
108+
wrapper = MyWrapper(mod)
109+
except Exception:
110+
return
111+
raise AssertionError("expected an exception to be raised")
112+
113+
114+
if __name__ == "__main__":
115+
test_torch_compile_wrapper()

tests/v1/e2e/test_spec_decode.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,14 @@ def model_name():
7575
return "meta-llama/Llama-3.1-8B-Instruct"
7676

7777

78+
@pytest.fixture(autouse=True)
79+
def reset_torch_dynamo():
80+
"""Reset torch dynamo cache before each test"""
81+
yield
82+
# Cleanup after test
83+
torch._dynamo.reset()
84+
85+
7886
@pytest.mark.parametrize(
7987
"speculative_config",
8088
[

0 commit comments

Comments
 (0)