@@ -800,6 +800,79 @@ def _nonzero_impl(ary):
800800 return res
801801
802802
803+ def _validate_indices (inds , queue_list , usm_type_list ):
804+ """
805+ Utility for validating indices are usm_ndarray of integral dtype or Python
806+ integers. At least one must be an array.
807+
808+ For each array, the queue and usm type are appended to `queue_list` and
809+ `usm_type_list`, respectively.
810+ """
811+ any_usmarray = False
812+ for ind in inds :
813+ if isinstance (ind , dpt .usm_ndarray ):
814+ any_usmarray = True
815+ if ind .dtype .kind not in "ui" :
816+ raise IndexError (
817+ "arrays used as indices must be of integer (or boolean) "
818+ "type"
819+ )
820+ queue_list .append (ind .sycl_queue )
821+ usm_type_list .append (ind .usm_type )
822+ elif not isinstance (ind , Integral ):
823+ raise TypeError (
824+ "all elements of `ind` expected to be usm_ndarrays "
825+ f"or integers, found { type (ind )} "
826+ )
827+ if not any_usmarray :
828+ raise TypeError (
829+ "at least one element of `inds` expected to be a usm_ndarray"
830+ )
831+ return inds
832+
833+
834+ def _prepare_indices_arrays (inds , q , usm_type ):
835+ """
836+ Utility taking a mix of usm_ndarray and possibly Python int scalar indices,
837+ a queue (assumed to be common to arrays in inds), and a usm type.
838+
839+ Python scalar integers are promoted to arrays on the provided queue and
840+ with the provided usm type. All arrays are then promoted to a common
841+ integral type (if possible) before being broadcast to a common shape.
842+ """
843+ # scalar integers -> arrays
844+ inds = tuple (
845+ map (
846+ lambda ind : (
847+ ind
848+ if isinstance (ind , dpt .usm_ndarray )
849+ else dpt .asarray (ind , usm_type = usm_type , sycl_queue = q )
850+ ),
851+ inds ,
852+ )
853+ )
854+
855+ # promote to a common integral type if possible
856+ ind_dt = dpt .result_type (* inds )
857+ if ind_dt .kind not in "ui" :
858+ raise ValueError (
859+ "cannot safely promote indices to an integer data type"
860+ )
861+ inds = tuple (
862+ map (
863+ lambda ind : (
864+ ind if ind .dtype == ind_dt else dpt .astype (ind , ind_dt )
865+ ),
866+ inds ,
867+ )
868+ )
869+
870+ # broadcast
871+ inds = dpt .broadcast_arrays (* inds )
872+
873+ return inds
874+
875+
803876def _take_multi_index (ary , inds , p , mode = 0 ):
804877 if not isinstance (ary , dpt .usm_ndarray ):
805878 raise TypeError (
@@ -820,26 +893,8 @@ def _take_multi_index(ary, inds, p, mode=0):
820893 ]
821894 if not isinstance (inds , (list , tuple )):
822895 inds = (inds ,)
823- any_usmarray = False
824- for ind in inds :
825- if isinstance (ind , dpt .usm_ndarray ):
826- any_usmarray = True
827- if ind .dtype .kind not in "ui" :
828- raise IndexError (
829- "arrays used as indices must be of integer (or boolean) "
830- "type"
831- )
832- queues_ .append (ind .sycl_queue )
833- usm_types_ .append (ind .usm_type )
834- elif not isinstance (ind , Integral ):
835- raise TypeError (
836- "all elements of `ind` expected to be usm_ndarrays "
837- "or integers"
838- )
839- if not any_usmarray :
840- raise TypeError (
841- "at least one element of `ind` expected to be a usm_ndarray"
842- )
896+
897+ _validate_indices (inds , queues_ , usm_types_ )
843898 res_usm_type = dpctl .utils .get_coerced_usm_type (usm_types_ )
844899 exec_q = dpctl .utils .get_execution_queue (queues_ )
845900 if exec_q is None :
@@ -849,34 +904,10 @@ def _take_multi_index(ary, inds, p, mode=0):
849904 "Use `usm_ndarray.to_device` method to migrate data to "
850905 "be associated with the same queue."
851906 )
907+
852908 if len (inds ) > 1 :
853- inds = tuple (
854- map (
855- lambda ind : (
856- ind
857- if isinstance (ind , dpt .usm_ndarray )
858- else dpt .asarray (
859- ind , usm_type = res_usm_type , sycl_queue = exec_q
860- )
861- ),
862- inds ,
863- )
864- )
865- ind_dt = dpt .result_type (* inds )
866- # ind arrays have been checked to be of integer dtype
867- if ind_dt .kind not in "ui" :
868- raise ValueError (
869- "cannot safely promote indices to an integer data type"
870- )
871- inds = tuple (
872- map (
873- lambda ind : (
874- ind if ind .dtype == ind_dt else dpt .astype (ind , ind_dt )
875- ),
876- inds ,
877- )
878- )
879- inds = dpt .broadcast_arrays (* inds )
909+ inds = _prepare_indices_arrays (inds , exec_q , res_usm_type )
910+
880911 ind0 = inds [0 ]
881912 ary_sh = ary .shape
882913 p_end = p + len (inds )
@@ -992,26 +1023,9 @@ def _put_multi_index(ary, inds, p, vals, mode=0):
9921023 ]
9931024 if not isinstance (inds , (list , tuple )):
9941025 inds = (inds ,)
995- any_usmarray = False
996- for ind in inds :
997- if isinstance (ind , dpt .usm_ndarray ):
998- any_usmarray = True
999- if ind .dtype .kind not in "ui" :
1000- raise IndexError (
1001- "arrays used as indices must be of integer (or boolean) "
1002- "type"
1003- )
1004- queues_ .append (ind .sycl_queue )
1005- usm_types_ .append (ind .usm_type )
1006- elif not isinstance (ind , Integral ):
1007- raise TypeError (
1008- "all elements of `ind` expected to be usm_ndarrays "
1009- "or integers"
1010- )
1011- if not any_usmarray :
1012- raise TypeError (
1013- "at least one element of `ind` expected to be a usm_ndarray"
1014- )
1026+
1027+ _validate_indices (inds , queues_ , usm_types_ )
1028+
10151029 vals_usm_type = dpctl .utils .get_coerced_usm_type (usm_types_ )
10161030 exec_q = dpctl .utils .get_execution_queue (queues_ )
10171031 if exec_q is not None :
@@ -1028,34 +1042,10 @@ def _put_multi_index(ary, inds, p, vals, mode=0):
10281042 "Use `usm_ndarray.to_device` method to migrate data to "
10291043 "be associated with the same queue."
10301044 )
1045+
10311046 if len (inds ) > 1 :
1032- inds = tuple (
1033- map (
1034- lambda ind : (
1035- ind
1036- if isinstance (ind , dpt .usm_ndarray )
1037- else dpt .asarray (
1038- ind , usm_type = vals_usm_type , sycl_queue = exec_q
1039- )
1040- ),
1041- inds ,
1042- )
1043- )
1044- ind_dt = dpt .result_type (* inds )
1045- # ind arrays have been checked to be of integer dtype
1046- if ind_dt .kind not in "ui" :
1047- raise ValueError (
1048- "cannot safely promote indices to an integer data type"
1049- )
1050- inds = tuple (
1051- map (
1052- lambda ind : (
1053- ind if ind .dtype == ind_dt else dpt .astype (ind , ind_dt )
1054- ),
1055- inds ,
1056- )
1057- )
1058- inds = dpt .broadcast_arrays (* inds )
1047+ inds = _prepare_indices_arrays (inds , exec_q , vals_usm_type )
1048+
10591049 ind0 = inds [0 ]
10601050 ary_sh = ary .shape
10611051 p_end = p + len (inds )
0 commit comments