|
1 | 1 | from collections.abc import Collection |
2 | | -from functools import reduce |
3 | 2 | from typing import Iterable, Set, Tuple, Union |
4 | 3 |
|
5 | 4 | import numpy as np |
6 | | -import numpy.core.numeric |
7 | 5 | from numpy.core.multiarray import normalize_axis_index |
8 | 6 |
|
9 | 7 | import pytensor |
|
14 | 12 | disconnected_type, |
15 | 13 | grad_undefined, |
16 | 14 | ) |
17 | | -from pytensor.graph.basic import Apply, Constant, Variable, equal_computations |
| 15 | +from pytensor.graph.basic import Apply, Constant, Variable |
18 | 16 | from pytensor.graph.op import Op |
19 | 17 | from pytensor.link.c.op import COp |
20 | 18 | from pytensor.link.c.params_type import ParamsType |
|
23 | 21 | from pytensor.raise_op import Assert |
24 | 22 | from pytensor.scalar import int32 as int_t |
25 | 23 | from pytensor.scalar import upcast |
26 | | -from pytensor.scalar.basic import Composite |
27 | 24 | from pytensor.tensor import basic as at |
28 | 25 | from pytensor.tensor import get_vector_length |
29 | 26 | from pytensor.tensor.exceptions import NotScalarConstantError |
30 | 27 | from pytensor.tensor.math import abs as at_abs |
31 | | -from pytensor.tensor.math import all as at_all |
| 28 | +from pytensor.tensor.math import all as pt_all |
| 29 | +from pytensor.tensor.math import eq as pt_eq |
32 | 30 | from pytensor.tensor.math import ge, lt, maximum, minimum, prod |
33 | 31 | from pytensor.tensor.math import sum as at_sum |
34 | 32 | from pytensor.tensor.subtensor import advanced_inc_subtensor1, set_subtensor |
@@ -536,7 +534,7 @@ def bincount(x, weights=None, minlength=None, assert_nonneg=False): |
536 | 534 |
|
537 | 535 | if assert_nonneg: |
538 | 536 | assert_op = Assert("Input to bincount has negative values!") |
539 | | - x = assert_op(x, at_all(x >= 0)) |
| 537 | + x = assert_op(x, pt_all(x >= 0)) |
540 | 538 |
|
541 | 539 | max_value = at.cast(x.max() + 1, "int64") |
542 | 540 |
|
@@ -1436,6 +1434,13 @@ def ravel_multi_index(multi_index, dims, mode="raise", order="C"): |
1436 | 1434 | return RavelMultiIndex(mode=mode, order=order)(*args) |
1437 | 1435 |
|
1438 | 1436 |
|
| 1437 | +_broadcast_assert = Assert( |
| 1438 | + "Could not broadcast dimensions. Broadcasting is only allowed along " |
| 1439 | + "axes that have a statically known length 1. Use `specify_shape` to " |
| 1440 | + "inform PyTensor of a known shape." |
| 1441 | +) |
| 1442 | + |
| 1443 | + |
1439 | 1444 | def broadcast_shape(*arrays, **kwargs) -> Tuple[aes.ScalarVariable, ...]: |
1440 | 1445 | """Compute the shape resulting from broadcasting arrays. |
1441 | 1446 |
|
@@ -1510,119 +1515,45 @@ def broadcast_shape_iter( |
1510 | 1515 | result_dims = [] |
1511 | 1516 |
|
1512 | 1517 | for dim_shapes in zip(*array_shapes): |
1513 | | - # Get the shapes in this dimension that are not definitively |
1514 | | - # broadcastable (i.e. not symbolically known to be broadcastable) |
1515 | | - maybe_non_bcast_shapes = [shape for shape in dim_shapes if shape != one_at] |
| 1518 | + # Get the shapes in this dimension that are not broadcastable |
| 1519 | + # (i.e. not symbolically known to be broadcastable) |
| 1520 | + non_bcast_shapes = [shape for shape in dim_shapes if shape != one_at] |
1516 | 1521 |
|
1517 | | - if len(maybe_non_bcast_shapes) == 0: |
| 1522 | + if len(non_bcast_shapes) == 0: |
1518 | 1523 | # Every shape was broadcastable in this dimension |
1519 | 1524 | result_dims.append(one_at) |
1520 | | - elif len(maybe_non_bcast_shapes) == 1: |
| 1525 | + elif len(non_bcast_shapes) == 1: |
1521 | 1526 | # Only one shape might not be broadcastable in this dimension |
1522 | | - result_dims.extend(maybe_non_bcast_shapes) |
| 1527 | + result_dims.extend(non_bcast_shapes) |
1523 | 1528 | else: |
1524 | 1529 | # More than one shape might not be broadcastable in this dimension |
1525 | | - |
1526 | 1530 | nonconst_nb_shapes: Set[int] = set() |
1527 | 1531 | const_nb_shapes: Set[Variable] = set() |
1528 | | - for shape in maybe_non_bcast_shapes: |
| 1532 | + for shape in non_bcast_shapes: |
1529 | 1533 | if isinstance(shape, Constant): |
1530 | 1534 | const_nb_shapes.add(shape.value.item()) |
1531 | 1535 | else: |
1532 | 1536 | nonconst_nb_shapes.add(shape) |
1533 | 1537 |
|
1534 | 1538 | if len(const_nb_shapes) > 1: |
1535 | | - raise ValueError("Could not broadcast dimensions") |
1536 | | - elif len(const_nb_shapes) == 1: |
1537 | | - (const_nb_shape,) = const_nb_shapes |
1538 | | - |
1539 | | - assert const_nb_shape != 1 |
1540 | | - |
1541 | | - const_nt_shape_var = pytensor.scalar.ScalarConstant( |
1542 | | - pytensor.scalar.int64, const_nb_shape |
| 1539 | + raise ValueError( |
| 1540 | + f"Could not broadcast dimensions. Incompatible shapes were {array_shapes}." |
1543 | 1541 | ) |
1544 | 1542 |
|
1545 | | - if len(nonconst_nb_shapes) > 0: |
1546 | | - # All the potential non-broadcast shapes need to either |
1547 | | - # be broadcastable or equal to the one non-broadcastable |
1548 | | - # constant `const_nt_shape_var`. |
1549 | | - assert_dim = Assert("Could not broadcast dimensions") |
1550 | | - |
1551 | | - scalar_nonconst_nb_shapes = [ |
1552 | | - at.scalar_from_tensor(s) |
1553 | | - if isinstance(s.type, TensorType) |
1554 | | - else s |
1555 | | - for s in nonconst_nb_shapes |
1556 | | - ] |
1557 | | - |
1558 | | - dummy_nonconst_nb_shapes = [ |
1559 | | - aes.get_scalar_type(dtype=v.dtype)() |
1560 | | - for v in scalar_nonconst_nb_shapes |
1561 | | - ] |
1562 | | - assert_cond = reduce( |
1563 | | - aes.and_, |
1564 | | - ( |
1565 | | - aes.or_( |
1566 | | - aes.eq(nbv, one_at), aes.eq(nbv, const_nt_shape_var) |
1567 | | - ) |
1568 | | - for nbv in dummy_nonconst_nb_shapes |
1569 | | - ), |
1570 | | - ) |
1571 | | - assert_cond_op = Composite(dummy_nonconst_nb_shapes, [assert_cond]) |
1572 | | - |
1573 | | - bcast_dim = assert_dim( |
1574 | | - const_nt_shape_var, assert_cond_op(*scalar_nonconst_nb_shapes) |
1575 | | - ) |
1576 | | - else: |
1577 | | - bcast_dim = const_nt_shape_var |
| 1543 | + if len(const_nb_shapes) == 1: |
| 1544 | + (first_length,) = const_nb_shapes |
| 1545 | + other_lengths = nonconst_nb_shapes |
| 1546 | + first_length = aes.as_scalar(first_length) |
1578 | 1547 | else: |
1579 | | - # There are no constant, non-broadcastable shapes in this |
1580 | | - # dimension. |
1581 | | - |
1582 | | - all_dims_equal = all( |
1583 | | - # TODO FIXME: This is a largely deficient, and expensive, means |
1584 | | - # of comparing graphs (and especially shapes) |
1585 | | - equal_computations([maybe_non_bcast_shapes[0]], [dim]) |
1586 | | - for dim in maybe_non_bcast_shapes[1:] |
1587 | | - ) |
| 1548 | + first_length, *other_lengths = nonconst_nb_shapes |
1588 | 1549 |
|
1589 | | - if all_dims_equal: |
1590 | | - result_dims.append(maybe_non_bcast_shapes[0]) |
1591 | | - continue |
1592 | | - |
1593 | | - scalar_maybe_non_bcast_shapes = [ |
1594 | | - at.scalar_from_tensor(s) if isinstance(s.type, TensorType) else s |
1595 | | - for s in maybe_non_bcast_shapes |
1596 | | - ] |
1597 | | - dummy_maybe_non_bcast_shapes = [ |
1598 | | - aes.get_scalar_type(dtype=v.dtype)() |
1599 | | - for v in scalar_maybe_non_bcast_shapes |
1600 | | - ] |
1601 | | - non_bcast_vec = [ |
1602 | | - aes.switch(aes.eq(nbv, 1), -one_at, nbv) |
1603 | | - for nbv in dummy_maybe_non_bcast_shapes |
1604 | | - ] |
1605 | | - dim_max = aes.abs(reduce(aes.scalar_maximum, non_bcast_vec)) |
1606 | | - dim_max_op = Composite(dummy_maybe_non_bcast_shapes, [dim_max]) |
1607 | | - |
1608 | | - dummy_dim_max = dim_max_op(*dummy_maybe_non_bcast_shapes) |
1609 | | - |
1610 | | - assert_dim = Assert("Could not broadcast dimensions") |
1611 | | - assert_cond = reduce( |
1612 | | - aes.and_, |
1613 | | - ( |
1614 | | - aes.or_(aes.eq(nbv, -one_at), aes.eq(nbv, dummy_dim_max)) |
1615 | | - for nbv in non_bcast_vec |
1616 | | - ), |
1617 | | - ) |
1618 | | - assert_cond_op = Composite(dummy_maybe_non_bcast_shapes, [assert_cond]) |
1619 | | - |
1620 | | - bcast_dim = assert_dim( |
1621 | | - dim_max_op(*scalar_maybe_non_bcast_shapes), |
1622 | | - assert_cond_op(*scalar_maybe_non_bcast_shapes), |
1623 | | - ) |
| 1550 | + if len(other_lengths) == 0: |
| 1551 | + result_dims.append(first_length) |
| 1552 | + continue |
1624 | 1553 |
|
1625 | | - result_dims.append(bcast_dim) |
| 1554 | + # Add assert that all remaining shapes are equal |
| 1555 | + condition = pt_all([pt_eq(first_length, other) for other in other_lengths]) |
| 1556 | + result_dims.append(_broadcast_assert(first_length, condition)) |
1626 | 1557 |
|
1627 | 1558 | return tuple(result_dims) |
1628 | 1559 |
|
|
0 commit comments