|
30 | 30 | from pytensor.graph.basic import Apply, graph_inputs |
31 | 31 | from pytensor.graph.null_type import NullType |
32 | 32 | from pytensor.graph.op import Op |
| 33 | +from pytensor.scan.op import Scan |
33 | 34 | from pytensor.tensor.math import add, dot, exp, sigmoid, sqr, tanh |
34 | 35 | from pytensor.tensor.math import sum as pt_sum |
35 | 36 | from pytensor.tensor.random import RandomStream |
@@ -1036,6 +1037,17 @@ def test_jacobian_scalar(): |
1036 | 1037 | vx = np.asarray(rng.uniform(), dtype=pytensor.config.floatX) |
1037 | 1038 | assert np.allclose(f(vx), 2) |
1038 | 1039 |
|
| 1040 | + # test when input is a shape (1,) vector -- should still be treated as a scalar |
| 1041 | + Jx = jacobian(y[None], x) |
| 1042 | + f = pytensor.function([x], Jx) |
| 1043 | + |
| 1044 | + # Ensure we hit the scalar grad case (doesn't use scan) |
| 1045 | + nodes = f.maker.fgraph.apply_nodes |
| 1046 | + assert not any(isinstance(node.op, Scan) for node in nodes) |
| 1047 | + |
| 1048 | + vx = np.asarray(rng.uniform(), dtype=pytensor.config.floatX) |
| 1049 | + assert np.allclose(f(vx), 2) |
| 1050 | + |
1039 | 1051 | # test when the jacobian is called with a tuple as wrt |
1040 | 1052 | Jx = jacobian(y, (x,)) |
1041 | 1053 | assert isinstance(Jx, tuple) |
|
0 commit comments