1414# See the License for the specific language governing permissions and
1515# limitations under the License.
1616
17- import ctypes
18-
1917import numpy as np
2018import pytest
2119
22- import dpctl
2320import dpctl .tensor as dpt
2421from dpctl .tensor ._type_utils import _can_cast
2522from dpctl .tests .helper import get_queue_or_skip , skip_if_dtype_not_supported
2623
27- from .utils import _all_dtypes , _compare_dtypes , _usm_types
24+ from .utils import _all_dtypes , _compare_dtypes
2825
2926
3027@pytest .mark .parametrize ("op1_dtype" , _all_dtypes )
@@ -61,100 +58,6 @@ def test_multiply_dtype_matrix(op1_dtype, op2_dtype):
6158 assert (dpt .asnumpy (r ) == expected .astype (r .dtype )).all ()
6259
6360
64- @pytest .mark .parametrize ("op1_usm_type" , _usm_types )
65- @pytest .mark .parametrize ("op2_usm_type" , _usm_types )
66- def test_multiply_usm_type_matrix (op1_usm_type , op2_usm_type ):
67- get_queue_or_skip ()
68-
69- sz = 128
70- ar1 = dpt .ones (sz , dtype = "i4" , usm_type = op1_usm_type )
71- ar2 = dpt .ones_like (ar1 , dtype = "i4" , usm_type = op2_usm_type )
72-
73- r = dpt .multiply (ar1 , ar2 )
74- assert isinstance (r , dpt .usm_ndarray )
75- expected_usm_type = dpctl .utils .get_coerced_usm_type (
76- (op1_usm_type , op2_usm_type )
77- )
78- assert r .usm_type == expected_usm_type
79-
80-
81- def test_multiply_order ():
82- get_queue_or_skip ()
83-
84- ar1 = dpt .ones ((20 , 20 ), dtype = "i4" , order = "C" )
85- ar2 = dpt .ones ((20 , 20 ), dtype = "i4" , order = "C" )
86- r1 = dpt .multiply (ar1 , ar2 , order = "C" )
87- assert r1 .flags .c_contiguous
88- r2 = dpt .multiply (ar1 , ar2 , order = "F" )
89- assert r2 .flags .f_contiguous
90- r3 = dpt .multiply (ar1 , ar2 , order = "A" )
91- assert r3 .flags .c_contiguous
92- r4 = dpt .multiply (ar1 , ar2 , order = "K" )
93- assert r4 .flags .c_contiguous
94-
95- ar1 = dpt .ones ((20 , 20 ), dtype = "i4" , order = "F" )
96- ar2 = dpt .ones ((20 , 20 ), dtype = "i4" , order = "F" )
97- r1 = dpt .multiply (ar1 , ar2 , order = "C" )
98- assert r1 .flags .c_contiguous
99- r2 = dpt .multiply (ar1 , ar2 , order = "F" )
100- assert r2 .flags .f_contiguous
101- r3 = dpt .multiply (ar1 , ar2 , order = "A" )
102- assert r3 .flags .f_contiguous
103- r4 = dpt .multiply (ar1 , ar2 , order = "K" )
104- assert r4 .flags .f_contiguous
105-
106- ar1 = dpt .ones ((40 , 40 ), dtype = "i4" , order = "C" )[:20 , ::- 2 ]
107- ar2 = dpt .ones ((40 , 40 ), dtype = "i4" , order = "C" )[:20 , ::- 2 ]
108- r4 = dpt .multiply (ar1 , ar2 , order = "K" )
109- assert r4 .strides == (20 , - 1 )
110-
111- ar1 = dpt .ones ((40 , 40 ), dtype = "i4" , order = "C" )[:20 , ::- 2 ].mT
112- ar2 = dpt .ones ((40 , 40 ), dtype = "i4" , order = "C" )[:20 , ::- 2 ].mT
113- r4 = dpt .multiply (ar1 , ar2 , order = "K" )
114- assert r4 .strides == (- 1 , 20 )
115-
116-
117- def test_multiply_broadcasting ():
118- get_queue_or_skip ()
119-
120- m = dpt .ones ((100 , 5 ), dtype = "i4" )
121- v = dpt .arange (1 , 6 , dtype = "i4" )
122-
123- r = dpt .multiply (m , v )
124-
125- expected = np .multiply (
126- np .ones ((100 , 5 ), dtype = "i4" ), np .arange (1 , 6 , dtype = "i4" )
127- )
128- assert (dpt .asnumpy (r ) == expected .astype (r .dtype )).all ()
129-
130- r2 = dpt .multiply (v , m )
131- expected2 = np .multiply (
132- np .arange (1 , 6 , dtype = "i4" ), np .ones ((100 , 5 ), dtype = "i4" )
133- )
134- assert (dpt .asnumpy (r2 ) == expected2 .astype (r2 .dtype )).all ()
135-
136-
137- @pytest .mark .parametrize ("arr_dt" , _all_dtypes )
138- def test_multiply_python_scalar (arr_dt ):
139- q = get_queue_or_skip ()
140- skip_if_dtype_not_supported (arr_dt , q )
141-
142- X = dpt .ones ((10 , 10 ), dtype = arr_dt , sycl_queue = q )
143- py_ones = (
144- bool (1 ),
145- int (1 ),
146- float (1 ),
147- complex (1 ),
148- np .float32 (1 ),
149- ctypes .c_int (1 ),
150- )
151- for sc in py_ones :
152- R = dpt .multiply (X , sc )
153- assert isinstance (R , dpt .usm_ndarray )
154- R = dpt .multiply (sc , X )
155- assert isinstance (R , dpt .usm_ndarray )
156-
157-
15861@pytest .mark .parametrize ("arr_dt" , _all_dtypes )
15962@pytest .mark .parametrize ("sc" , [bool (1 ), int (1 ), float (1 ), complex (1 )])
16063def test_multiply_python_scalar_gh1219 (arr_dt , sc ):
@@ -175,22 +78,6 @@ def test_multiply_python_scalar_gh1219(arr_dt, sc):
17578 assert _compare_dtypes (R .dtype , Rnp .dtype , sycl_queue = q )
17679
17780
178- @pytest .mark .parametrize ("dtype" , _all_dtypes )
179- def test_multiply_inplace_python_scalar (dtype ):
180- q = get_queue_or_skip ()
181- skip_if_dtype_not_supported (dtype , q )
182- X = dpt .ones ((10 , 10 ), dtype = dtype , sycl_queue = q )
183- dt_kind = X .dtype .kind
184- if dt_kind in "ui" :
185- X *= int (1 )
186- elif dt_kind == "f" :
187- X *= float (1 )
188- elif dt_kind == "c" :
189- X *= complex (1 )
190- elif dt_kind == "b" :
191- X *= bool (1 )
192-
193-
19481@pytest .mark .parametrize ("op1_dtype" , _all_dtypes )
19582@pytest .mark .parametrize ("op2_dtype" , _all_dtypes )
19683def test_multiply_inplace_dtype_matrix (op1_dtype , op2_dtype ):
0 commit comments