4040"""
4141
4242
43- import numpy
44-
4543from dpnp .dpnp_algo import *
46- from dpnp .dparray import dparray
4744from dpnp .dpnp_utils import *
4845import dpnp
4946import dpnp .config as config
5047
48+ import numpy
49+
5150
5251__all__ = [
5352 "dot" ,
@@ -92,15 +91,14 @@ def dot(x1, x2, **kwargs):
9291
9392 """
9493
95- is_x1_dparray = isinstance (x1 , dparray )
96- is_x2_dparray = isinstance (x2 , dparray )
97-
98- if (not use_origin_backend (x1 ) and is_x1_dparray and is_x2_dparray and not kwargs ):
99- dim1 = x1 .ndim
100- dim2 = x2 .ndim
94+ x1_desc = dpnp .get_dpnp_descriptor (x1 )
95+ x2_desc = dpnp .get_dpnp_descriptor (x2 )
96+ if x1_desc and x2_desc and not kwargs :
97+ dim1 = x1_desc .ndim
98+ dim2 = x2_desc .ndim
10199
102- if not (dim1 >= 2 and dim2 == 1 ) and not (dim1 >= 2 and dim2 >= 2 ) and (x1 .dtype == x2 .dtype ):
103- result = dpnp_dot (x1 , x2 )
100+ if not (dim1 >= 2 and dim2 == 1 ) and not (dim1 >= 2 and dim2 >= 2 ) and (x1_desc .dtype == x2_desc .dtype ):
101+ result = dpnp_dot (x1_desc , x2_desc )
104102
105103 # scalar returned
106104 if result .shape == (1 ,):
@@ -186,16 +184,15 @@ def inner(x1, x2, **kwargs):
186184
187185 """
188186
189- is_x1_dparray = isinstance (x1 , dparray )
190- is_x2_dparray = isinstance (x2 , dparray )
191-
192- if (not use_origin_backend (x1 ) and is_x1_dparray and is_x2_dparray and not kwargs ):
193- return dpnp_inner (x1 , x2 )
187+ x1_desc = dpnp .get_dpnp_descriptor (x1 )
188+ x2_desc = dpnp .get_dpnp_descriptor (x2 )
189+ if 0 and x1_desc and x2_desc and not kwargs :
190+ return dpnp_inner (x1_desc , x2_desc )
194191
195192 return call_origin (numpy .inner , x1 , x2 , ** kwargs )
196193
197194
198- def kron (a , b ):
195+ def kron (x1 , x2 ):
199196 """
200197 Returns the kronecker product of two arrays.
201198
@@ -205,23 +202,15 @@ def kron(a, b):
205202
206203 """
207204
208- if not use_origin_backend (a ):
209- if dpnp .isscalar (a ):
210- a = dpnp .array (a )
211- if dpnp .isscalar (b ):
212- b = dpnp .array (b )
213-
214- if not isinstance (a , dparray ):
215- pass
216- elif not isinstance (b , dparray ):
217- pass
218- else :
219- return dpnp_kron (a , b )
205+ x1_desc = dpnp .get_dpnp_descriptor (x1 )
206+ x2_desc = dpnp .get_dpnp_descriptor (x2 )
207+ if x1_desc and x2_desc :
208+ return dpnp_kron (x1_desc , x2_desc )
220209
221- return call_origin (numpy .kron , a , b )
210+ return call_origin (numpy .kron , x1 , x2 )
222211
223212
224- def matmul (in_array1 , in_array2 , out = None , ** kwargs ):
213+ def matmul (x1 , x2 , out = None , ** kwargs ):
225214 """
226215 Matrix product of two arrays.
227216
@@ -257,32 +246,42 @@ def matmul(in_array1, in_array2, out=None, **kwargs):
257246
258247 """
259248
260- if not use_origin_backend (in_array1 ) and not kwargs :
261- if not isinstance (in_array1 , dparray ):
249+ x1_desc = dpnp .get_dpnp_descriptor (x1 )
250+ x2_desc = dpnp .get_dpnp_descriptor (x2 )
251+ out_desc = dpnp .get_dpnp_descriptor (x2 )
252+ if x1_desc and x2_desc and out_desc and not kwargs :
253+ if x1_desc .size != x2_desc .size :
254+ pass
255+ elif not x1_desc .ndim :
256+ pass
257+ elif not x2_desc .ndim :
262258 pass
263- elif not isinstance ( in_array2 , dparray ) :
259+ elif not x1_desc . size :
264260 pass
265- elif out is not None and not isinstance ( out , dparray ) :
261+ elif not x2_desc . size :
266262 pass
267263 else :
268- """
269- Cost model checks
270- """
271-
272- dparray1_size = in_array1 .size
273- dparray2_size = in_array2 .size
274- cost_size = 4096 # 2D array shape(64, 64)
275-
276- if ((in_array1 .dtype == numpy .float64 ) or (in_array1 .dtype == numpy .float32 )):
264+ if 0 :
277265 """
278- Floating point types are handled via original math library better than SYCL math library
266+ Cost model checks
279267 """
280- cost_size = 262144 # 2D array shape(512, 512)
281268
282- if (dparray1_size > cost_size ) and (dparray2_size > cost_size ):
283- return dpnp_matmul (in_array1 , in_array2 , out = out )
269+ dparray1_size = x1_desc .size
270+ dparray2_size = x2_desc .size
271+ cost_size = 4096 # 2D array shape(64, 64)
284272
285- return call_origin (numpy .matmul , in_array1 , in_array2 , out = out , ** kwargs )
273+ if ((x1_desc .dtype == numpy .float64 ) or (x1_desc .dtype == numpy .float32 )):
274+ """
275+ Floating point types are handled via original math library better than SYCL math library
276+ """
277+ cost_size = 262144 # 2D array shape(512, 512)
278+
279+ if (dparray1_size > cost_size ) and (dparray2_size > cost_size ):
280+ return dpnp_matmul (x1_desc , x2_desc , out )
281+ else :
282+ return dpnp_matmul (x1_desc , x2_desc , out )
283+
284+ return call_origin (numpy .matmul , x1 , x2 , out = out , ** kwargs )
286285
287286
288287def outer (x1 , x2 , ** kwargs ):
@@ -314,11 +313,10 @@ def outer(x1, x2, **kwargs):
314313
315314 """
316315
317- is_x1_dparray = isinstance (x1 , dparray )
318- is_x2_dparray = isinstance (x2 , dparray )
319-
320- if (not use_origin_backend (x1 ) and is_x1_dparray and is_x2_dparray and not kwargs ):
321- return dpnp_outer (x1 , x2 )
316+ x1_desc = dpnp .get_dpnp_descriptor (x1 )
317+ x2_desc = dpnp .get_dpnp_descriptor (x2 )
318+ if 0 and x1_desc and x2_desc and not kwargs :
319+ return dpnp_outer (x1_desc , x2_desc )
322320
323321 return call_origin (numpy .outer , x1 , x2 , ** kwargs )
324322
@@ -353,11 +351,10 @@ def tensordot(x1, x2, axes=2):
353351
354352 """
355353
356- is_x1_dparray = isinstance (x1 , dparray )
357- is_x2_dparray = isinstance (x2 , dparray )
358-
359- if (not use_origin_backend (x1 ) and is_x1_dparray and is_x2_dparray and (axes == 1 )):
360- return dpnp_tensordot (x1 , x2 ) # dpnp_matmul
354+ x1_desc = dpnp .get_dpnp_descriptor (x1 )
355+ x2_desc = dpnp .get_dpnp_descriptor (x2 )
356+ if x1_desc and x2_desc and (axes == 1 ):
357+ return dpnp_tensordot_not_implemented (x1_desc , x2_desc ) # dpnp_matmul
361358
362359 return call_origin (numpy .tensordot , x1 , x2 , axes )
363360
0 commit comments