2828 as_tensor_variable ,
2929 cast ,
3030 constant ,
31+ expand_dims ,
3132 get_underlying_scalar_constant_value ,
3233 moveaxis ,
3334 ones_like ,
3435 register_infer_shape ,
3536 switch ,
3637 zeros_like ,
3738)
38- from pytensor .tensor .blockwise import Blockwise
3939from pytensor .tensor .elemwise import CAReduce , DimShuffle , Elemwise
4040from pytensor .tensor .exceptions import NotScalarConstantError
4141from pytensor .tensor .extra_ops import broadcast_arrays
4545 Sum ,
4646 _conj ,
4747 _dot ,
48- _inner_prod ,
49- _matrix_matrix_matmul ,
50- _matrix_vec_prod ,
51- _vec_matrix_prod ,
48+ _matmul ,
5249 add ,
5350 digamma ,
5451 dot ,
@@ -182,7 +179,7 @@ def local_lift_transpose_through_dot(fgraph, node):
182179 if not (
183180 is_matrix_transpose (node .outputs [0 ])
184181 and node .inputs [0 ].owner
185- and ((dot_op := node .inputs [0 ].owner .op ) in (_dot , _matrix_matrix_matmul ))
182+ and ((dot_op := node .inputs [0 ].owner .op ) in (_dot , _matmul ))
186183 ):
187184 return False
188185
@@ -197,60 +194,157 @@ def local_lift_transpose_through_dot(fgraph, node):
197194 return ret
198195
199196
200- @register_stabilize
201- @register_specialize
202- @node_rewriter (tracks = [Blockwise ])
203- def local_batched_matmul_to_core_matmul (fgraph , node ):
204- """Rewrite matmul where only one of the inputs has batch dimensions to a reshaped core matmul.
197+ def _batched_matmul_to_core_matmul (fgraph , node , allow_reshape : bool ):
198+ """Move batch dimensions of matmul operands to core matmul
205199
206- Example, if x has batch dimensions, but y not:
200+ Example, if x has batch dimensions that don't overlap with batch dimensions of y
207201 x @ y -> (x.reshape(-1, x.shape[-1]) @ y).reshape(*x.shape[:-1], y.shape[-1])
208202
209- It also works when y has batch dimensions, but x not.
210- """
203+ It also works for batch dimensions of y that don't overlap with batch dimensions of x
211204
212- # Check whether we have a matmul operation in this node
213- if not (
214- isinstance (node .op .core_op , Dot )
215- and len (node .op .inputs_sig [0 ]) == 2
216- and len (node .op .inputs_sig [1 ]) == 2
217- ):
218- return None
205+ The rewrite only uses reshape when mixing dimensions, and it can refuse to apply if `allow_reshape=False`
206+ """
219207
220208 x , y = node .inputs
221209 batch_ndim = node .op .batch_ndim (node )
222210
223- # Check if x has batch dimensions, but y not (or only broadcastable dimensions)
224- if any (not b_dim for b_dim in x .type .broadcastable [:- 2 ]) and all (
225- y .type .broadcastable [:- 2 ]
226- ):
227- x_stacked = x .reshape ((- 1 , x .shape [- 1 ]))
228- out_stacked = x_stacked @ y .squeeze (tuple (range (batch_ndim )))
229- out = out_stacked .reshape ((* x .shape [:- 1 ], y .shape [- 1 ]))
230- return [out ]
231-
232- # Otherwise, check if y has batch dimension, but x not
233- elif any (not b_dim for b_dim in y .type .broadcastable [:- 2 ]) and all (
234- x .type .broadcastable [:- 2 ]
235- ):
236- # For the y batch case we need to first move the batch axes and then reshape
237- # y.shape == (*b, k, n)
238- y_tr = moveaxis (y , - 2 , 0 ) # (k, *b, n)
239- y_stacked = y_tr .reshape ((y .shape [- 2 ], - 1 )) # (k, *b * n)
240- out_stacked = x .squeeze (tuple (range (batch_ndim ))) @ y_stacked # (m, *b * n)
241- out_stacked_tr = out_stacked .reshape (
242- (x .shape [- 2 ], * y .shape [:- 2 ], y .shape [- 1 ])
243- ) # (m, *b, n)
244- out = moveaxis (out_stacked_tr , 0 , - 2 ) # (*b, m, n)
245- return [out ]
246-
247- # Both x and y have batch dimensions, nothing to do here
248- return None
211+ x_axis_to_merge = [
212+ i
213+ for i , (bcast_x , bcast_y ) in enumerate (
214+ zip (x .type .broadcastable [:- 2 ], y .type .broadcastable [:- 2 ])
215+ )
216+ if bcast_y and not bcast_x
217+ ]
218+
219+ y_axis_to_merge = [
220+ i
221+ for i , (bcast_x , bcast_y ) in enumerate (
222+ zip (x .type .broadcastable [:- 2 ], y .type .broadcastable [:- 2 ])
223+ )
224+ if bcast_x and not bcast_y
225+ ]
226+
227+ if not (x_axis_to_merge or y_axis_to_merge ):
228+ return None
229+
230+ x_shape = tuple (x .shape )
231+ y_shape = tuple (y .shape )
232+ x_is_row = x .type .broadcastable [- 2 ]
233+ y_is_col = y .type .broadcastable [- 1 ]
234+ n_x_axis_to_merge = len (x_axis_to_merge )
235+ n_y_axis_to_merge = len (y_axis_to_merge )
236+ n_axis_to_merge = n_x_axis_to_merge + n_y_axis_to_merge
237+
238+ x_stacked , y_stacked = x , y
239+ dims_were_merged = False
240+
241+ if n_x_axis_to_merge :
242+ # ravel batch dimensions of x on the core (m) axis
243+ x_axis_destination = tuple (range (- n_x_axis_to_merge - 2 , - 2 ))
244+ x_stacked = moveaxis (x , x_axis_to_merge , x_axis_destination )
245+ if x_is_row :
246+ # x was a row matrix, squeeze it to clean up the graph
247+ x_stacked = x_stacked .squeeze (- 2 )
248+ if n_x_axis_to_merge > 1 or not x_is_row :
249+ if not allow_reshape :
250+ # TODO: We could allow the y rewrite to go on
251+ # Or just move one axis (the largest) if x is row
252+ return None
253+
254+ # Ravel moved batch dims together with (m) if needed
255+ x_stacked_shape = tuple (x_stacked .shape )
256+ x_stacked = x_stacked .reshape (
257+ (* x_stacked_shape [: batch_ndim - n_x_axis_to_merge ], - 1 , x_shape [- 1 ])
258+ )
259+ dims_were_merged = True
260+
261+ if n_y_axis_to_merge :
262+ # ravel batch dimensions of y on the core (n) axis
263+ y_axis_destination = tuple (range (- n_y_axis_to_merge - 1 , - 1 ))
264+ y_stacked = moveaxis (y , y_axis_to_merge , y_axis_destination )
265+ if y_is_col :
266+ # y was a column matrix, squeeze it to clean up the graph
267+ y_stacked = y_stacked .squeeze (- 1 )
268+ if n_y_axis_to_merge > 1 or not y_is_col :
269+ if not allow_reshape :
270+ # TODO: We could allow the x rewrite to go on
271+ # Or just move one axis (the largest) if y is col
272+ return None
273+ # Ravel moved batch dims together with (n) if needed
274+ y_stacked_shape = tuple (y_stacked .shape )
275+ y_stacked = y_stacked .reshape (
276+ (* y_stacked_shape [: batch_ndim - n_y_axis_to_merge ], y_shape [- 2 ], - 1 )
277+ )
278+ dims_were_merged = True
279+
280+ # Squeeze x_dims corresponding to merged dimensions of y
281+ x_axis_to_squeeze = np .array (y_axis_to_merge )
282+ for i in reversed (x_axis_to_merge ):
283+ # The corresponding dimensions of y may have shifted when we merged dimensions of x
284+ x_axis_to_squeeze [x_axis_to_squeeze > i ] -= 1
285+ x_stacked = x_stacked .squeeze (tuple (x_axis_to_squeeze ))
286+
287+ # Same for y
288+ y_axis_to_squeeze = np .array (x_axis_to_merge )
289+ for i in reversed (y_axis_to_merge ):
290+ y_axis_to_squeeze [y_axis_to_squeeze > i ] -= 1
291+ y_stacked = y_stacked .squeeze (tuple (y_axis_to_squeeze ))
292+
293+ out_stacked = x_stacked @ y_stacked
294+
295+ # Split back any merged dimensions
296+ if dims_were_merged :
297+ x_merged_shapes = [x_shape [i ] for i in x_axis_to_merge ]
298+ if not x_is_row :
299+ # Otherwise we handle that later with expand_dims, which is cleaner
300+ x_merged_shapes .append (x_shape [- 2 ])
301+ y_merged_shapes = [y_shape [i ] for i in y_axis_to_merge ]
302+ if not y_is_col :
303+ # Otherwise we handle that later with expand_dims, which is cleaner
304+ y_merged_shapes .append (y_shape [- 1 ])
305+ out_stacked_shape = tuple (out_stacked .shape )
306+ out_unstacked = out_stacked .reshape (
307+ (
308+ * out_stacked_shape [: batch_ndim - n_axis_to_merge ],
309+ * x_merged_shapes ,
310+ * y_merged_shapes ,
311+ )
312+ )
313+ else :
314+ out_unstacked = out_stacked
315+
316+ # Add back dummy row, col axis
317+ # We do this separately to avoid the reshape as much as we can
318+ if y_is_col and (n_y_axis_to_merge or dims_were_merged ):
319+ out_unstacked = expand_dims (out_unstacked , - 1 )
320+ if x_is_row and (n_x_axis_to_merge or dims_were_merged ):
321+ out_unstacked = expand_dims (out_unstacked , - n_y_axis_to_merge - 2 )
322+
323+ # Move batch axis back to their original location
324+ source = range (- n_axis_to_merge - 2 , 0 )
325+ destination = (* x_axis_to_merge , - 2 , * y_axis_to_merge , - 1 )
326+ out = moveaxis (out_unstacked , source , destination )
327+ return [out ]
328+
329+
330+ @register_canonicalize
331+ @node_rewriter (tracks = [_matmul ])
332+ def local_batched_matmul_to_core_matmul (fgraph , node ):
333+ # Allow passing batch dimensions of matmul to core vector / column matrices
334+ return _batched_matmul_to_core_matmul (fgraph , node , allow_reshape = False )
335+
336+
337+ @register_specialize
338+ @node_rewriter (tracks = [_matmul ])
339+ def local_batched_matmul_to_core_matmul_with_reshape (fgraph , node ):
340+ # Allow stacking batch dimensions of matmul with core dimensions, with a reshape operation
341+ # We only apply this in specialize, because grahs with reshape are hard to work with
342+ return _batched_matmul_to_core_matmul (fgraph , node , allow_reshape = True )
249343
250344
251345@register_canonicalize
252346@register_specialize
253- @node_rewriter ([_inner_prod , _matrix_vec_prod , _vec_matrix_prod , _matrix_matrix_matmul ])
347+ @node_rewriter ([_matmul ])
254348def local_blockwise_dot_to_mul (fgraph , node ):
255349 """Rewrite blockwise dots that correspond to multiplication without summation.
256350
0 commit comments