|
27 | 27 | from pytensor.compile.sharedvalue import shared |
28 | 28 | from pytensor.configdefaults import config |
29 | 29 | from pytensor.gradient import NullTypeGradError, Rop, disconnected_grad, grad, hessian |
| 30 | +from pytensor.graph import vectorize_graph |
30 | 31 | from pytensor.graph.basic import Apply, ancestors, equal_computations |
31 | 32 | from pytensor.graph.fg import FunctionGraph |
32 | 33 | from pytensor.graph.op import Op |
@@ -1178,6 +1179,17 @@ def get_sum_of_grad(input0, input1): |
1178 | 1179 |
|
1179 | 1180 | utt.verify_grad(get_sum_of_grad, inputs_test_values, rng=rng) |
1180 | 1181 |
|
| 1182 | + def test_blockwise_scan(self): |
| 1183 | + x = pt.tensor("x", shape=()) |
| 1184 | + out, _ = scan(lambda x: x + 1, outputs_info=[x], n_steps=10) |
| 1185 | + x_vec = pt.tensor("x_vec", shape=(None,)) |
| 1186 | + out_vec = vectorize_graph(out, {x: x_vec}) |
| 1187 | + |
| 1188 | + fn = function([x_vec], out_vec) |
| 1189 | + o1 = fn([1, 2, 3]) |
| 1190 | + o2 = np.arange(2, 12) + np.arange(3).reshape(-1, 1) |
| 1191 | + assert np.allclose(o1, o2) |
| 1192 | + |
1181 | 1193 | def test_connection_pattern(self): |
1182 | 1194 | """Test `Scan.connection_pattern` in the presence of recurrent outputs with multiple taps.""" |
1183 | 1195 |
|
|
0 commit comments