@@ -1071,19 +1071,20 @@ def multiply(x1, x2, dtype=None, out=None, where=True, **kwargs):
10711071 [1, 4, 9, 16, 25]
10721072
10731073 """
1074- x1_is_scalar , x2_is_scalar = dpnp .isscalar (x1 ), dpnp .isscalar (x2 )
1075- x1_is_dparray , x2_is_dparray = isinstance (x1 , dparray ), isinstance (x2 , dparray )
10761074
1077- if not use_origin_backend (x1 ) and not kwargs :
1078- if not x1_is_dparray and not x1_is_scalar :
1079- pass
1080- elif not x2_is_dparray and not x2_is_scalar :
1075+ x1_is_scalar = dpnp .isscalar (x1 )
1076+ x2_is_scalar = dpnp .isscalar (x2 )
1077+ x1_desc = dpnp .get_dpnp_descriptor (x1 )
1078+ x2_desc = dpnp .get_dpnp_descriptor (x2 )
1079+
1080+ if x1_desc and x2_desc and not kwargs :
1081+ if not x2_desc and not x2_is_scalar :
10811082 pass
10821083 elif x1_is_scalar and x2_is_scalar :
10831084 pass
1084- elif x1_is_dparray and x1 .ndim == 0 :
1085+ elif x1_desc and x1_desc .ndim == 0 :
10851086 pass
1086- elif x2_is_dparray and x2 .ndim == 0 :
1087+ elif x2_desc and x2_desc .ndim == 0 :
10871088 pass
10881089 elif dtype is not None :
10891090 pass
@@ -1092,7 +1093,7 @@ def multiply(x1, x2, dtype=None, out=None, where=True, **kwargs):
10921093 elif not where :
10931094 pass
10941095 else :
1095- return dpnp_multiply (x1 , x2 , dtype = dtype , out = out , where = where )
1096+ return dpnp_multiply (x1_desc , x2_desc , dtype = dtype , out = out , where = where )
10961097
10971098 return call_origin (numpy .multiply , x1 , x2 , dtype = dtype , out = out , where = where , ** kwargs )
10981099
0 commit comments