|
2 | 2 | # SPDX-FileCopyrightText: Copyright contributors to the vLLM project |
3 | 3 |
|
4 | 4 |
|
| 5 | +import os |
| 6 | + |
| 7 | +import pytest |
5 | 8 | import torch |
6 | 9 |
|
7 | | -from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher |
8 | | -from vllm.config import CompilationMode |
| 10 | +from vllm.compilation.wrapper import TorchCompileWithNoGuardsWrapper |
| 11 | +from vllm.config import ( |
| 12 | + CompilationConfig, |
| 13 | + CompilationMode, |
| 14 | + VllmConfig, |
| 15 | + set_current_vllm_config, |
| 16 | +) |
9 | 17 |
|
10 | 18 |
|
11 | 19 | class MyMod(torch.nn.Module): |
12 | 20 | def forward(self, x: torch.Tensor, cache: torch.Tensor | None = None): |
13 | | - if cache is not None: |
14 | | - return x + cache |
15 | | - return x * 2 |
| 21 | + if x.size()[0] >= 4: |
| 22 | + return x * 2 |
| 23 | + else: |
| 24 | + return x * 100 |
16 | 25 |
|
17 | 26 |
|
18 | | -class MyWrapper(TorchCompileWrapperWithCustomDispatcher): |
| 27 | +class MyWrapper(TorchCompileWithNoGuardsWrapper): |
19 | 28 | def __init__(self, model): |
20 | 29 | self.model = model |
21 | | - compiled_callable = torch.compile(self.forward, backend="eager") |
22 | | - super().__init__( |
23 | | - compiled_callable, compilation_mode=CompilationMode.DYNAMO_TRACE_ONCE |
24 | | - ) |
| 30 | + super().__init__() |
25 | 31 |
|
26 | | - def forward(self, x: torch.Tensor, cache: torch.Tensor | None = None): |
| 32 | + def forward(self, x: torch.Tensor): # type: ignore[override] |
27 | 33 | # this is the function to be compiled |
28 | | - return self.model(x, cache) |
29 | | - |
30 | | - def __call__(self, x: torch.Tensor, cache: torch.Tensor | None = None): |
31 | | - # let torch.compile compile twice |
32 | | - if len(self.compiled_codes) == 2: |
33 | | - dispatch_id = 0 if cache is None else 1 |
34 | | - with self.dispatch_to_code(dispatch_id): |
35 | | - return self.forward(x, cache) |
36 | | - else: |
37 | | - return self.compiled_callable(x, cache) |
| 34 | + return self.model(x) |
| 35 | + |
38 | 36 |
|
| 37 | +@pytest.mark.parametrize("use_bytecode_hook", [True, False]) |
| 38 | +def test_torch_compile_wrapper(use_bytecode_hook, monkeypatch): |
| 39 | + """Test basic functionality of TorchCompileWithNoGuardsWrapper.""" |
| 40 | + # Set the environment variable for this test |
| 41 | + monkeypatch.setenv("VLLM_USE_BYTECODE_HOOK", "1" if use_bytecode_hook else "0") |
39 | 42 |
|
40 | | -def test_torch_compile_wrapper(): |
41 | | - mod = MyMod() |
42 | | - wrappers = [] |
43 | | - for i in range(3): |
| 43 | + # Create a proper vLLM config instead of mocking |
| 44 | + vllm_config = VllmConfig() |
| 45 | + vllm_config.compilation_config = CompilationConfig() |
| 46 | + vllm_config.compilation_config.mode = CompilationMode.DYNAMO_TRACE_ONCE |
| 47 | + vllm_config.compilation_config.backend = "inductor" |
| 48 | + |
| 49 | + # Test DYNAMO_TRACE_ONCE |
| 50 | + with set_current_vllm_config(vllm_config): |
44 | 51 | torch._dynamo.reset() |
| 52 | + mod = MyMod() |
| 53 | + wrapper = MyWrapper(mod) |
| 54 | + |
| 55 | + # First call should trigger compilation |
| 56 | + x = torch.tensor([1, 2, 3, 4]) |
| 57 | + torch._dynamo.mark_dynamic(x, 0) |
| 58 | + |
| 59 | + result1 = wrapper(x) |
| 60 | + expected1 = torch.tensor([2, 4, 6, 8]) |
| 61 | + assert torch.allclose(result1, expected1), ( |
| 62 | + f"Expected {expected1}, got {result1}" |
| 63 | + ) |
| 64 | + |
| 65 | + # Second call should use compiled code |
| 66 | + x2 = torch.tensor([1, 2, 3]) |
| 67 | + result2 = wrapper(x2) |
| 68 | + expected2 = torch.tensor([2, 4, 6]) |
| 69 | + assert torch.allclose(result2, expected2), ( |
| 70 | + f"Expected {expected2}, got {result2}" |
| 71 | + ) |
| 72 | + |
| 73 | + # without the wrapper result would be different. |
| 74 | + result3 = mod(x2) |
| 75 | + expected3 = torch.tensor([100, 200, 300]) |
| 76 | + |
| 77 | + assert torch.allclose(result3, expected3), ( |
| 78 | + f"Expected {result3}, got {expected3}" |
| 79 | + ) |
| 80 | + |
| 81 | + # with STOCK_TORCH_COMPILE we do not remove guards. |
| 82 | + vllm_config.compilation_config.mode = CompilationMode.STOCK_TORCH_COMPILE |
| 83 | + torch._dynamo.reset() |
| 84 | + with set_current_vllm_config(vllm_config): |
| 85 | + mod = MyMod() |
45 | 86 | wrapper = MyWrapper(mod) |
46 | | - wrappers.append(wrapper) |
47 | | - x = torch.tensor([1]) |
48 | | - wrapper(x, None) # profile run, compile |
49 | | - # create a cache tensor |
50 | | - cache = torch.tensor([2]) |
51 | | - wrapper(x, cache) # warm up with cache, recompile |
52 | | - |
53 | | - # for new input, dispatch to the compiled code directly |
54 | | - new_x = torch.tensor([3]) |
55 | | - assert wrapper(new_x, None).item() == 6 # dispatch to the first compiled code |
56 | | - assert wrapper(new_x, cache).item() == 5 # dispatch to the second compiled code |
57 | | - |
58 | | - for wrapper in wrappers: |
59 | | - # make sure they have independent compiled codes |
60 | | - assert len(wrapper.compiled_codes) == 2 |
| 87 | + |
| 88 | + # First call should trigger compilation |
| 89 | + x = torch.tensor([1, 2, 3, 4]) |
| 90 | + torch._dynamo.mark_dynamic(x, 0) |
| 91 | + |
| 92 | + result1 = wrapper(x) |
| 93 | + expected1 = torch.tensor([2, 4, 6, 8]) |
| 94 | + assert torch.allclose(result1, expected1), ( |
| 95 | + f"Expected {expected1}, got {result1}" |
| 96 | + ) |
| 97 | + |
| 98 | + # Second call should triger another compilation |
| 99 | + x2 = torch.tensor([1, 2, 3]) |
| 100 | + result2 = wrapper(x2) |
| 101 | + expected2 = torch.tensor([100, 200, 300]) |
| 102 | + assert torch.allclose(result2, expected2), ( |
| 103 | + f"Expected {expected2}, got {result2}" |
| 104 | + ) |
| 105 | + |
| 106 | + # NO_COMPILATION level not supported. |
| 107 | + vllm_config.compilation_config.mode = None |
| 108 | + torch._dynamo.reset() |
| 109 | + with set_current_vllm_config(vllm_config): |
| 110 | + torch._dynamo.reset() |
| 111 | + mod = MyMod() |
| 112 | + |
| 113 | + try: |
| 114 | + wrapper = MyWrapper(mod) |
| 115 | + except Exception: |
| 116 | + return |
| 117 | + raise AssertionError("expected an exception to be raised") |
| 118 | + |
| 119 | + |
| 120 | +if __name__ == "__main__": |
| 121 | + # Run with both parameter values |
| 122 | + |
| 123 | + class MockMonkeypatch: |
| 124 | + def setenv(self, name, value): |
| 125 | + os.environ[name] = value |
| 126 | + |
| 127 | + mp = MockMonkeypatch() |
| 128 | + |
| 129 | + print("Testing with VLLM_USE_BYTECODE_HOOK=False") |
| 130 | + test_torch_compile_wrapper(False, mp) |
| 131 | + |
| 132 | + print("Testing with VLLM_USE_BYTECODE_HOOK=True") |
| 133 | + test_torch_compile_wrapper(True, mp) |
| 134 | + |
| 135 | + print("All tests passed!") |
0 commit comments