|
37 | 37 |
|
38 | 38 | """ |
39 | 39 |
|
| 40 | +# pylint: disable=no-name-in-module |
40 | 41 | import numpy |
41 | 42 |
|
42 | 43 | import dpnp |
43 | 44 |
|
| 45 | +from .dpnp_utils import map_dtype_to_device |
44 | 46 | from .dpnp_utils.dpnp_utils_einsum import dpnp_einsum |
45 | 47 | from .dpnp_utils.dpnp_utils_linearalgebra import ( |
46 | 48 | dpnp_dot, |
|
66 | 68 | ] |
67 | 69 |
|
68 | 70 |
|
| 71 | +# TODO: implement a specific scalar-array kernel |
| 72 | +def _call_multiply(a, b, out=None): |
| 73 | + """Call multiply function for special cases of scalar-array dots.""" |
| 74 | + |
| 75 | + sc, arr = (a, b) if dpnp.isscalar(a) else (b, a) |
| 76 | + sc_dtype = map_dtype_to_device(type(sc), arr.sycl_device) |
| 77 | + res_dtype = dpnp.result_type(sc_dtype, arr) |
| 78 | + if out is not None and out.dtype == arr.dtype: |
| 79 | + res = dpnp.multiply(a, b, out=out) |
| 80 | + else: |
| 81 | + res = dpnp.multiply(a, b, dtype=res_dtype) |
| 82 | + return dpnp.get_result_array(res, out, casting="no") |
| 83 | + |
| 84 | + |
69 | 85 | def dot(a, b, out=None): |
70 | 86 | """ |
71 | 87 | Dot product of `a` and `b`. |
@@ -139,8 +155,7 @@ def dot(a, b, out=None): |
139 | 155 | raise ValueError("Only C-contiguous array is acceptable.") |
140 | 156 |
|
141 | 157 | if dpnp.isscalar(a) or dpnp.isscalar(b): |
142 | | - # TODO: use specific scalar-vector kernel |
143 | | - return dpnp.multiply(a, b, out=out) |
| 158 | + return _call_multiply(a, b, out=out) |
144 | 159 |
|
145 | 160 | a_ndim = a.ndim |
146 | 161 | b_ndim = b.ndim |
@@ -635,8 +650,7 @@ def inner(a, b): |
635 | 650 | dpnp.check_supported_arrays_type(a, b, scalar_type=True) |
636 | 651 |
|
637 | 652 | if dpnp.isscalar(a) or dpnp.isscalar(b): |
638 | | - # TODO: use specific scalar-vector kernel |
639 | | - return dpnp.multiply(a, b) |
| 653 | + return _call_multiply(a, b) |
640 | 654 |
|
641 | 655 | if a.ndim == 0 or b.ndim == 0: |
642 | 656 | # TODO: use specific scalar-vector kernel |
@@ -714,8 +728,7 @@ def kron(a, b): |
714 | 728 | dpnp.check_supported_arrays_type(a, b, scalar_type=True) |
715 | 729 |
|
716 | 730 | if dpnp.isscalar(a) or dpnp.isscalar(b): |
717 | | - # TODO: use specific scalar-vector kernel |
718 | | - return dpnp.multiply(a, b) |
| 731 | + return _call_multiply(a, b) |
719 | 732 |
|
720 | 733 | a_ndim = a.ndim |
721 | 734 | b_ndim = b.ndim |
@@ -1199,8 +1212,7 @@ def tensordot(a, b, axes=2): |
1199 | 1212 | raise ValueError( |
1200 | 1213 | "One of the inputs is scalar, axes should be zero." |
1201 | 1214 | ) |
1202 | | - # TODO: use specific scalar-vector kernel |
1203 | | - return dpnp.multiply(a, b) |
| 1215 | + return _call_multiply(a, b) |
1204 | 1216 |
|
1205 | 1217 | return dpnp_tensordot(a, b, axes=axes) |
1206 | 1218 |
|
@@ -1263,13 +1275,13 @@ def vdot(a, b): |
1263 | 1275 | if b.size != 1: |
1264 | 1276 | raise ValueError("The second array should be of size one.") |
1265 | 1277 | a_conj = numpy.conj(a) |
1266 | | - return dpnp.multiply(a_conj, b) |
| 1278 | + return _call_multiply(a_conj, b) |
1267 | 1279 |
|
1268 | 1280 | if dpnp.isscalar(b): |
1269 | 1281 | if a.size != 1: |
1270 | 1282 | raise ValueError("The first array should be of size one.") |
1271 | 1283 | a_conj = dpnp.conj(a) |
1272 | | - return dpnp.multiply(a_conj, b) |
| 1284 | + return _call_multiply(a_conj, b) |
1273 | 1285 |
|
1274 | 1286 | if a.ndim == 1 and b.ndim == 1: |
1275 | 1287 | return dpnp_dot(a, b, out=None, conjugate=True) |
|
0 commit comments