55from pytensor .graph import ancestors
66from pytensor .graph .op import compute_test_value
77from pytensor .graph .rewriting .basic import copy_stack_trace , in2out , node_rewriter
8- from pytensor .scalar import integer_types
9- from pytensor .tensor import NoneConst
8+ from pytensor .tensor import NoneConst , TensorVariable
109from pytensor .tensor .basic import constant
1110from pytensor .tensor .elemwise import DimShuffle
1211from pytensor .tensor .extra_ops import broadcast_to
1312from pytensor .tensor .random .op import RandomVariable
1413from pytensor .tensor .random .utils import broadcast_params
15- from pytensor .tensor .shape import Shape , Shape_i , shape_padleft
14+ from pytensor .tensor .shape import Shape , Shape_i
1615from pytensor .tensor .subtensor import (
1716 AdvancedSubtensor ,
1817 AdvancedSubtensor1 ,
1918 Subtensor ,
20- as_index_variable ,
2119 get_idx_list ,
2220)
21+ from pytensor .tensor .type import integer_dtypes
2322from pytensor .tensor .type_other import NoneTypeT , SliceType
2423
2524
@@ -127,22 +126,23 @@ def local_dimshuffle_rv_lift(fgraph, node):
127126
128127 ds_op = node .op
129128
130- if not isinstance (ds_op , DimShuffle ):
129+ # Dimshuffle which drop dimensions not supported yet
130+ if ds_op .drop :
131131 return False
132132
133- base_rv = node .inputs [0 ]
134- rv_node = base_rv .owner
133+ rv_node = node .inputs [0 ].owner
135134
136135 if not (rv_node and isinstance (rv_node .op , RandomVariable )):
137136 return False
138137
139- # Dimshuffle which drop dimensions not supported yet
140- if ds_op .drop :
141- return False
142-
143138 rv_op = rv_node .op
144139 rng , size , * dist_params = rv_node .inputs
145- rv = rv_node .default_output ()
140+ next_rng , rv = rv_node .outputs
141+
142+ # If no one else is using the underlying `RandomVariable`, then we can
143+ # do this; otherwise, the graph would be internally inconsistent.
144+ if is_rv_used_in_graph (rv , node , fgraph ):
145+ return False
146146
147147 # Check that Dimshuffle does not affect support dims
148148 supp_dims = set (range (rv .ndim - rv_op .ndim_supp , rv .ndim ))
@@ -153,31 +153,24 @@ def local_dimshuffle_rv_lift(fgraph, node):
153153
154154 # If no one else is using the underlying RandomVariable, then we can
155155 # do this; otherwise, the graph would be internally inconsistent.
156- if is_rv_used_in_graph (base_rv , node , fgraph ):
156+ if is_rv_used_in_graph (rv , node , fgraph ):
157157 return False
158158
159159 batched_dims = rv .ndim - rv_op .ndim_supp
160160 batched_dims_ds_order = tuple (o for o in ds_op .new_order if o not in supp_dims )
161161
162162 if isinstance (size .type , NoneTypeT ):
163- # Make size explicit
164- shape = tuple (broadcast_params (dist_params , rv_op .ndims_params )[0 ].shape )
165- size = shape [:batched_dims ]
166-
167- # Update the size to reflect the DimShuffled dimensions
168- new_size = [
169- constant (1 , dtype = "int64" ) if o == "x" else size [o ]
170- for o in batched_dims_ds_order
171- ]
163+ new_size = size
164+ else :
165+ # Update the size to reflect the DimShuffled dimensions
166+ new_size = [
167+ constant (1 , dtype = "int64" ) if o == "x" else size [o ]
168+ for o in batched_dims_ds_order
169+ ]
172170
173171 # Updates the params to reflect the Dimshuffled dimensions
174172 new_dist_params = []
175173 for param , param_ndim_supp in zip (dist_params , rv_op .ndims_params ):
176- # Add broadcastable dimensions to the parameters that would have been expanded by the size
177- padleft = batched_dims - (param .ndim - param_ndim_supp )
178- if padleft > 0 :
179- param = shape_padleft (param , padleft )
180-
181174 # Add the parameter support dimension indexes to the batched dimensions Dimshuffle
182175 param_new_order = batched_dims_ds_order + tuple (
183176 range (batched_dims , batched_dims + param_ndim_supp )
@@ -189,10 +182,10 @@ def local_dimshuffle_rv_lift(fgraph, node):
189182 if config .compute_test_value != "off" :
190183 compute_test_value (new_node )
191184
192- out = new_node .outputs [ 1 ]
193- if base_rv .name :
194- out .name = f"{ base_rv .name } _lifted"
195- return [out ]
185+ new_rv = new_node .default_output ()
186+ if rv .name :
187+ new_rv .name = f"{ rv .name } _lifted"
188+ return [new_rv ]
196189
197190
198191@node_rewriter ([Subtensor , AdvancedSubtensor1 , AdvancedSubtensor ])
@@ -206,47 +199,38 @@ def local_subtensor_rv_lift(fgraph, node):
206199 ``mvnormal(mu, cov, size=(2,))[0, 0]``.
207200 """
208201
209- def is_nd_advanced_idx (idx , dtype ):
202+ def is_nd_advanced_idx (idx , dtype ) -> bool :
203+ if not isinstance (idx , TensorVariable ):
204+ return False
210205 if isinstance (dtype , str ):
211206 return (getattr (idx .type , "dtype" , None ) == dtype ) and (idx .type .ndim >= 1 )
212207 else :
213208 return (getattr (idx .type , "dtype" , None ) in dtype ) and (idx .type .ndim >= 1 )
214209
215210 subtensor_op = node .op
216211
217- old_subtensor = node .outputs [0 ]
218- rv = node .inputs [0 ]
219- rv_node = rv .owner
212+ [indexed_rv ] = node .outputs
213+ rv_node = node .inputs [0 ].owner
220214
221215 if not (rv_node and isinstance (rv_node .op , RandomVariable )):
222216 return False
223217
224- shape_feature = getattr (fgraph , "shape_feature" , None )
225- if not shape_feature :
226- return None
227-
228- # Use shape_feature to facilitate inferring final shape.
229- # Check that neither the RV nor the old Subtensor are in the shape graph.
230- output_shape = fgraph .shape_feature .shape_of .get (old_subtensor , None )
231- if output_shape is None or {old_subtensor , rv } & set (ancestors (output_shape )):
232- return None
233-
234218 rv_op = rv_node .op
235219 rng , size , * dist_params = rv_node .inputs
220+ rv = rv_node .default_output ()
221+
222+ # If no one else is using the underlying `RandomVariable`, then we can
223+ # do this; otherwise, the graph would be internally inconsistent.
224+ if is_rv_used_in_graph (rv , node , fgraph ):
225+ return False
236226
237227 # Parse indices
238- idx_list = getattr (subtensor_op , "idx_list" , None )
239- if idx_list :
240- idx_vars = get_idx_list (node .inputs , idx_list )
241- else :
242- idx_vars = node .inputs [1 :]
243- indices = tuple (as_index_variable (idx ) for idx in idx_vars )
228+ indices = get_idx_list (node .inputs , getattr (subtensor_op , "idx_list" , None ))
244229
245230 # The rewrite doesn't apply if advanced indexing could broadcast the samples (leading to duplicates)
246231 # Note: For simplicity this also excludes subtensor-related expand_dims (np.newaxis).
247232 # If we wanted to support that we could rewrite it as subtensor + dimshuffle
248233 # and make use of the dimshuffle lift rewrite
249- integer_dtypes = {type .dtype for type in integer_types }
250234 if any (
251235 is_nd_advanced_idx (idx , integer_dtypes ) or NoneConst .equals (idx )
252236 for idx in indices
@@ -277,34 +261,35 @@ def is_nd_advanced_idx(idx, dtype):
277261 n_discarded_idxs = len (supp_indices )
278262 indices = indices [:- n_discarded_idxs ]
279263
280- # If no one else is using the underlying `RandomVariable`, then we can
281- # do this; otherwise, the graph would be internally inconsistent.
282- if is_rv_used_in_graph (rv , node , fgraph ):
283- return False
284-
285264 # Update the size to reflect the indexed dimensions
286- new_size = output_shape [: len (output_shape ) - rv_op .ndim_supp ]
265+ if isinstance (size .type , NoneTypeT ):
266+ new_size = size
267+ else :
268+ shape_feature = getattr (fgraph , "shape_feature" , None )
269+ if not shape_feature :
270+ return None
271+
272+ # Use shape_feature to facilitate inferring final shape.
273+ # Check that neither the RV nor the old Subtensor are in the shape graph.
274+ output_shape = fgraph .shape_feature .shape_of .get (indexed_rv , None )
275+ if output_shape is None or {indexed_rv , rv } & set (ancestors (output_shape )):
276+ return None
277+
278+ new_size = output_shape [: len (output_shape ) - rv_op .ndim_supp ]
287279
288280 # Propagate indexing to the parameters' batch dims.
289281 # We try to avoid broadcasting the parameters together (and with size), by only indexing
290282 # non-broadcastable (non-degenerate) parameter dims. These parameters and the new size
291283 # should still correctly broadcast any degenerate parameter dims.
292284 new_dist_params = []
293285 for param , param_ndim_supp in zip (dist_params , rv_op .ndims_params ):
294- # We first expand any missing parameter dims (and later index them away or keep them with none-slicing)
295- batch_param_dims_missing = batch_ndims - (param .ndim - param_ndim_supp )
296- batch_param = (
297- shape_padleft (param , batch_param_dims_missing )
298- if batch_param_dims_missing
299- else param
300- )
301- # Check which dims are actually broadcasted
302- bcast_batch_param_dims = tuple (
286+ # Check which dims are broadcasted by either size or other parameters
287+ bcast_param_dims = tuple (
303288 dim
304- for dim , (param_dim , output_dim ) in enumerate (
305- zip (batch_param .type .shape , rv .type .shape )
289+ for dim , (param_dim_bcast , output_dim_bcast ) in enumerate (
290+ zip (param .type .broadcastable , rv .type .broadcastable )
306291 )
307- if ( param_dim == 1 ) and ( output_dim != 1 )
292+ if param_dim_bcast and not output_dim_bcast
308293 )
309294 batch_indices = []
310295 curr_dim = 0
@@ -315,23 +300,23 @@ def is_nd_advanced_idx(idx, dtype):
315300 # If not, we use that directly, instead of the more inefficient `nonzero` form
316301 bool_dims = range (curr_dim , curr_dim + idx .type .ndim )
317302 # There's an overlap, we have to decompose the boolean mask as a `nonzero`
318- if set (bool_dims ) & set (bcast_batch_param_dims ):
303+ if set (bool_dims ) & set (bcast_param_dims ):
319304 int_indices = list (idx .nonzero ())
320305 # Indexing by 0 drops the degenerate dims
321306 for bool_dim in bool_dims :
322- if bool_dim in bcast_batch_param_dims :
307+ if bool_dim in bcast_param_dims :
323308 int_indices [bool_dim - curr_dim ] = 0
324309 batch_indices .extend (int_indices )
325- # No overlap, use index as is
310+ # No overlap, use boolean index as is
326311 else :
327312 batch_indices .append (idx )
328313 curr_dim += len (bool_dims )
329314 # Basic-indexing (slice or integer)
330315 else :
331316 # Broadcasted dim
332- if curr_dim in bcast_batch_param_dims :
317+ if curr_dim in bcast_param_dims :
333318 # Slice indexing, keep degenerate dim by none-slicing
334- if isinstance (idx .type , SliceType ):
319+ if isinstance (idx , slice ) or isinstance ( idx .type , SliceType ):
335320 batch_indices .append (slice (None ))
336321 # Integer indexing, drop degenerate dim by 0-indexing
337322 else :
@@ -342,7 +327,7 @@ def is_nd_advanced_idx(idx, dtype):
342327 batch_indices .append (idx )
343328 curr_dim += 1
344329
345- new_dist_params .append (batch_param [tuple (batch_indices )])
330+ new_dist_params .append (param [tuple (batch_indices )])
346331
347332 # Create new RV
348333 new_node = rv_op .make_node (rng , new_size , * new_dist_params )
0 commit comments