Skip to content

Commit 5bcc153

Browse files
authored
[Compile] Fix noop_elimination pass and add tests for noop_elimination (vllm-project#24880)
Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
1 parent 45bfa49 commit 5bcc153

File tree

4 files changed

+130
-23
lines changed

4 files changed

+130
-23
lines changed

.buildkite/test-pipeline.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -394,6 +394,7 @@ steps:
394394
- pytest -v -s compile/test_async_tp.py
395395
- pytest -v -s compile/test_fusion_all_reduce.py
396396
- pytest -v -s compile/test_decorator.py
397+
- pytest -v -s compile/test_noop_elimination.py
397398

398399
- label: PyTorch Fullgraph Smoke Test # 15min
399400
timeout_in_minutes: 30

tests/compile/backend.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,4 +64,8 @@ def check_after_ops(self, ops: Sequence[OpOverload]):
6464
num_pre = len(list(find_op_nodes(op, self.graph_pre_pass)))
6565
num_post = len(list(find_op_nodes(op, self.graph_post_pass)))
6666
assert num_pre == 0, f"Unexpected op {op.name()} in pre-pass graph"
67-
assert num_post > 0, f"Op {op.name()} not found in post-pass graph"
67+
assert num_post > 0, f"Op {op.name()} not found in post-pass graph"
68+
69+
def op_count(self, op: OpOverload, before=False) -> int:
70+
graph = self.graph_pre_pass if before else self.graph_post_pass
71+
return len(list(find_op_nodes(op, graph)))
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
import pytest
5+
import torch
6+
7+
import vllm
8+
from vllm.compilation.noop_elimination import NoOpEliminationPass
9+
from vllm.config import (CompilationConfig, CompilationLevel, PassConfig,
10+
VllmConfig)
11+
12+
from .backend import TestBackend
13+
14+
15+
@pytest.mark.parametrize("dtype",
16+
[torch.float16, torch.bfloat16, torch.float32])
17+
@pytest.mark.parametrize("num_tokens", [256, 1024])
18+
@pytest.mark.parametrize("hidden_size", [64, 4096])
19+
def test_noop_elimination(dtype, num_tokens, hidden_size):
20+
torch.set_default_device("cuda")
21+
torch.set_default_dtype(dtype)
22+
torch.manual_seed(1)
23+
24+
class Model(torch.nn.Module):
25+
26+
def forward(self, x):
27+
# Chain of reshapes
28+
y = x.reshape(-1, 128, 32)
29+
z = y.reshape(-1, 4096)
30+
# No-op reshape
31+
a = z.reshape(-1, 4096)
32+
# Final reshape that should remain
33+
b = a.reshape(-1, 128, 32)
34+
# No-op slice
35+
c = b[0:b.shape[0]]
36+
# The pass should replace the result of this op with `c`.
37+
d = torch.slice_scatter(
38+
torch.ones_like(c), # Dummy tensor to be scattered into
39+
c, # Source tensor
40+
0, # dim
41+
0, # start
42+
c.shape[0], # end
43+
)
44+
return d
45+
46+
vllm_config = VllmConfig(compilation_config=CompilationConfig(
47+
level=CompilationLevel.PIECEWISE,
48+
pass_config=PassConfig(enable_noop=True),
49+
))
50+
with vllm.config.set_current_vllm_config(vllm_config):
51+
noop_pass = NoOpEliminationPass(vllm_config)
52+
53+
backend = TestBackend(noop_pass)
54+
55+
model = Model()
56+
# First dimension dynamic
57+
x = torch.rand(num_tokens, hidden_size)
58+
torch._dynamo.mark_dynamic(x, 0)
59+
60+
result = model(x)
61+
62+
model2 = torch.compile(model, backend=backend)
63+
result2 = model2(x)
64+
65+
ATOL, RTOL = (2e-3, 2e-3)
66+
torch.testing.assert_close(result, result2, atol=ATOL, rtol=RTOL)
67+
68+
# The no-op reshape and slice should be eliminated.
69+
# The chain of reshapes should be fused into a single reshape.
70+
assert backend.op_count(torch.ops.aten.reshape.default) == 1
71+
assert backend.op_count(torch.ops.aten.slice.Tensor) == 0
72+
assert backend.op_count(torch.ops.aten.slice_scatter.default) == 0
73+
74+
75+
def test_non_noop_slice_preserved():
76+
"""Ensure that a slice with end=-1 (dropping last row) is NOT eliminated.
77+
78+
Regression test for a bug where end=-1 was treated like an inferred
79+
dimension (reshape semantics) leading to incorrect elimination.
80+
"""
81+
torch.set_default_device("cuda")
82+
x = torch.randn(16, 16)
83+
84+
class SliceModel(torch.nn.Module):
85+
86+
def forward(self, x):
87+
base = x.clone()
88+
src = torch.ones(15, 16)
89+
y = torch.slice_scatter(base, src, dim=0, start=0, end=-1)
90+
return x[0:-1, :], y
91+
92+
vllm_config = VllmConfig(compilation_config=CompilationConfig(
93+
level=CompilationLevel.PIECEWISE,
94+
pass_config=PassConfig(enable_noop=True),
95+
))
96+
with vllm.config.set_current_vllm_config(vllm_config):
97+
noop_pass = NoOpEliminationPass(vllm_config)
98+
backend = TestBackend(noop_pass)
99+
model = SliceModel()
100+
ref = model(x)
101+
compiled = torch.compile(model, backend=backend)
102+
out = compiled(x)
103+
torch.testing.assert_close(ref, out)
104+
# The slice should remain (not a no-op).
105+
assert backend.op_count(torch.ops.aten.slice.Tensor) == 1
106+
assert backend.op_count(torch.ops.aten.slice_scatter.default) == 1

vllm/compilation/noop_elimination.py

Lines changed: 18 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,6 @@ class NoOpEliminationPass(VllmInductorPass):
6262
scaled_mm: "f16[s0, 4096]" = ...
6363
at = auto_functionalized(fused_add_rms_norm, input = scaled_mm, ...)
6464
out: "f16[s0, 4096]" = at[1]
65-
66-
TODO(luka): This is currently tested in test_fusion,
67-
but separate tests could be good.
6865
"""
6966

7067
def __call__(self, graph: torch.fx.Graph):
@@ -96,17 +93,19 @@ def __call__(self, graph: torch.fx.Graph):
9693
# Invalid reshape args, skip
9794
continue
9895

99-
if self.all_dims_equivalent(shape, input_shape):
96+
if self.reshape_all_dims_equivalent(shape, input_shape):
10097
node.replace_all_uses_with(input)
10198
graph.erase_node(node)
10299
count += 1
103100

104101
elif is_func(node, torch.ops.aten.slice.Tensor):
102+
# python slicing semantics are different from reshape
103+
# Don't treat -1 as inferred dimension
105104
input, dim_index, start, end = node.args[:4]
106105
input_shape = input.meta["val"].shape
107-
i_dim = input_shape[dim_index]
106+
output_shape = node.meta["val"].shape
108107

109-
if start == 0 and self.dims_equivalent(end, i_dim):
108+
if output_shape == input_shape:
110109
node.replace_all_uses_with(input)
111110
graph.erase_node(node)
112111
count += 1
@@ -116,14 +115,7 @@ def __call__(self, graph: torch.fx.Graph):
116115
base_shape = base.meta["val"].shape
117116
view_shape = view.meta["val"].shape
118117

119-
view_dim = view_shape[dim_index]
120-
121-
# Check that view fully covers base and the full view is used
122-
# (if the view fully covered the base after slicing but was not
123-
# fully used, we could replace slice_scatter with a simple slice
124-
# but that's a niche case).
125-
if (base_shape == view_shape and start == 0
126-
and self.dims_equivalent(end, view_dim)):
118+
if base_shape == view_shape:
127119
node.replace_all_uses_with(view)
128120
graph.erase_node(node)
129121
count += 1
@@ -132,13 +124,9 @@ def __call__(self, graph: torch.fx.Graph):
132124
self.dump_graph(graph, "after_noop_elimination")
133125
self.end_and_log()
134126

135-
def all_dims_equivalent(self, dims: Iterable[Union[int, torch.fx.Node]],
136-
i_dims: Iterable[Union[int, SymInt]]):
137-
return all(
138-
self.dims_equivalent(s, i_s) for s, i_s in zip(dims, i_dims))
139-
140-
def dims_equivalent(self, dim: Union[int, torch.fx.Node],
141-
i_dim: Union[int, SymInt]) -> bool:
127+
# ---------------------- Reshape helpers ----------------------
128+
def reshape_dims_equivalent(self, dim: Union[int, torch.fx.Node],
129+
i_dim: Union[int, SymInt]) -> bool:
142130
"""
143131
This function checks if two dimensions are equivalent.
144132
:param dim: The dimension arg to reshape/slice
@@ -156,10 +144,18 @@ def dims_equivalent(self, dim: Union[int, torch.fx.Node],
156144
In case 3, the reshape dimension is a torch.fx.Node,
157145
and its value is a SymInt. That value is equal to the
158146
input dimension.
159-
160147
"""
161148
# Case 1 and 2
162149
if dim == i_dim or dim == -1:
163150
return True
164151
# Case 3
165152
return isinstance(dim, torch.fx.Node) and dim.meta["val"] == i_dim
153+
154+
def reshape_all_dims_equivalent(
155+
self,
156+
dims: Iterable[Union[int, torch.fx.Node]],
157+
i_dims: Iterable[Union[int, SymInt]],
158+
) -> bool:
159+
return all(
160+
self.reshape_dims_equivalent(s, i_s)
161+
for s, i_s in zip(dims, i_dims))

0 commit comments

Comments
 (0)