Skip to content

Commit d679ef7

Browse files
committed
Simpler Elemwise.infer_shape
1 parent 70be87e commit d679ef7

File tree

1 file changed

+18
-4
lines changed

1 file changed

+18
-4
lines changed

pytensor/tensor/elemwise.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import pytensor.tensor.basic
99
from pytensor.configdefaults import config
1010
from pytensor.gradient import DisconnectedType
11-
from pytensor.graph.basic import Apply
11+
from pytensor.graph.basic import Apply, Constant
1212
from pytensor.graph.null_type import NullType
1313
from pytensor.graph.replace import _vectorize_node, _vectorize_not_needed
1414
from pytensor.graph.utils import MethodNotDefined
@@ -797,10 +797,24 @@ def _check_runtime_broadcast(node, inputs):
797797
)
798798

799799
def infer_shape(self, fgraph, node, i_shapes) -> list[tuple[TensorVariable, ...]]:
800-
from pytensor.tensor.extra_ops import broadcast_shape
800+
out_shape = list(node.outputs[0].type.shape)
801+
if missing_dims := [i for i, s in enumerate(out_shape) if s is None]:
802+
for inp_idx, inp in enumerate(node.inputs):
803+
inp_st_shape = inp.type.shape
804+
for d in missing_dims:
805+
if inp_st_shape[d] == 1:
806+
continue # Nothing to learn from this input
807+
if inp_st_shape[d] is not None:
808+
out_shape[d] = inp_st_shape[d]
809+
missing_dims.remove(d)
810+
else:
811+
out_shape[d] = new_dim = i_shapes[inp_idx][d]
812+
if isinstance(new_dim, Constant):
813+
missing_dims.remove(d)
814+
if not missing_dims:
815+
break
801816

802-
out_shape = broadcast_shape(*i_shapes, arrays_are_shapes=True)
803-
return [tuple(as_tensor_variable(s) for s in out_shape)] * len(node.outputs)
817+
return [tuple(out_shape) for _ in node.outputs]
804818

805819
def _c_all(self, node, nodename, inames, onames, sub):
806820
# Some `Op`s directly call `Elemwise._c_all` or `Elemwise.c_code`

0 commit comments

Comments
 (0)