@@ -370,74 +370,73 @@ def local_subtensor_merge(fgraph, node):
370370 """
371371 from pytensor .scan .op import Scan
372372
373- if isinstance (node .op , Subtensor ):
374- u = node .inputs [0 ]
375- if u .owner and isinstance (u .owner .op , Subtensor ):
376- # We can merge :)
377- # x actual tensor on which we are picking slices
378- x = u .owner .inputs [0 ]
379- # slices of the first applied subtensor
380- slices1 = get_idx_list (u .owner .inputs , u .owner .op .idx_list )
381- slices2 = get_idx_list (node .inputs , node .op .idx_list )
382-
383- # Don't try to do the optimization on do-while scan outputs,
384- # as it will create a dependency on the shape of the outputs
385- if (
386- x .owner is not None
387- and isinstance (x .owner .op , Scan )
388- and x .owner .op .info .as_while
389- ):
390- return None
373+ u = node .inputs [0 ]
374+ if not (u .owner is not None and isinstance (u .owner .op , Subtensor )):
375+ return None
391376
392- # Get the shapes of the vectors !
393- try :
394- # try not to introduce new shape into the graph
395- xshape = fgraph .shape_feature .shape_of [x ]
396- ushape = fgraph .shape_feature .shape_of [u ]
397- except AttributeError :
398- # Following the suggested use of shape_feature which should
399- # consider the case when the compilation mode doesn't
400- # include the ShapeFeature
401- xshape = x .shape
402- ushape = u .shape
403-
404- merged_slices = []
405- pos_2 = 0
406- pos_1 = 0
407- while (pos_1 < len (slices1 )) and (pos_2 < len (slices2 )):
408- slice1 = slices1 [pos_1 ]
409- if isinstance (slice1 , slice ):
410- merged_slices .append (
411- merge_two_slices (
412- fgraph , slice1 , xshape [pos_1 ], slices2 [pos_2 ], ushape [pos_2 ]
413- )
414- )
415- pos_2 += 1
416- else :
417- merged_slices .append (slice1 )
418- pos_1 += 1
419-
420- if pos_2 < len (slices2 ):
421- merged_slices += slices2 [pos_2 :]
422- else :
423- merged_slices += slices1 [pos_1 :]
377+ # We can merge :)
378+ # x actual tensor on which we are picking slices
379+ x = u .owner .inputs [0 ]
380+ # slices of the first applied subtensor
381+ slices1 = get_idx_list (u .owner .inputs , u .owner .op .idx_list )
382+ slices2 = get_idx_list (node .inputs , node .op .idx_list )
424383
425- merged_slices = tuple (as_index_constant (s ) for s in merged_slices )
426- subtens = Subtensor (merged_slices )
384+ # Don't try to do the optimization on do-while scan outputs,
385+ # as it will create a dependency on the shape of the outputs
386+ if (
387+ x .owner is not None
388+ and isinstance (x .owner .op , Scan )
389+ and x .owner .op .info .as_while
390+ ):
391+ return None
427392
428- sl_ins = get_slice_elements (
429- merged_slices , lambda x : isinstance (x , Variable )
393+ # Get the shapes of the vectors !
394+ try :
395+ # try not to introduce new shape into the graph
396+ xshape = fgraph .shape_feature .shape_of [x ]
397+ ushape = fgraph .shape_feature .shape_of [u ]
398+ except AttributeError :
399+ # Following the suggested use of shape_feature which should
400+ # consider the case when the compilation mode doesn't
401+ # include the ShapeFeature
402+ xshape = x .shape
403+ ushape = u .shape
404+
405+ merged_slices = []
406+ pos_2 = 0
407+ pos_1 = 0
408+ while (pos_1 < len (slices1 )) and (pos_2 < len (slices2 )):
409+ slice1 = slices1 [pos_1 ]
410+ if isinstance (slice1 , slice ):
411+ merged_slices .append (
412+ merge_two_slices (
413+ fgraph , slice1 , xshape [pos_1 ], slices2 [pos_2 ], ushape [pos_2 ]
414+ )
430415 )
431- # Do not call make_node for test_value
432- out = subtens (x , * sl_ins )
416+ pos_2 += 1
417+ else :
418+ merged_slices .append (slice1 )
419+ pos_1 += 1
433420
434- # Copy over previous output stacktrace
435- # and stacktrace from previous slicing operation.
436- # Why? Because, the merged slicing operation could have failed
437- # because of either of the two original slicing operations
438- orig_out = node .outputs [0 ]
439- copy_stack_trace ([orig_out , node .inputs [0 ]], out )
440- return [out ]
421+ if pos_2 < len (slices2 ):
422+ merged_slices += slices2 [pos_2 :]
423+ else :
424+ merged_slices += slices1 [pos_1 :]
425+
426+ merged_slices = tuple (as_index_constant (s ) for s in merged_slices )
427+ subtens = Subtensor (merged_slices )
428+
429+ sl_ins = get_slice_elements (merged_slices , lambda x : isinstance (x , Variable ))
430+ # Do not call make_node for test_value
431+ out = subtens (x , * sl_ins )
432+
433+ # Copy over previous output stacktrace
434+ # and stacktrace from previous slicing operation.
435+ # Why? Because, the merged slicing operation could have failed
436+ # because of either of the two original slicing operations
437+ orig_out = node .outputs [0 ]
438+ copy_stack_trace ([orig_out , node .inputs [0 ]], out )
439+ return [out ]
441440
442441
443442@register_specialize
@@ -826,6 +825,12 @@ def merge_two_slices(fgraph, slice1, len1, slice2, len2):
826825 if not isinstance (slice1 , slice ):
827826 raise ValueError ("slice1 should be of type `slice`" )
828827
828+ # Simple case where one of the slices is useless
829+ if is_full_slice (slice1 ):
830+ return slice2
831+ elif is_full_slice (slice2 ):
832+ return slice1
833+
829834 sl1 , reverse1 = get_canonical_form_slice (slice1 , len1 )
830835 sl2 , reverse2 = get_canonical_form_slice (slice2 , len2 )
831836
0 commit comments