|
8 | 8 | import pytensor.tensor.basic |
9 | 9 | from pytensor.configdefaults import config |
10 | 10 | from pytensor.gradient import DisconnectedType |
11 | | -from pytensor.graph.basic import Apply |
| 11 | +from pytensor.graph.basic import Apply, Constant |
12 | 12 | from pytensor.graph.null_type import NullType |
13 | 13 | from pytensor.graph.replace import _vectorize_node, _vectorize_not_needed |
14 | 14 | from pytensor.graph.utils import MethodNotDefined |
@@ -797,10 +797,24 @@ def _check_runtime_broadcast(node, inputs): |
797 | 797 | ) |
798 | 798 |
|
799 | 799 | def infer_shape(self, fgraph, node, i_shapes) -> list[tuple[TensorVariable, ...]]: |
800 | | - from pytensor.tensor.extra_ops import broadcast_shape |
| 800 | + out_shape = list(node.outputs[0].type.shape) |
| 801 | + if missing_dims := [i for i, s in enumerate(out_shape) if s is None]: |
| 802 | + for inp_idx, inp in enumerate(node.inputs): |
| 803 | + inp_st_shape = inp.type.shape |
| 804 | + for d in missing_dims: |
| 805 | + if inp_st_shape[d] == 1: |
| 806 | + continue # Nothing to learn from this input |
| 807 | + if inp_st_shape[d] is not None: |
| 808 | + out_shape[d] = inp_st_shape[d] |
| 809 | + missing_dims.remove(d) |
| 810 | + else: |
| 811 | + out_shape[d] = new_dim = i_shapes[inp_idx][d] |
| 812 | + if isinstance(new_dim, Constant): |
| 813 | + missing_dims.remove(d) |
| 814 | + if not missing_dims: |
| 815 | + break |
801 | 816 |
|
802 | | - out_shape = broadcast_shape(*i_shapes, arrays_are_shapes=True) |
803 | | - return [tuple(as_tensor_variable(s) for s in out_shape)] * len(node.outputs) |
| 817 | + return [tuple(out_shape) for _ in node.outputs] |
804 | 818 |
|
805 | 819 | def _c_all(self, node, nodename, inames, onames, sub): |
806 | 820 | # Some `Op`s directly call `Elemwise._c_all` or `Elemwise.c_code` |
|
0 commit comments