|
11 | 11 | import pytensor.tensor.math as tm |
12 | 12 | from pytensor import compile, config, function, shared |
13 | 13 | from pytensor.compile.io import In, Out |
14 | | -from pytensor.compile.mode import get_default_mode |
| 14 | +from pytensor.compile.mode import Mode, get_default_mode |
15 | 15 | from pytensor.compile.ops import DeepCopyOp |
16 | 16 | from pytensor.gradient import grad, hessian |
17 | 17 | from pytensor.graph.basic import Apply |
@@ -2002,45 +2002,65 @@ def test_split_static_shape(self): |
2002 | 2002 | y = Split(2)(x, 0, [s, 5 - s])[0] |
2003 | 2003 | assert y.type.shape == (None,) |
2004 | 2004 |
|
2005 | | - |
2006 | | -def test_join_inplace(): |
2007 | | - # Test join to work inplace. |
2008 | | - # |
2009 | | - # This function tests the case when several elements are passed to the |
2010 | | - # join function but all except one of them are empty. In this case join |
2011 | | - # should work inplace and the output should be the view of the non-empty |
2012 | | - # element. |
2013 | | - s = lscalar() |
2014 | | - x = vector("x") |
2015 | | - z = at.zeros((s,)) |
2016 | | - |
2017 | | - join = Join(view=0) |
2018 | | - c = join(0, x, z, z) |
2019 | | - |
2020 | | - f = pytensor.function([In(x, borrow=True), s], Out(c, borrow=True)) |
2021 | | - |
2022 | | - data = np.array([3, 4, 5], dtype=config.floatX) |
2023 | | - |
2024 | | - if config.mode not in ["DebugMode", "DEBUG_MODE"]: |
2025 | | - assert f(data, 0) is data |
2026 | | - assert np.allclose(f(data, 0), [3, 4, 5]) |
2027 | | - |
2028 | | - |
2029 | | -def test_join_oneInput(): |
2030 | | - # Test join when only 1 input is given. |
2031 | | - # |
2032 | | - # This functions tests the case when concatenate is called |
2033 | | - # on an array of tensors but the array has only one element. |
2034 | | - # In this case, we would like to avoid the computational |
2035 | | - # overhead of concatenation of one element. |
2036 | | - x_0 = fmatrix() |
2037 | | - x_1 = fmatrix() |
2038 | | - x_2 = fvector() |
2039 | | - join_0 = at.concatenate([x_0], axis=1) |
2040 | | - join_1 = at.concatenate([x_0, x_1, shape_padright(x_2)], axis=1) |
2041 | | - |
2042 | | - assert join_0 is x_0 |
2043 | | - assert join_1 is not x_0 |
| 2005 | + def test_join_inplace(self): |
| 2006 | + # Test join to work inplace. |
| 2007 | + # |
| 2008 | + # This function tests the case when several elements are passed to the |
| 2009 | + # join function but all except one of them are empty. In this case join |
| 2010 | + # should work inplace and the output should be the view of the non-empty |
| 2011 | + # element. |
| 2012 | + s = lscalar() |
| 2013 | + x = vector("x") |
| 2014 | + z = at.zeros((s,)) |
| 2015 | + |
| 2016 | + join = Join(view=0) |
| 2017 | + c = join(0, x, z, z) |
| 2018 | + |
| 2019 | + f = pytensor.function([In(x, borrow=True), s], Out(c, borrow=True)) |
| 2020 | + |
| 2021 | + data = np.array([3, 4, 5], dtype=config.floatX) |
| 2022 | + |
| 2023 | + if config.mode not in ["DebugMode", "DEBUG_MODE"]: |
| 2024 | + assert f(data, 0) is data |
| 2025 | + assert np.allclose(f(data, 0), [3, 4, 5]) |
| 2026 | + |
| 2027 | + def test_join_oneInput(self): |
| 2028 | + # Test join when only 1 input is given. |
| 2029 | + # |
| 2030 | + # This functions tests the case when concatenate is called |
| 2031 | + # on an array of tensors but the array has only one element. |
| 2032 | + # In this case, we would like to avoid the computational |
| 2033 | + # overhead of concatenation of one element. |
| 2034 | + x_0 = fmatrix() |
| 2035 | + x_1 = fmatrix() |
| 2036 | + x_2 = fvector() |
| 2037 | + join_0 = at.concatenate([x_0], axis=1) |
| 2038 | + join_1 = at.concatenate([x_0, x_1, shape_padright(x_2)], axis=1) |
| 2039 | + |
| 2040 | + assert join_0 is x_0 |
| 2041 | + assert join_1 is not x_0 |
| 2042 | + |
| 2043 | + @pytest.mark.parametrize("linker", ("py", "c")) |
| 2044 | + def test_split_view(self, linker): |
| 2045 | + x = vector("x") |
| 2046 | + axis = 0 |
| 2047 | + op = Split(len_splits=3) |
| 2048 | + assert op.view_map == {0: [0], 1: [0], 2: [0]} |
| 2049 | + splits = op(x, axis, [0, 3, 2]) |
| 2050 | + |
| 2051 | + mode = Mode(linker) |
| 2052 | + f = pytensor.function( |
| 2053 | + [In(x, borrow=True)], [Out(s, borrow=True) for s in splits], mode=mode |
| 2054 | + ) |
| 2055 | + x_test = np.arange(5, dtype=config.floatX) |
| 2056 | + res = f(x_test) |
| 2057 | + for r, expected in zip(res, ([], [0, 1, 2], [3, 4])): |
| 2058 | + assert np.allclose(r, expected) |
| 2059 | + if linker == "py": |
| 2060 | + assert r.base is x_test |
| 2061 | + else: |
| 2062 | + # C impl always makes a copy |
| 2063 | + assert r.base is not x_test |
2044 | 2064 |
|
2045 | 2065 |
|
2046 | 2066 | def test_TensorFromScalar(): |
|
0 commit comments