@@ -62,9 +62,6 @@ class NoOpEliminationPass(VllmInductorPass):
6262 scaled_mm: "f16[s0, 4096]" = ...
6363 at = auto_functionalized(fused_add_rms_norm, input = scaled_mm, ...)
6464 out: "f16[s0, 4096]" = at[1]
65-
66- TODO(luka): This is currently tested in test_fusion,
67- but separate tests could be good.
6865 """
6966
7067 def __call__ (self , graph : torch .fx .Graph ):
@@ -96,17 +93,19 @@ def __call__(self, graph: torch.fx.Graph):
9693 # Invalid reshape args, skip
9794 continue
9895
99- if self .all_dims_equivalent (shape , input_shape ):
96+ if self .reshape_all_dims_equivalent (shape , input_shape ):
10097 node .replace_all_uses_with (input )
10198 graph .erase_node (node )
10299 count += 1
103100
104101 elif is_func (node , torch .ops .aten .slice .Tensor ):
102+ # python slicing semantics are different from reshape
103+ # Don't treat -1 as inferred dimension
105104 input , dim_index , start , end = node .args [:4 ]
106105 input_shape = input .meta ["val" ].shape
107- i_dim = input_shape [ dim_index ]
106+ output_shape = node . meta [ "val" ]. shape
108107
109- if start == 0 and self . dims_equivalent ( end , i_dim ) :
108+ if output_shape == input_shape :
110109 node .replace_all_uses_with (input )
111110 graph .erase_node (node )
112111 count += 1
@@ -116,14 +115,7 @@ def __call__(self, graph: torch.fx.Graph):
116115 base_shape = base .meta ["val" ].shape
117116 view_shape = view .meta ["val" ].shape
118117
119- view_dim = view_shape [dim_index ]
120-
121- # Check that view fully covers base and the full view is used
122- # (if the view fully covered the base after slicing but was not
123- # fully used, we could replace slice_scatter with a simple slice
124- # but that's a niche case).
125- if (base_shape == view_shape and start == 0
126- and self .dims_equivalent (end , view_dim )):
118+ if base_shape == view_shape :
127119 node .replace_all_uses_with (view )
128120 graph .erase_node (node )
129121 count += 1
@@ -132,13 +124,9 @@ def __call__(self, graph: torch.fx.Graph):
132124 self .dump_graph (graph , "after_noop_elimination" )
133125 self .end_and_log ()
134126
135- def all_dims_equivalent (self , dims : Iterable [Union [int , torch .fx .Node ]],
136- i_dims : Iterable [Union [int , SymInt ]]):
137- return all (
138- self .dims_equivalent (s , i_s ) for s , i_s in zip (dims , i_dims ))
139-
140- def dims_equivalent (self , dim : Union [int , torch .fx .Node ],
141- i_dim : Union [int , SymInt ]) -> bool :
127+ # ---------------------- Reshape helpers ----------------------
128+ def reshape_dims_equivalent (self , dim : Union [int , torch .fx .Node ],
129+ i_dim : Union [int , SymInt ]) -> bool :
142130 """
143131 This function checks if two dimensions are equivalent.
144132 :param dim: The dimension arg to reshape/slice
@@ -156,10 +144,18 @@ def dims_equivalent(self, dim: Union[int, torch.fx.Node],
156144 In case 3, the reshape dimension is a torch.fx.Node,
157145 and its value is a SymInt. That value is equal to the
158146 input dimension.
159-
160147 """
161148 # Case 1 and 2
162149 if dim == i_dim or dim == - 1 :
163150 return True
164151 # Case 3
165152 return isinstance (dim , torch .fx .Node ) and dim .meta ["val" ] == i_dim
153+
154+ def reshape_all_dims_equivalent (
155+ self ,
156+ dims : Iterable [Union [int , torch .fx .Node ]],
157+ i_dims : Iterable [Union [int , SymInt ]],
158+ ) -> bool :
159+ return all (
160+ self .reshape_dims_equivalent (s , i_s )
161+ for s , i_s in zip (dims , i_dims ))
0 commit comments