@@ -461,6 +461,41 @@ def local_subtensor_of_expand_dims(fgraph, node):
461461 return [out ]
462462
463463
464+ @register_canonicalize
465+ @register_specialize
466+ @node_rewriter ([Subtensor ])
467+ def local_subtensor_of_squeeze (fgraph , node ):
468+ """Lift subtensor through a squeeze operation"""
469+ x , * idxs_vars = node .inputs
470+ if not (
471+ x .owner is not None
472+ and isinstance (x .owner .op , DimShuffle )
473+ and x .owner .op .is_squeeze
474+ ):
475+ return None
476+
477+ [x_before_squeeze ] = x .owner .inputs
478+ idxs = indices_from_subtensor (idxs_vars , node .op .idx_list )
479+ dropped_dims = x .owner .op .drop
480+
481+ # Apply indices directly on x
482+ # Add empty slices on the axis that squeeze would have removed
483+ new_idxs = np .insert (np .array (idxs , dtype = object ), dropped_dims , slice (None ))
484+ x_indexed = x_before_squeeze [tuple (new_idxs )]
485+
486+ # Reapply squeeze
487+ # Indexing may have squeezed some dimensions, so we need to recalculate dropped_dims
488+ new_dropped_dims = np .array (dropped_dims )
489+ for i , new_idx in reversed (tuple (enumerate (new_idxs ))):
490+ if not isinstance (new_idx , slice ):
491+ # If it's not a slice, it's an integer which drops the dimension
492+ new_dropped_dims [new_dropped_dims > i ] -= 1
493+ new_x = x_indexed .squeeze (tuple (new_dropped_dims ))
494+
495+ copy_stack_trace (x , new_x )
496+ return [new_x ]
497+
498+
464499@register_canonicalize
465500@register_specialize
466501@node_rewriter ([Subtensor ])
0 commit comments