|
2 | 2 | import pytest |
3 | 3 |
|
4 | 4 | from pytensor.compile.function import function |
5 | | -from pytensor.compile.mode import Mode |
6 | 5 | from pytensor.configdefaults import config |
7 | 6 | from pytensor.graph.fg import FunctionGraph |
8 | | -from pytensor.graph.op import get_test_value |
9 | | -from pytensor.graph.rewriting.db import RewriteDatabaseQuery |
10 | | -from pytensor.link.jax import JAXLinker |
11 | | -from pytensor.tensor import blas as pt_blas |
12 | 7 | from pytensor.tensor import nlinalg as pt_nlinalg |
13 | | -from pytensor.tensor.math import Argmax, Max, maximum |
14 | | -from pytensor.tensor.math import max as pt_max |
15 | | -from pytensor.tensor.type import dvector, matrix, scalar, tensor3, vector |
| 8 | +from pytensor.tensor.type import matrix |
16 | 9 | from tests.link.jax.test_basic import compare_jax_and_py |
17 | 10 |
|
18 | 11 |
|
19 | 12 | jax = pytest.importorskip("jax") |
20 | 13 |
|
21 | 14 |
|
22 | | -def test_jax_BatchedDot(): |
23 | | - # tensor3 . tensor3 |
24 | | - a = tensor3("a") |
25 | | - a.tag.test_value = ( |
26 | | - np.linspace(-1, 1, 10 * 5 * 3).astype(config.floatX).reshape((10, 5, 3)) |
27 | | - ) |
28 | | - b = tensor3("b") |
29 | | - b.tag.test_value = ( |
30 | | - np.linspace(1, -1, 10 * 3 * 2).astype(config.floatX).reshape((10, 3, 2)) |
31 | | - ) |
32 | | - out = pt_blas.BatchedDot()(a, b) |
33 | | - fgraph = FunctionGraph([a, b], [out]) |
34 | | - compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) |
35 | | - |
36 | | - # A dimension mismatch should raise a TypeError for compatibility |
37 | | - inputs = [get_test_value(a)[:-1], get_test_value(b)] |
38 | | - opts = RewriteDatabaseQuery(include=[None], exclude=["cxx_only", "BlasOpt"]) |
39 | | - jax_mode = Mode(JAXLinker(), opts) |
40 | | - pytensor_jax_fn = function(fgraph.inputs, fgraph.outputs, mode=jax_mode) |
41 | | - with pytest.raises(TypeError): |
42 | | - pytensor_jax_fn(*inputs) |
43 | | - |
44 | | - |
45 | 15 | def test_jax_basic_multiout(): |
46 | 16 | rng = np.random.default_rng(213234) |
47 | 17 |
|
@@ -79,45 +49,6 @@ def assert_fn(x, y): |
79 | 49 | compare_jax_and_py(out_fg, [X.astype(config.floatX)], assert_fn=assert_fn) |
80 | 50 |
|
81 | 51 |
|
82 | | -def test_jax_max_and_argmax(): |
83 | | - # Test that a single output of a multi-output `Op` can be used as input to |
84 | | - # another `Op` |
85 | | - x = dvector() |
86 | | - mx = Max([0])(x) |
87 | | - amx = Argmax([0])(x) |
88 | | - out = mx * amx |
89 | | - out_fg = FunctionGraph([x], [out]) |
90 | | - compare_jax_and_py(out_fg, [np.r_[1, 2]]) |
91 | | - |
92 | | - |
93 | | -def test_tensor_basics(): |
94 | | - y = vector("y") |
95 | | - y.tag.test_value = np.r_[1.0, 2.0].astype(config.floatX) |
96 | | - x = vector("x") |
97 | | - x.tag.test_value = np.r_[3.0, 4.0].astype(config.floatX) |
98 | | - A = matrix("A") |
99 | | - A.tag.test_value = np.empty((2, 2), dtype=config.floatX) |
100 | | - alpha = scalar("alpha") |
101 | | - alpha.tag.test_value = np.array(3.0, dtype=config.floatX) |
102 | | - beta = scalar("beta") |
103 | | - beta.tag.test_value = np.array(5.0, dtype=config.floatX) |
104 | | - |
105 | | - # This should be converted into a `Gemv` `Op` when the non-JAX compatible |
106 | | - # optimizations are turned on; however, when using JAX mode, it should |
107 | | - # leave the expression alone. |
108 | | - out = y.dot(alpha * A).dot(x) + beta * y |
109 | | - fgraph = FunctionGraph([y, x, A, alpha, beta], [out]) |
110 | | - compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) |
111 | | - |
112 | | - out = maximum(y, x) |
113 | | - fgraph = FunctionGraph([y, x], [out]) |
114 | | - compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) |
115 | | - |
116 | | - out = pt_max(y) |
117 | | - fgraph = FunctionGraph([y], [out]) |
118 | | - compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) |
119 | | - |
120 | | - |
121 | 52 | def test_pinv(): |
122 | 53 | x = matrix("x") |
123 | 54 | x_inv = pt_nlinalg.pinv(x) |
|
0 commit comments