55
66from pytensor import Variable
77from pytensor .compile import optdb
8- from pytensor .graph import Constant , FunctionGraph , node_rewriter
8+ from pytensor .graph import Constant , FunctionGraph , node_rewriter , vectorize_graph
99from pytensor .graph .rewriting .basic import NodeRewriter , copy_stack_trace
1010from pytensor .npy_2_compat import normalize_axis_index , normalize_axis_tuple
1111from pytensor .scalar import basic as ps
@@ -119,21 +119,48 @@ def local_subtensor_of_dot(fgraph, node):
119119 the remaining entries of ``idxs`` (if any), modified to skip the
120120 second-to-last dimension of ``B`` (because dot sums over this dimension).
121121 """
122- if not isinstance (node .op , Subtensor ):
123- return
124- if not (node .inputs [0 ].owner and isinstance (node .inputs [0 ].owner .op , Dot )):
122+ x , * idx_vars = node .inputs
123+ if not (
124+ x .owner is not None
125+ and (
126+ isinstance (x .owner .op , Dot )
127+ or (
128+ isinstance (x .owner .op , Blockwise )
129+ and isinstance (x .owner .op .core_op , Dot )
130+ )
131+ )
132+ ):
125133 return
126134 # If there is other node that use the outputs of the dot
127135 # We don't want to compute twice the sub part.
128- if len (fgraph .clients [node . inputs [ 0 ] ]) > 1 :
136+ if len (fgraph .clients [x ]) > 1 :
129137 return
130138
131- a = node .inputs [0 ].owner .inputs [0 ]
132- b = node .inputs [0 ].owner .inputs [1 ]
139+ a = x .owner .inputs [0 ]
140+ b = x .owner .inputs [1 ]
141+ idx_list = indices_from_subtensor (idx_vars , node .op .idx_list )
133142
134- idx_list = get_idx_list (node .inputs , node .op .idx_list )
143+ if not idx_list :
144+ # Nothing to do, `local_useless_slice` will handle this case
145+ return None
135146
136- num_a_indices = min (a .ndim - 1 , len (idx_list ))
147+ batch_ndim = (
148+ x .owner .op .batch_ndim (x .owner ) if isinstance (x .owner .op , Blockwise ) else 0
149+ )
150+
151+ if batch_ndim :
152+ batch_idx_list , idx_list = idx_list [:batch_ndim ], idx_list [batch_ndim :]
153+ if not idx_list :
154+ # Indexing only over batch dimensions of Blockwise, nothing to do here
155+ # This will be handled by `local_subtensor_of_batch_dims`
156+ return None
157+ # We perform the rest of the rewrite on dummy a, b that correspond to the core case
158+ a = a .type .clone (shape = a .type .shape [batch_ndim :])()
159+ b = b .type .clone (shape = b .type .shape [batch_ndim :])()
160+
161+ a_ndim = a .ndim
162+ b_ndim = b .ndim
163+ num_a_indices = min (a_ndim - 1 , len (idx_list ))
137164 a_indices = idx_list [:num_a_indices ]
138165 b_indices = idx_list [num_a_indices :]
139166
@@ -142,26 +169,22 @@ def local_subtensor_of_dot(fgraph, node):
142169 # This wasn't necessary for a, because we just omitted the last index.
143170 # We skip this if b.ndim = 1, since then we just want b_sub = b, not b_sub = b[:]
144171 # (dot also handles b.ndim < 2 as a special case)
145- if b . ndim > 1 and len (b_indices ) >= b . ndim - 1 :
172+ if b_ndim > 1 and len (b_indices ) >= b_ndim - 1 :
146173 b_indices = (
147- b_indices [: b . ndim - 2 ]
174+ b_indices [: b_ndim - 2 ]
148175 + (slice (None , None , None ),)
149- + b_indices [b . ndim - 2 :]
176+ + b_indices [b_ndim - 2 :]
150177 )
151178
152- a_sub = a .__getitem__ (tuple (a_indices ))
153- b_sub = b .__getitem__ (tuple (b_indices )) if b_indices else b
179+ a_sub = a [tuple (a_indices )]
180+ b_sub = b [tuple (b_indices )] if b_indices else b
181+ r = dot (a_sub , b_sub )
154182
155- # Copy over previous output stacktrace to a_sub and b_sub,
156- # because an error in the subtensor operation (e.g. an index error)
157- # on either a or b must correspond to an error in the
158- # subtensor operation on their dot product.
159- copy_stack_trace (node .outputs [0 ], [a_sub , b_sub ])
183+ if batch_ndim :
184+ # Replace dummy inputs by the original batch ones
185+ r = vectorize_graph (r , replace = {a : x .owner .inputs [0 ], b : x .owner .inputs [1 ]})
186+ r = r [tuple (batch_idx_list )]
160187
161- # Copy over previous output stacktrace and previous dot product stacktrace,
162- # because an error here may correspond to an either in either the original
163- # dot product, or in the dot product after the subtensor operation.
164- r = dot (a_sub , b_sub )
165188 copy_stack_trace ([node .outputs [0 ], node .inputs [0 ]], r )
166189
167190 return [r ]
0 commit comments