|
5 | 5 | import numpy as np |
6 | 6 | import pytest |
7 | 7 |
|
8 | | -from .. import ones, arange, reshape, asarray, result_type, all, equal |
| 8 | +from .. import ones, arange, reshape, asarray, result_type, all, equal, stack |
9 | 9 | from .._array_object import Array, CPU_DEVICE, Device |
10 | 10 | from .._dtypes import ( |
11 | 11 | _all_dtypes, |
@@ -101,41 +101,65 @@ def test_validate_index(): |
101 | 101 | assert_raises(IndexError, lambda: a[idx]) |
102 | 102 |
|
103 | 103 |
|
104 | | -def test_indexing_arrays(): |
| 104 | +@pytest.mark.parametrize("device", [None, "CPU_DEVICE", "device1", "device2"]) |
| 105 | +def test_indexing_arrays(device): |
105 | 106 | # indexing with 1D integer arrays and mixes of integers and 1D integer are allowed |
| 107 | + device = None if device is None else Device(device) |
106 | 108 |
|
107 | 109 | # 1D array |
108 | | - a = arange(5) |
109 | | - idx = asarray([1, 0, 1, 2, -1]) |
| 110 | + a = arange(5, device=device) |
| 111 | + idx = asarray([1, 0, 1, 2, -1], device=device) |
110 | 112 | a_idx = a[idx] |
111 | 113 |
|
112 | | - a_idx_loop = asarray([a[idx[i]] for i in range(idx.shape[0])]) |
| 114 | + a_idx_loop = stack([a[idx[i]] for i in range(idx.shape[0])]) |
113 | 115 | assert all(a_idx == a_idx_loop) |
| 116 | + assert a_idx.shape == idx.shape |
| 117 | + assert a.device == idx.device == a_idx.device |
114 | 118 |
|
115 | 119 | # setitem with arrays is not allowed |
116 | 120 | with assert_raises(IndexError): |
117 | 121 | a[idx] = 42 |
118 | 122 |
|
119 | 123 | # mixed array and integer indexing |
120 | | - a = reshape(arange(3*4), (3, 4)) |
121 | | - idx = asarray([1, 0, 1, 2, -1]) |
| 124 | + a = reshape(arange(3*4, device=device), (3, 4)) |
| 125 | + idx = asarray([1, 0, 1, 2, -1], device=device) |
122 | 126 | a_idx = a[idx, 1] |
123 | | - |
124 | | - a_idx_loop = asarray([a[idx[i], 1] for i in range(idx.shape[0])]) |
| 127 | + a_idx_loop = stack([a[idx[i], 1] for i in range(idx.shape[0])]) |
125 | 128 | assert all(a_idx == a_idx_loop) |
| 129 | + assert a_idx.shape == idx.shape |
| 130 | + assert a.device == idx.device == a_idx.device |
126 | 131 |
|
127 | 132 | # index with two arrays |
128 | 133 | a_idx = a[idx, idx] |
129 | | - a_idx_loop = asarray([a[idx[i], idx[i]] for i in range(idx.shape[0])]) |
| 134 | + a_idx_loop = stack([a[idx[i], idx[i]] for i in range(idx.shape[0])]) |
130 | 135 | assert all(a_idx == a_idx_loop) |
| 136 | + assert a_idx.shape == a_idx.shape |
| 137 | + assert a.device == idx.device == a_idx.device |
131 | 138 |
|
132 | 139 | # setitem with arrays is not allowed |
133 | 140 | with assert_raises(IndexError): |
134 | 141 | a[idx, idx] = 42 |
135 | 142 |
|
136 | 143 | # smoke test indexing with ndim > 1 arrays |
137 | 144 | idx = idx[..., None] |
138 | | - a[idx, idx] |
| 145 | + a_idx = a[idx, idx] |
| 146 | + assert a.device == idx.device == a_idx.device |
| 147 | + |
| 148 | + |
| 149 | +def test_indexing_arrays_different_devices(): |
| 150 | + # Ensure indexing via array on different device errors |
| 151 | + device1 = Device("CPU_DEVICE") |
| 152 | + device2 = Device("device1") |
| 153 | + |
| 154 | + a = arange(5, device=device1) |
| 155 | + idx1 = asarray([1, 0, 1, 2, -1], device=device2) |
| 156 | + idx2 = asarray([1, 0, 1, 2, -1], device=device1) |
| 157 | + |
| 158 | + with pytest.raises(ValueError, match="Array indexing is only allowed when"): |
| 159 | + a[idx1] |
| 160 | + |
| 161 | + with pytest.raises(ValueError, match="Array indexing is only allowed when"): |
| 162 | + a[idx1, idx2] |
139 | 163 |
|
140 | 164 |
|
141 | 165 | def test_promoted_scalar_inherits_device(): |
|
0 commit comments