Skip to content

Commit 84b93a5

Browse files
committed
remove the bytecode hook and replace TorchCompileWrapperWithCustomDispatcher with TorchCompileGuardsStripWrapper
Signed-off-by: Laith Sakka <lsakka@meta.com>
1 parent b230286 commit 84b93a5

File tree

10 files changed

+422
-236
lines changed

10 files changed

+422
-236
lines changed

tests/compile/piecewise/test_multiple_graphs.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
from vllm.forward_context import BatchDescriptor, set_forward_context
2323
from vllm.utils.torch_utils import is_torch_equal_or_newer
2424

25+
from ...utils import create_new_process_for_each_test
26+
2527
# This import automatically registers `torch.ops.silly.attention`
2628
from .. import silly_attention # noqa: F401
2729

@@ -193,7 +195,14 @@ def run_model(
193195

194196

195197
@pytest.mark.parametrize("use_inductor_graph_partition", [False, True])
196-
def test_multi_graph_piecewise_compile(use_inductor_graph_partition: bool):
198+
@pytest.mark.parametrize("use_bytecode_hook", [True, False])
199+
@create_new_process_for_each_test("spawn")
200+
def test_multi_graph_piecewise_compile(
201+
use_inductor_graph_partition: bool, use_bytecode_hook: bool, monkeypatch
202+
):
203+
# Set the environment variable for this test
204+
monkeypatch.setenv("VLLM_USE_BYTECODE_HOOK", "1" if use_bytecode_hook else "0")
205+
197206
if use_inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
198207
pytest.skip("inductor graph partition is only available in PyTorch 2.9+")
199208

tests/compile/piecewise/test_simple.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
from vllm.forward_context import BatchDescriptor, set_forward_context
2222
from vllm.utils.torch_utils import is_torch_equal_or_newer
2323

24+
from ...utils import create_new_process_for_each_test
25+
2426
# This import automatically registers `torch.ops.silly.attention`
2527
from ..silly_attention import get_global_counter, reset_global_counter
2628

@@ -124,6 +126,7 @@ def _run_simple_model(
124126

125127
@pytest.mark.parametrize("use_inductor", [True, False])
126128
@torch.inference_mode()
129+
@create_new_process_for_each_test("spawn")
127130
def test_simple_piecewise_compile(use_inductor):
128131
_run_simple_model(
129132
splitting_ops=["silly::attention"],

tests/compile/piecewise/test_toy_llama.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929
from vllm.forward_context import BatchDescriptor, set_forward_context
3030
from vllm.utils.torch_utils import is_torch_equal_or_newer
3131

32+
from ...utils import create_new_process_for_each_test
33+
3234
# This import automatically registers `torch.ops.silly.attention`
3335
from .. import silly_attention # noqa: F401
3436

@@ -334,6 +336,7 @@ def run_model(llama_config, compile_config: CompilationConfig) -> torch.Tensor:
334336
("inductor", True), # Inductor, Inductor partition
335337
],
336338
)
339+
@create_new_process_for_each_test("spawn")
337340
def test_toy_llama(
338341
backend: str, use_inductor_graph_partition: bool, monkeypatch, tmp_path
339342
):
@@ -513,4 +516,8 @@ def benchmark():
513516

514517

515518
if __name__ == "__main__":
516-
benchmark()
519+
# Protect against subprocess reimport when using spawn_new_process_for_each_test
520+
import os
521+
522+
if os.environ.get("RUNNING_IN_SUBPROCESS") != "1":
523+
benchmark()

tests/compile/test_wrapper.py

Lines changed: 115 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -2,59 +2,134 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

44

5+
import os
6+
7+
import pytest
58
import torch
69

7-
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
8-
from vllm.config import CompilationMode
10+
from vllm.compilation.wrapper import TorchCompileWithNoGuardsWrapper
11+
from vllm.config import (
12+
CompilationConfig,
13+
CompilationMode,
14+
VllmConfig,
15+
set_current_vllm_config,
16+
)
917

1018

1119
class MyMod(torch.nn.Module):
1220
def forward(self, x: torch.Tensor, cache: torch.Tensor | None = None):
13-
if cache is not None:
14-
return x + cache
15-
return x * 2
21+
if x.size()[0] >= 4:
22+
return x * 2
23+
else:
24+
return x * 100
1625

1726

18-
class MyWrapper(TorchCompileWrapperWithCustomDispatcher):
27+
class MyWrapper(TorchCompileWithNoGuardsWrapper):
1928
def __init__(self, model):
2029
self.model = model
21-
compiled_callable = torch.compile(self.forward, backend="eager")
22-
super().__init__(
23-
compiled_callable, compilation_mode=CompilationMode.DYNAMO_TRACE_ONCE
24-
)
30+
super().__init__()
2531

26-
def forward(self, x: torch.Tensor, cache: torch.Tensor | None = None):
32+
def forward(self, x: torch.Tensor): # type: ignore[override]
2733
# this is the function to be compiled
28-
return self.model(x, cache)
29-
30-
def __call__(self, x: torch.Tensor, cache: torch.Tensor | None = None):
31-
# let torch.compile compile twice
32-
if len(self.compiled_codes) == 2:
33-
dispatch_id = 0 if cache is None else 1
34-
with self.dispatch_to_code(dispatch_id):
35-
return self.forward(x, cache)
36-
else:
37-
return self.compiled_callable(x, cache)
34+
return self.model(x)
35+
3836

37+
@pytest.mark.parametrize("use_bytecode_hook", [True, False])
38+
def test_torch_compile_wrapper(use_bytecode_hook, monkeypatch):
39+
"""Test basic functionality of TorchCompileWithNoGuardsWrapper."""
40+
# Set the environment variable for this test
41+
monkeypatch.setenv("VLLM_USE_BYTECODE_HOOK", "1" if use_bytecode_hook else "0")
3942

40-
def test_torch_compile_wrapper():
41-
mod = MyMod()
42-
wrappers = []
43-
for i in range(3):
43+
# Create a proper vLLM config instead of mocking
44+
vllm_config = VllmConfig()
45+
vllm_config.compilation_config = CompilationConfig()
46+
vllm_config.compilation_config.mode = CompilationMode.DYNAMO_TRACE_ONCE
47+
vllm_config.compilation_config.backend = "inductor"
48+
49+
# Test DYNAMO_TRACE_ONCE
50+
with set_current_vllm_config(vllm_config):
4451
torch._dynamo.reset()
52+
mod = MyMod()
53+
wrapper = MyWrapper(mod)
54+
55+
# First call should trigger compilation
56+
x = torch.tensor([1, 2, 3, 4])
57+
torch._dynamo.mark_dynamic(x, 0)
58+
59+
result1 = wrapper(x)
60+
expected1 = torch.tensor([2, 4, 6, 8])
61+
assert torch.allclose(result1, expected1), (
62+
f"Expected {expected1}, got {result1}"
63+
)
64+
65+
# Second call should use compiled code
66+
x2 = torch.tensor([1, 2, 3])
67+
result2 = wrapper(x2)
68+
expected2 = torch.tensor([2, 4, 6])
69+
assert torch.allclose(result2, expected2), (
70+
f"Expected {expected2}, got {result2}"
71+
)
72+
73+
# without the wrapper result would be different.
74+
result3 = mod(x2)
75+
expected3 = torch.tensor([100, 200, 300])
76+
77+
assert torch.allclose(result3, expected3), (
78+
f"Expected {result3}, got {expected3}"
79+
)
80+
81+
# with STOCK_TORCH_COMPILE we do not remove guards.
82+
vllm_config.compilation_config.mode = CompilationMode.STOCK_TORCH_COMPILE
83+
torch._dynamo.reset()
84+
with set_current_vllm_config(vllm_config):
85+
mod = MyMod()
4586
wrapper = MyWrapper(mod)
46-
wrappers.append(wrapper)
47-
x = torch.tensor([1])
48-
wrapper(x, None) # profile run, compile
49-
# create a cache tensor
50-
cache = torch.tensor([2])
51-
wrapper(x, cache) # warm up with cache, recompile
52-
53-
# for new input, dispatch to the compiled code directly
54-
new_x = torch.tensor([3])
55-
assert wrapper(new_x, None).item() == 6 # dispatch to the first compiled code
56-
assert wrapper(new_x, cache).item() == 5 # dispatch to the second compiled code
57-
58-
for wrapper in wrappers:
59-
# make sure they have independent compiled codes
60-
assert len(wrapper.compiled_codes) == 2
87+
88+
# First call should trigger compilation
89+
x = torch.tensor([1, 2, 3, 4])
90+
torch._dynamo.mark_dynamic(x, 0)
91+
92+
result1 = wrapper(x)
93+
expected1 = torch.tensor([2, 4, 6, 8])
94+
assert torch.allclose(result1, expected1), (
95+
f"Expected {expected1}, got {result1}"
96+
)
97+
98+
# Second call should triger another compilation
99+
x2 = torch.tensor([1, 2, 3])
100+
result2 = wrapper(x2)
101+
expected2 = torch.tensor([100, 200, 300])
102+
assert torch.allclose(result2, expected2), (
103+
f"Expected {expected2}, got {result2}"
104+
)
105+
106+
# NO_COMPILATION level not supported.
107+
vllm_config.compilation_config.mode = None
108+
torch._dynamo.reset()
109+
with set_current_vllm_config(vllm_config):
110+
torch._dynamo.reset()
111+
mod = MyMod()
112+
113+
try:
114+
wrapper = MyWrapper(mod)
115+
except Exception:
116+
return
117+
raise AssertionError("expected an exception to be raised")
118+
119+
120+
if __name__ == "__main__":
121+
# Run with both parameter values
122+
123+
class MockMonkeypatch:
124+
def setenv(self, name, value):
125+
os.environ[name] = value
126+
127+
mp = MockMonkeypatch()
128+
129+
print("Testing with VLLM_USE_BYTECODE_HOOK=False")
130+
test_torch_compile_wrapper(False, mp)
131+
132+
print("Testing with VLLM_USE_BYTECODE_HOOK=True")
133+
test_torch_compile_wrapper(True, mp)
134+
135+
print("All tests passed!")

tests/models/multimodal/generation/test_qwen2_5_vl.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ def qwen2_5_vl_chat_template(*query):
3434
@pytest.mark.parametrize("num_frames", [16])
3535
@pytest.mark.parametrize("dtype", [target_dtype])
3636
@pytest.mark.parametrize("max_tokens", [128])
37+
@pytest.mark.parametrize("use_bytecode_hook", [True, False])
3738
def test_qwen2_5_vl_evs_functionality(
3839
vllm_runner,
3940
video_assets,
@@ -42,10 +43,14 @@ def test_qwen2_5_vl_evs_functionality(
4243
num_frames: int,
4344
dtype: str,
4445
max_tokens: int,
46+
use_bytecode_hook: bool,
47+
monkeypatch,
4548
) -> None:
4649
"""Test EVS (Efficient Video Sampling) functionality with different
4750
pruning rates.
4851
"""
52+
# Set the environment variable for this test
53+
monkeypatch.setenv("VLLM_USE_BYTECODE_HOOK", "1" if use_bytecode_hook else "0")
4954

5055
# Sample frames from video assets
5156
sampled_vids = [
@@ -86,6 +91,7 @@ def test_qwen2_5_vl_evs_functionality(
8691
@pytest.mark.parametrize("num_frames", [16])
8792
@pytest.mark.parametrize("dtype", [target_dtype])
8893
@pytest.mark.parametrize("max_tokens", [128])
94+
@pytest.mark.parametrize("use_bytecode_hook", [True, False])
8995
def test_qwen2_5_vl_evs_batched_videos(
9096
vllm_runner,
9197
video_assets,
@@ -94,6 +100,8 @@ def test_qwen2_5_vl_evs_batched_videos(
94100
num_frames: int,
95101
dtype: str,
96102
max_tokens: int,
103+
use_bytecode_hook: bool,
104+
monkeypatch,
97105
) -> None:
98106
"""Test EVS functionality with batched videos.
99107
@@ -102,6 +110,8 @@ def test_qwen2_5_vl_evs_batched_videos(
102110
2. Both pruning configurations work with multiple videos
103111
3. The model doesn't crash when processing multiple videos simultaneously
104112
"""
113+
# Set the environment variable for this test
114+
monkeypatch.setenv("VLLM_USE_BYTECODE_HOOK", "1" if use_bytecode_hook else "0")
105115
# Sample frames from video assets
106116
sampled_vids = [
107117
sample_frames_from_video(asset.np_ndarrays, num_frames)

tests/v1/e2e/test_spec_decode.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,14 @@ def model_name():
7575
return "meta-llama/Llama-3.1-8B-Instruct"
7676

7777

78+
@pytest.fixture(autouse=True)
79+
def reset_torch_dynamo():
80+
"""Reset torch dynamo cache before each test"""
81+
yield
82+
# Cleanup after test
83+
torch._dynamo.reset()
84+
85+
7886
@pytest.mark.parametrize(
7987
"speculative_config",
8088
[

0 commit comments

Comments
 (0)