File tree Expand file tree Collapse file tree 1 file changed +5
-0
lines changed
pytensor/tensor/rewriting Expand file tree Collapse file tree 1 file changed +5
-0
lines changed Original file line number Diff line number Diff line change @@ -916,6 +916,10 @@ def specialize_matmul_to_batched_dot(fgraph, node):
916916 """
917917 x , y = node .inputs
918918
919+ if x .type .ndim < 3 :
920+ # This doesn't actually have a batch dimension
921+ return None
922+
919923 # BatchedDot does not allow implicit broadcasting of the batch dimensions
920924 # We do not want to explicitly broadcast as it may result in huge arrays
921925 if x .type .broadcastable [:- 2 ] != y .type .broadcastable [:- 2 ]:
@@ -926,6 +930,7 @@ def specialize_matmul_to_batched_dot(fgraph, node):
926930 if len (x_shape ) > 3 :
927931 # If we have more than one batch dim, ravel it
928932 x = x .reshape ((- 1 , x_shape [- 2 ], x_shape [- 1 ]))
933+ if len (y_shape ) > 3 :
929934 y = y .reshape ((- 1 , y_shape [- 2 ], y_shape [- 1 ]))
930935
931936 new_out = _batched_dot (x , y )
You can’t perform that action at this time.
0 commit comments