5050]
5151
5252
53- def _compute_res_dtype (* arrays , sycl_queue , dtype = None , casting = "no" ):
53+ def _compute_res_dtype (* arrays , sycl_queue , dtype = None , out = None , casting = "no" ):
5454 """
55- Determines the output array data type and an intermediate data type
56- used in performing calculations related to a specific math function.
57- If dtype is ``None``, the output array data type of the operation is
58- determined based on the Promotion Type Rule and device capabilities.
59- Otherwise, `dtype` is used as output array dtype, if input arrays
60- can cast to it according to the casting rule determined. If casting
61- cannot be done, a ``TypeError`` is raised.
62- The intermediate data type is the data type used for performing the math
63- function calculations. If output array dtype is a floating-point data type,
64- it is also used for the intermediate data type. If output array dtype is an
65- integral data type, the default floating point data type of the device where
66- input arrays are allocated on are used for intermediate data type.
55+ Determines the output array data type.
56+ If `dtype` and `out` are ``None``, the output array data type of the
57+ operation is determined based on the Promotion Type Rule and device
58+ capabilities. if `out` is given, its data type is used as the output
59+ array dtypes. Otherwise, `dtype` is used as output array dtype.
60+ If input arrays cannot be cast to the determined output array dtype,
61+ a ``TypeError`` is raised.
6762
6863 Parameters
6964 ----------
7065 arrays : {dpnp.ndarray, usm_ndarray}
7166 Input arrays.
7267 dtype : dtype
68+ If not ``None`` and `out` is not defined, data type of the output array.
69+ out : {dpnp.ndarray, usm_ndarray}
7370 If not ``None``, data type of the output array.
7471 casting : {"no", "equiv", "safe", "same_kind", "unsafe"}, optional
7572 Controls what kind of data casting may occur.
@@ -78,17 +75,23 @@ def _compute_res_dtype(*arrays, sycl_queue, dtype=None, casting="no"):
7875
7976 Returns
8077 -------
81- compute_dtype, res_dtype :
82- `compute_dtype` is the data type used in performing math function calculations.
83- The input arrays of the math function are cast to `compute_dtype` and then
84- the calculations are performed.
85- `res_dtype` is the output data type. When the result is obtained, it is cast
86- to `res_dtype`.
78+ res_dtype :
79+ `res_dtype` is the output data type. When the result is obtained,
80+ it is cast to `res_dtype`.
8781
8882 """
8983
9084 res_dtype = dpnp .result_type (* arrays )
91- default_dtype = dpnp .default_float_type (sycl_queue = sycl_queue )
85+
86+ # If inputs are boolean and `out` is given and it is not boolean, the
87+ # calculation should be performed in boolean and at the end the result
88+ # is cast to out dtype. It is different than general case where the inputs
89+ # are cast to out dtype and then calculation is performed. Even when inputs
90+ # are boolean and `dtype` is given, the casting is done first and then the
91+ # calculation is performed.
92+ if out is not None and res_dtype != dpnp .bool :
93+ # out dtype is prioritized over a given dtype
94+ dtype = out .dtype
9295
9396 if dtype is not None :
9497 if dpnp .can_cast (res_dtype , dtype , casting = casting ):
@@ -98,11 +101,7 @@ def _compute_res_dtype(*arrays, sycl_queue, dtype=None, casting="no"):
98101 f"Cannot cast from dtype({ res_dtype } ) to dtype({ dtype } ) with casting rule { casting } "
99102 )
100103
101- compute_dtype = (
102- res_dtype if dpnp .issubdtype (res_dtype , dpnp .inexact ) else default_dtype
103- )
104-
105- return compute_dtype , res_dtype
104+ return res_dtype
106105
107106
108107def _copy_array (x , copy_flag = False , dtype = None , order = "C" ):
@@ -504,6 +503,23 @@ def _gemm_matmul(exec_q, x1, x2, res):
504503 return res
505504
506505
506+ def _gemm_special_case (x1 , x2 , res_dtype , call_flag ):
507+ """
508+ `gemm` and `gemm_batch` support these special cases of data types
509+ while `gemv` does not.
510+
511+ """
512+ # TODO: replace with dpnp.int8 when it is added
513+ is_int8 = x1 .dtype == numpy .int8 and x2 .dtype == numpy .int8
514+ is_int32_or_f32 = res_dtype in [dpnp .int32 , dpnp .float32 ]
515+ flag = is_int8 and is_int32_or_f32 and call_flag in ["gemm" , "gemm_batch" ]
516+
517+ # onemkl_interfaces does not support these data types
518+ onemkl_interfaces = bi ._using_onemkl_interfaces ()
519+
520+ return flag and not onemkl_interfaces
521+
522+
507523def _shape_error (shape1 , shape2 , func , err_msg ):
508524 """Validate the shapes of input and output arrays."""
509525
@@ -749,17 +765,19 @@ def dpnp_dot(a, b, /, out=None, *, casting="same_kind", conjugate=False):
749765 _validate_out_array (out , exec_q )
750766
751767 # Determine the appropriate data types
752- dot_dtype , res_dtype = _compute_res_dtype (a , b , sycl_queue = exec_q )
768+ res_dtype = _compute_res_dtype (
769+ a , b , out = out , casting = casting , sycl_queue = exec_q
770+ )
753771
754772 result = _create_result_array (
755- a , b , out , (), dot_dtype , res_usm_type , exec_q
773+ a , b , out , (), res_dtype , res_usm_type , exec_q
756774 )
757775
758776 # input arrays should have the proper data type
759777 if dpnp .issubdtype (res_dtype , dpnp .inexact ):
760778 # copying is needed if dtypes of input arrays are different
761- a = _copy_array (a , dtype = dot_dtype )
762- b = _copy_array (b , dtype = dot_dtype )
779+ a = _copy_array (a , dtype = res_dtype )
780+ b = _copy_array (b , dtype = res_dtype )
763781
764782 _manager = dpu .SequentialOrderManager [exec_q ]
765783
@@ -777,14 +795,11 @@ def dpnp_dot(a, b, /, out=None, *, casting="same_kind", conjugate=False):
777795 )
778796 _manager .add_event_pair (ht_ev , dot_ev )
779797 else :
780- # oneapi::mkl::blas::dot is slow for integer data type ,
798+ # oneapi::mkl::blas::dot does not support integer dtypes ,
781799 # so using dpctl.tensor.vecdot instead
782- dpt_a = dpnp .get_usm_ndarray (a )
783- dpt_b = dpnp .get_usm_ndarray (b )
784- result = dpnp_array ._create_from_usm_ndarray (dpt .vecdot (dpt_a , dpt_b ))
785-
786- if dot_dtype != res_dtype :
787- result = result .astype (res_dtype , copy = False )
800+ a_usm = dpnp .get_usm_ndarray (a )
801+ b_usm = dpnp .get_usm_ndarray (b )
802+ result = dpnp_array ._create_from_usm_ndarray (dpt .vecdot (a_usm , b_usm ))
788803
789804 return dpnp .get_result_array (result , out , casting = casting )
790805
@@ -902,8 +917,8 @@ def dpnp_multiplication(
902917 axes_res = normalize_axis_tuple (axes_res , len (result_shape ), "axes" )
903918
904919 # Determine the appropriate data types
905- compute_dtype , res_dtype = _compute_res_dtype (
906- x1 , x2 , dtype = dtype , casting = casting , sycl_queue = exec_q
920+ res_dtype = _compute_res_dtype (
921+ x1 , x2 , dtype = dtype , out = out , casting = casting , sycl_queue = exec_q
907922 )
908923
909924 call_flag = None
@@ -998,7 +1013,7 @@ def dpnp_multiplication(
9981013 x2 ,
9991014 out ,
10001015 res_shape ,
1001- compute_dtype ,
1016+ res_dtype ,
10021017 res_usm_type ,
10031018 exec_q ,
10041019 res_order ,
@@ -1010,64 +1025,82 @@ def dpnp_multiplication(
10101025 elif x1 .size == 0 or x2 .size == 0 :
10111026 result .fill (0 )
10121027 else :
1013- # input arrays should have the proper data type and
1014- # their base (last 2-dimensions) to be c-contiguous or f-contiguous
1015- x1 = _copy_array (
1016- x1 ,
1017- copy_flag = not x1_contig_flag ,
1018- dtype = compute_dtype ,
1019- order = res_order ,
1020- )
1021- x2 = _copy_array (
1022- x2 ,
1023- copy_flag = not x2_contig_flag ,
1024- dtype = compute_dtype ,
1025- order = res_order ,
1026- )
1027-
1028- if call_flag == "gemv" :
1029- if transpose :
1030- a_usm = dpnp .get_usm_ndarray (x2 )
1031- x_usm = dpnp .get_usm_ndarray (x1 )
1032- else :
1033- a_usm = dpnp .get_usm_ndarray (x1 )
1034- x_usm = dpnp .get_usm_ndarray (x2 )
1035-
1036- _manager = dpu .SequentialOrderManager [exec_q ]
1037-
1038- ht_ev , gemv_ev = bi ._gemv (
1039- exec_q ,
1040- a_usm ,
1041- x_usm ,
1042- dpnp .get_usm_ndarray (result ),
1043- transpose ,
1044- depends = _manager .submitted_events ,
1028+ if _gemm_special_case (x1 , x2 , res_dtype , call_flag ):
1029+ x1 = _copy_array (
1030+ x1 , copy_flag = not x1_contig_flag , order = res_order
10451031 )
1046- _manager .add_event_pair (ht_ev , gemv_ev )
1047- elif call_flag == "gemm" :
1048- result = _gemm_matmul (
1049- exec_q ,
1050- x1 ,
1051- x2 ,
1052- result ,
1032+ x2 = _copy_array (
1033+ x2 , copy_flag = not x2_contig_flag , order = res_order
10531034 )
1054- else : # call_flag == "gemm_batch"
1055- assert call_flag == "gemm_batch"
1056- result = _gemm_batch_matmul (
1057- exec_q ,
1035+ if call_flag == "gemm" :
1036+ result = _gemm_matmul (exec_q , x1 , x2 , result )
1037+ else :
1038+ assert call_flag == "gemm_batch"
1039+ result = _gemm_batch_matmul (exec_q , x1 , x2 , result )
1040+ elif dpnp .issubdtype (res_dtype , dpnp .inexact ):
1041+ # copying is needed if dtypes of input arrays are different or
1042+ # their base (last 2-dimensions) is not c-contiguous or f-contiguous
1043+ x1 = _copy_array (
10581044 x1 ,
1045+ copy_flag = not x1_contig_flag ,
1046+ dtype = res_dtype ,
1047+ order = res_order ,
1048+ )
1049+ x2 = _copy_array (
10591050 x2 ,
1060- result ,
1051+ copy_flag = not x2_contig_flag ,
1052+ dtype = res_dtype ,
1053+ order = res_order ,
1054+ )
1055+
1056+ if call_flag == "gemv" :
1057+ if transpose :
1058+ a_usm = dpnp .get_usm_ndarray (x2 )
1059+ x_usm = dpnp .get_usm_ndarray (x1 )
1060+ else :
1061+ a_usm = dpnp .get_usm_ndarray (x1 )
1062+ x_usm = dpnp .get_usm_ndarray (x2 )
1063+
1064+ _manager = dpu .SequentialOrderManager [exec_q ]
1065+
1066+ ht_ev , gemv_ev = bi ._gemv (
1067+ exec_q ,
1068+ a_usm ,
1069+ x_usm ,
1070+ dpnp .get_usm_ndarray (result ),
1071+ transpose ,
1072+ depends = _manager .submitted_events ,
1073+ )
1074+ _manager .add_event_pair (ht_ev , gemv_ev )
1075+ elif call_flag == "gemm" :
1076+ result = _gemm_matmul (exec_q , x1 , x2 , result )
1077+ else :
1078+ assert call_flag == "gemm_batch"
1079+ result = _gemm_batch_matmul (exec_q , x1 , x2 , result )
1080+ else :
1081+ # oneapi::mkl::blas::gemm/gemv do not support integer dtypes,
1082+ # except for special cases determined in `_gemm_special_case`,
1083+ # use dpctl.tensor.matmul for unsupported cases
1084+
1085+ # `dpt.matmul` does not support `casting` kwarg.
1086+ # We may need to change input dtypes based on given `casting`.
1087+ # The possibility of casting is already validated in
1088+ # `_compute_res_dtype`.
1089+ x1 = _copy_array (x1 , dtype = res_dtype , order = res_order )
1090+ x2 = _copy_array (x2 , dtype = res_dtype , order = res_order )
1091+
1092+ x1_usm = dpnp .get_usm_ndarray (x1 )
1093+ x2_usm = dpnp .get_usm_ndarray (x2 )
1094+ out_usm = dpnp .get_usm_ndarray (result )
1095+ dpt .matmul (
1096+ x1_usm , x2_usm , out = out_usm , dtype = dtype , order = order
10611097 )
10621098
10631099 if NumPy_special_case :
10641100 result = dpnp .tile (result , out .shape )
10651101 elif res_shape != result_shape :
10661102 result = dpnp .reshape (result , result_shape )
10671103
1068- if compute_dtype != res_dtype :
1069- result = dpnp .astype (result , res_dtype , copy = False )
1070-
10711104 if out is None :
10721105 if axes is not None :
10731106 # Move the data back to the appropriate axes of the result array
@@ -1207,8 +1240,8 @@ def dpnp_vecdot(
12071240 )
12081241
12091242 # Determine the appropriate data types
1210- _ , res_dtype = _compute_res_dtype (
1211- x1 , x2 , dtype = dtype , casting = casting , sycl_queue = exec_q
1243+ res_dtype = _compute_res_dtype (
1244+ x1 , x2 , dtype = dtype , out = out , casting = casting , sycl_queue = exec_q
12121245 )
12131246
12141247 _ , x1_is_1D , _ = _define_dim_flags (x1 , axis = - 1 )
0 commit comments