11from collections .abc import Callable , Sequence
2- from typing import Any , cast
2+ from typing import Any , Literal , cast , overload
33
44import numpy as np
55from numpy import broadcast_shapes , empty
3232from pytensor .tensor .variable import TensorVariable
3333
3434
35+ def _squeeze_left (x , stop_at_dim : int | None = None ):
36+ """Squeeze any leading dims of `x` until a real dim or `stop_at_dim` (if not None) is reached."""
37+ x_dims = x .type .broadcastable
38+ squeeze_ndim = len (x_dims ) if all (x_dims ) else x_dims .index (False )
39+ if stop_at_dim is not None :
40+ squeeze_ndim = min (squeeze_ndim , stop_at_dim )
41+ if squeeze_ndim == 0 :
42+ return x
43+ return x .squeeze (axis = tuple (range (squeeze_ndim )))
44+
45+
3546def _vectorize_node_perform (
3647 core_node : Apply ,
3748 batch_bcast_patterns : Sequence [tuple [bool , ...]],
@@ -143,8 +154,6 @@ def _check_runtime_broadcast_core(numerical_inputs, batch_bcast_patterns, batch_
143154class Blockwise (COp ):
144155 """Generalizes a core `Op` to work with batched dimensions.
145156
146- TODO: Dispatch JAX (should be easy with the vectorize macro)
147- TODO: Dispatch Numba
148157 TODO: C implementation?
149158 TODO: Fuse Blockwise?
150159 """
@@ -202,21 +211,52 @@ def __init__(
202211
203212 super ().__init__ (** kwargs )
204213
205- def _create_dummy_core_node (self , inputs : Sequence [TensorVariable ]) -> Apply :
206- core_input_types = []
214+ @overload
215+ def _create_dummy_core_node (
216+ self ,
217+ inputs : Sequence [TensorVariable ],
218+ * ,
219+ propagate_unbatched_core_inputs : bool = False ,
220+ return_dummy_inputs : Literal [False ] = ...,
221+ ) -> Apply : ...
222+
223+ @overload
224+ def _create_dummy_core_node (
225+ self ,
226+ inputs : Sequence [TensorVariable ],
227+ * ,
228+ propagate_unbatched_core_inputs : bool = False ,
229+ return_dummy_inputs : Literal [True ] = ...,
230+ ) -> tuple [Apply , list [TensorVariable ]]: ...
231+
232+ def _create_dummy_core_node (
233+ self ,
234+ inputs : Sequence [TensorVariable ],
235+ * ,
236+ propagate_unbatched_core_inputs : bool = False ,
237+ return_dummy_inputs : bool = False ,
238+ ) -> Apply | tuple [Apply , list [TensorVariable ]]:
239+ core_inputs = []
240+ core_dummy_inputs = []
207241 for i , (inp , sig ) in enumerate (zip (inputs , self .inputs_sig , strict = True )):
208242 if inp .type .ndim < len (sig ):
209243 raise ValueError (
210244 f"Input { i } { inp } has insufficient core dimensions for signature { self .signature } "
211245 )
212246 # ndim_supp = 0 case
213- if not sig :
214- core_shape = ()
247+ inp_ndim = inp .type .ndim
248+ batch_ndim = inp_ndim - len (sig )
249+ core_shape = inp .type .shape [batch_ndim :]
250+ if propagate_unbatched_core_inputs and all (
251+ inp .type .broadcastable [:batch_ndim ]
252+ ):
253+ core_inputs .append (_squeeze_left (inp , batch_ndim ))
215254 else :
216- core_shape = inp .type .shape [- len (sig ) :]
217- core_input_types .append (tensor (dtype = inp .type .dtype , shape = core_shape ))
255+ dummy_inp = tensor (dtype = inp .type .dtype , shape = core_shape )
256+ core_inputs .append (dummy_inp )
257+ core_dummy_inputs .append (dummy_inp )
218258
219- core_node = self .core_op .make_node (* core_input_types )
259+ core_node = self .core_op .make_node (* core_inputs )
220260
221261 if len (core_node .outputs ) != len (self .outputs_sig ):
222262 raise ValueError (
@@ -230,6 +270,9 @@ def _create_dummy_core_node(self, inputs: Sequence[TensorVariable]) -> Apply:
230270 f"Output { i } of { self .core_op } has wrong number of core dimensions for signature { self .signature } : { core_out .type .ndim } "
231271 )
232272
273+ if return_dummy_inputs :
274+ return core_node , core_dummy_inputs
275+
233276 return core_node
234277
235278 def make_node (self , * inputs ):
@@ -298,11 +341,17 @@ def infer_shape(
298341
299342 batch_shape = broadcast_shape (* batch_shapes , arrays_are_shapes = True )
300343
301- # Try to extract the core shapes from the core_op
302- core_op_infer_shape = getattr (self .core_op , "infer_shape" , None )
303- if core_op_infer_shape is not None :
304- dummy_core_node = self ._create_dummy_core_node (node .inputs )
305- dummy_core_inputs = tuple (explicit_graph_inputs (dummy_core_node .inputs ))
344+ def extract_core_shape_from_infer_shape ():
345+ # Try to extract the core shapes from the core_op
346+ core_op_infer_shape = getattr (self .core_op , "infer_shape" , None )
347+ if core_op_infer_shape is None :
348+ return [[None ] * out .ndim for out in node .outputs ]
349+
350+ dummy_core_node , dummy_core_inputs = self ._create_dummy_core_node (
351+ node .inputs ,
352+ return_dummy_inputs = True ,
353+ propagate_unbatched_core_inputs = True ,
354+ )
306355 dummy_fgraph = FunctionGraph (outputs = dummy_core_node .outputs , clone = False )
307356 core_input_shapes = [
308357 input_shape [batch_ndims :] for input_shape in input_shapes
@@ -311,6 +360,25 @@ def infer_shape(
311360 dummy_fgraph , dummy_core_node , core_input_shapes
312361 )
313362
363+ # Set to None those core_shapes that depend on dummy_core_inputs,
364+ # meaning their value may not be constant across batch dims of the Blockwise
365+ if not dummy_core_inputs :
366+ # All inputs are unbatched, so the core_shape can be used as is
367+ return core_output_shapes
368+ else :
369+ set_dummy_core_inputs = set (dummy_core_inputs )
370+ safe_core_output_shapes = [list (shape ) for shape in core_output_shapes ]
371+ for core_out_shape in safe_core_output_shapes :
372+ for o , core_out_dim in enumerate (core_out_shape ):
373+ if set_dummy_core_inputs & set (
374+ explicit_graph_inputs ([core_out_dim ])
375+ ):
376+ core_out_shape [o ] = None
377+
378+ return safe_core_output_shapes
379+
380+ safe_core_out_shape = None
381+
314382 out_shapes = []
315383 for o , (output , sig ) in enumerate (
316384 zip (node .outputs , self .outputs_sig , strict = True )
@@ -321,19 +389,15 @@ def infer_shape(
321389 if dim_name in core_dims :
322390 core_out_shape .append (core_dims [dim_name ])
323391 else :
324- if core_op_infer_shape is not None :
325- # If the input values are needed to compute the dimension length, we can't use the infer_shape
326- # of the core_node as the value is not constant across batch dims of the Blockwise
327- core_out_dim = core_output_shapes [o ][i ]
328- if not (
329- set (dummy_core_inputs )
330- & set (explicit_graph_inputs ([core_out_dim ]))
331- ):
332- core_out_shape .append (core_out_dim )
333- continue
334-
335- # Fallback shape requires evaluating the Blockwise Op
336- core_out_shape .append (Shape_i (batch_ndims + i )(output ))
392+ if safe_core_out_shape is None :
393+ # Extract the core shape from the core_op infer_shape on demand
394+ # For many Ops we never need to do this, because all info is in their signature
395+ safe_core_out_shape = extract_core_shape_from_infer_shape ()
396+ if (core_out_dim := safe_core_out_shape [o ][i ]) is not None :
397+ core_out_shape .append (core_out_dim )
398+ else :
399+ # Fallback shape requires evaluating the Blockwise Op
400+ core_out_shape .append (Shape_i (batch_ndims + i )(output ))
337401 out_shapes .append ((* batch_shape , * core_out_shape ))
338402
339403 return out_shapes
@@ -448,7 +512,10 @@ def gufunc(
448512 )
449513 return core_func (* inputs )
450514 else :
451- core_node = self ._create_dummy_core_node (node .inputs ) # type: ignore
515+ core_node = self ._create_dummy_core_node (
516+ cast (list [TensorVariable ], node .inputs ),
517+ propagate_unbatched_core_inputs = True ,
518+ )
452519 gufunc = _vectorize_node_perform (
453520 core_node ,
454521 batch_bcast_patterns = batch_bcast_patterns ,
0 commit comments