File tree Expand file tree Collapse file tree 1 file changed +5
-5
lines changed
pytensor/tensor/rewriting Expand file tree Collapse file tree 1 file changed +5
-5
lines changed Original file line number Diff line number Diff line change @@ -613,10 +613,6 @@ def local_subtensor_make_vector(fgraph, node):
613613 something more general for constant ``*Subtensor*`` graphs (or perhaps
614614 include this kind of work in the constant folding).
615615 """
616-
617- if not isinstance (node .op , Subtensor | AdvancedSubtensor1 ):
618- return False
619-
620616 x = node .inputs [0 ]
621617
622618 if not (x .owner and isinstance (x .owner .op , MakeVector )):
@@ -666,7 +662,11 @@ def local_subtensor_make_vector(fgraph, node):
666662 const_slice = get_constant_idx (
667663 node .op .idx_list , node .inputs , allow_partial = False
668664 )[0 ]
669- ret = make_vector_op (* x .owner .inputs [const_slice ])
665+ sliced_inputs = x .owner .inputs [const_slice ]
666+ if len (sliced_inputs ) == 1 :
667+ ret = expand_dims (sliced_inputs [0 ], axis = 0 )
668+ else :
669+ ret = make_vector_op (* sliced_inputs )
670670 copy_stack_trace (node .outputs , ret )
671671 return [ret ]
672672 except NotScalarConstantError :
You can’t perform that action at this time.
0 commit comments