|
12 | 12 | from pytensor import function |
13 | 13 | from pytensor.compile import DeepCopyOp, shared |
14 | 14 | from pytensor.compile.io import In |
| 15 | +from pytensor.compile.mode import Mode |
15 | 16 | from pytensor.configdefaults import config |
| 17 | +from pytensor.gradient import grad |
16 | 18 | from pytensor.graph.op import get_test_value |
17 | 19 | from pytensor.graph.rewriting.utils import is_same_graph |
18 | 20 | from pytensor.printing import pprint |
|
22 | 24 | from pytensor.tensor.elemwise import DimShuffle |
23 | 25 | from pytensor.tensor.math import exp, isinf |
24 | 26 | from pytensor.tensor.math import sum as pt_sum |
| 27 | +from pytensor.tensor.shape import specify_shape |
25 | 28 | from pytensor.tensor.subtensor import ( |
26 | 29 | AdvancedIncSubtensor, |
27 | 30 | AdvancedIncSubtensor1, |
@@ -1660,6 +1663,25 @@ def just_numeric_args(a, b): |
1660 | 1663 | ), |
1661 | 1664 | ) |
1662 | 1665 |
|
| 1666 | + def test_grad_broadcastable_specialization(self): |
| 1667 | + # Make sure gradient does not fail when gx has a more precise static_shape after indexing. |
| 1668 | + # This is a regression test for a bug reported in |
| 1669 | + # https://discourse.pymc.io/t/marginalized-mixture-wont-begin-sampling-throws-assertion-error/15969 |
| 1670 | + |
| 1671 | + x = vector("x") # Unknown write time shape = (2,) |
| 1672 | + out = x.zeros_like() |
| 1673 | + |
| 1674 | + # Update a subtensor of unknown write time shape = (1,) |
| 1675 | + out = out[1:].set(exp(x[1:])) |
| 1676 | + out = specify_shape(out, 2) |
| 1677 | + gx = grad(out.sum(), x) |
| 1678 | + |
| 1679 | + mode = Mode(linker="py", optimizer=None) |
| 1680 | + np.testing.assert_allclose( |
| 1681 | + gx.eval({x: [1, 1]}, mode=mode), |
| 1682 | + [0, np.e], |
| 1683 | + ) |
| 1684 | + |
1663 | 1685 |
|
1664 | 1686 | class TestIncSubtensor1: |
1665 | 1687 | def setup_method(self): |
|
0 commit comments