|
| 1 | +from python import Python |
| 2 | +from layout import Layout, LayoutTensor, RuntimeLayout |
| 3 | + |
| 4 | +from testing import assert_equal, assert_raises, assert_true |
| 5 | + |
| 6 | +from bridge.numpy import ndarray_to_tensor, tensor_to_ndarray |
| 7 | + |
| 8 | + |
| 9 | +def test_ndarray_to_tensor(): |
| 10 | + """Test numpy array conversion to layouttensor for various tensor shapes.""" |
| 11 | + var np = Python.import_module("numpy") |
| 12 | + |
| 13 | + # 1) Test vectors |
| 14 | + var in_vector = np.arange(4.0) |
| 15 | + var out_vector = ndarray_to_tensor[order=1](in_vector) |
| 16 | + assert_equal(out_vector[1], 1.0) |
| 17 | + assert_equal(out_vector[3], 3.0) |
| 18 | + |
| 19 | + # 2) Test matrices |
| 20 | + var in_matrix = np.arange(4.0 * 3.0).reshape(3, 4) |
| 21 | + var out_matrix = ndarray_to_tensor[order=2](in_matrix) |
| 22 | + assert_equal(out_matrix[0, 0], 0.0) |
| 23 | + assert_equal(out_matrix[1, 1], 5.0) |
| 24 | + assert_equal(out_matrix[1, 3], 7.0) |
| 25 | + assert_equal(out_matrix[2, 1], 9.0) |
| 26 | + |
| 27 | + # Check that non-contiguous arrays raise exceptions. |
| 28 | + with assert_raises(): |
| 29 | + var in_matrix_col_major = np.asfortranarray(in_matrix) |
| 30 | + _ = ndarray_to_tensor[order=2](in_matrix_col_major) |
| 31 | + |
| 32 | + # 3) Test three-index tensors |
| 33 | + var in_tensor = np.arange(4.0 * 3.0).reshape(3, 1, 4) |
| 34 | + var out_tensor = ndarray_to_tensor[order=3](in_tensor) |
| 35 | + assert_equal(out_tensor[0, 0, 0], 0.0) |
| 36 | + assert_equal(out_tensor[1, 0, 1], 5.0) |
| 37 | + assert_equal(out_tensor[1, 0, 3], 7.0) |
| 38 | + assert_equal(out_tensor[2, 0, 1], 9.0) |
| 39 | + |
| 40 | + # 3) Test four-index tensors |
| 41 | + var in_4tensor = np.arange(4.0 * 3.0).reshape(2, 3, 1, 2) |
| 42 | + var out_4tensor = ndarray_to_tensor[order=4](in_4tensor) |
| 43 | + assert_equal(out_4tensor[0, 0, 0, 0], 0.0) |
| 44 | + assert_equal(out_4tensor[0, 1, 0, 1], 3.0) |
| 45 | + assert_equal(out_4tensor[1, 0, 0, 1], 7.0) |
| 46 | + assert_equal(out_4tensor[0, 2, 0, 0], 4.0) |
| 47 | + |
| 48 | + |
| 49 | +def test_memory_leaks(): |
| 50 | + """Test that we can safely remove the reference to the numpy array.""" |
| 51 | + var np = Python.import_module("numpy") |
| 52 | + var np_array = np.arange(6.0).reshape(3, 2) |
| 53 | + var tensor = ndarray_to_tensor[order=2](np_array) |
| 54 | + np_array.__del__() |
| 55 | + assert_equal(tensor[1, 0], 2.0) |
| 56 | + assert_equal(tensor[1, 1], 3.0) |
| 57 | + assert_equal(tensor[2, 1], 5.0) |
| 58 | + |
| 59 | + |
| 60 | +# def test_tensor_numpy_identity_transformation(): |
| 61 | +# """Test that `tensor_to_ndarray` is inverse of `ndarray_to_tensor`.""" |
| 62 | +# var values = InlineArray[Float64, 6](0.0, 1.0, 2.0, 3.0, 4.0, 5.0) |
| 63 | +# var ptr = values.unsafe_ptr() |
| 64 | +# var tensor = LayoutTensor[ |
| 65 | +# DType.float64, |
| 66 | +# Layout.row_major(2, 3), |
| 67 | +# # MutableAnyOrigin, |
| 68 | +# # __origin_of(ptr[]), |
| 69 | +# # __origin_of(ptr), |
| 70 | +# # origin = __origin_of(values), |
| 71 | +# # MutableAnyOrigin, |
| 72 | +# ](values) |
| 73 | + |
| 74 | +# np_array = tensor_to_ndarray(tensor) |
| 75 | +# out_layouttensor = ndarray_to_tensor[order=2](in_array) |
| 76 | + |
| 77 | + |
| 78 | +def test_numpy_tensor_identity_transformation(): |
| 79 | + """Test that `ndarray_to_tensor` is the inverse of `tensor_to_ndarray`.""" |
| 80 | + var np = Python.import_module("numpy") |
| 81 | + |
| 82 | + # 1) Test vectors |
| 83 | + # TODO: Add support for vectors! |
| 84 | + |
| 85 | + # 2) Test matrices |
| 86 | + var in_matrix = np.arange(4.0 * 3.0).reshape(3, 4) |
| 87 | + var layout_matrix = ndarray_to_tensor[order=2](in_matrix) |
| 88 | + var out_matrix = tensor_to_ndarray(layout_matrix) |
| 89 | + np.testing.assert_array_equal(in_matrix, out_matrix) |
| 90 | + |
| 91 | + # 3) Test three-index tensors |
| 92 | + var in_tensor = np.arange(4.0 * 3.0).reshape(3, 1, 4) |
| 93 | + var layout_tensor = ndarray_to_tensor[order=3](in_tensor) |
| 94 | + var out_tensor = tensor_to_ndarray(layout_tensor) |
| 95 | + np.testing.assert_array_equal(in_tensor, out_tensor) |
| 96 | + |
| 97 | + # 3) Test four-index tensors |
| 98 | + var in_4tensor = np.arange(4.0 * 3.0).reshape(2, 3, 1, 2) |
| 99 | + var layout_4tensor = ndarray_to_tensor[order=4](in_4tensor) |
| 100 | + var out_4tensor = tensor_to_ndarray(layout_4tensor) |
| 101 | + np.testing.assert_array_equal(in_4tensor, out_4tensor) |
0 commit comments