@@ -948,6 +948,9 @@ def perform(self, node, inputs, out_):
948948 out [0 ] = np .asarray (x .__getitem__ (cdata ))
949949
950950 def infer_shape (self , fgraph , node , shapes ):
951+ def _is_constant (const , x ):
952+ return isinstance (const , Constant ) and const .data .item () == x
953+
951954 xshp = shapes [0 ]
952955 assert len (xshp ) == node .inputs [0 ].ndim
953956 outshp = []
@@ -961,10 +964,17 @@ def infer_shape(self, fgraph, node, shapes):
961964 # If it is the default (None, None, None) slice, or a variant,
962965 # the shape will be xl
963966 if (
964- (idx .start in [None , 0 ])
965- and (idx .stop in [None , sys .maxsize ])
966- and (idx .step is None or idx .step == 1 )
967+ (idx .start is None or _is_constant (idx .start , 0 ))
968+ and (idx .stop is None or _is_constant (idx .stop , sys .maxsize ))
969+ and (idx .step is None or _is_constant (idx .step , 1 ))
970+ ):
971+ outshp .append (xl )
972+ elif (
973+ (idx .start is None )
974+ and (idx .stop is None )
975+ and _is_constant (idx .step , - 1 )
967976 ):
977+ # Reverse slice
968978 outshp .append (xl )
969979 else :
970980 cnf = get_canonical_form_slice (idx , xl )[0 ]
0 commit comments