@@ -845,14 +845,16 @@ def _nonzero_impl(ary):
845845 return res
846846
847847
848- def _validate_indices (inds , queue_list , usm_type_list ):
848+ def _get_indices_queue_usm_type (inds , queue , usm_type ):
849849 """
850850 Utility for validating indices are NumPy ndarray or usm_ndarray of integral
851851 dtype or Python integers. At least one must be an array.
852852
853853 For each array, the queue and usm type are appended to `queue_list` and
854854 `usm_type_list`, respectively.
855855 """
856+ queues = [queue ]
857+ usm_types = [usm_type ]
856858 any_array = False
857859 for ind in inds :
858860 if isinstance (ind , (np .ndarray , dpt .usm_ndarray )):
@@ -863,8 +865,8 @@ def _validate_indices(inds, queue_list, usm_type_list):
863865 "type"
864866 )
865867 if isinstance (ind , dpt .usm_ndarray ):
866- queue_list .append (ind .sycl_queue )
867- usm_type_list .append (ind .usm_type )
868+ queues .append (ind .sycl_queue )
869+ usm_types .append (ind .usm_type )
868870 elif not isinstance (ind , Integral ):
869871 raise TypeError (
870872 "all elements of `ind` expected to be usm_ndarrays, "
@@ -874,7 +876,9 @@ def _validate_indices(inds, queue_list, usm_type_list):
874876 raise TypeError (
875877 "at least one element of `inds` expected to be an array"
876878 )
877- return inds
879+ usm_type = dpctl .utils .get_coerced_usm_type (usm_types )
880+ q = dpctl .utils .get_execution_queue (queues )
881+ return q , usm_type
878882
879883
880884def _prepare_indices_arrays (inds , q , usm_type ):
@@ -931,18 +935,12 @@ def _take_multi_index(ary, inds, p, mode=0):
931935 raise ValueError (
932936 "Invalid value for mode keyword, only 0 or 1 is supported"
933937 )
934- queues_ = [
935- ary .sycl_queue ,
936- ]
937- usm_types_ = [
938- ary .usm_type ,
939- ]
940938 if not isinstance (inds , (list , tuple )):
941939 inds = (inds ,)
942940
943- _validate_indices ( inds , queues_ , usm_types_ )
944- res_usm_type = dpctl . utils . get_coerced_usm_type ( usm_types_ )
945- exec_q = dpctl . utils . get_execution_queue ( queues_ )
941+ exec_q , res_usm_type = _get_indices_queue_usm_type (
942+ inds , ary . sycl_queue , ary . usm_type
943+ )
946944 if exec_q is None :
947945 raise dpctl .utils .ExecutionPlacementError (
948946 "Can not automatically determine where to allocate the "
@@ -1068,23 +1066,13 @@ def _put_multi_index(ary, inds, p, vals, mode=0):
10681066 raise ValueError (
10691067 "Invalid value for mode keyword, only 0 or 1 is supported"
10701068 )
1071- if isinstance (vals , dpt .usm_ndarray ):
1072- queues_ = [ary .sycl_queue , vals .sycl_queue ]
1073- usm_types_ = [ary .usm_type , vals .usm_type ]
1074- else :
1075- queues_ = [
1076- ary .sycl_queue ,
1077- ]
1078- usm_types_ = [
1079- ary .usm_type ,
1080- ]
10811069 if not isinstance (inds , (list , tuple )):
10821070 inds = (inds ,)
10831071
1084- _validate_indices (inds , queues_ , usm_types_ )
1072+ exec_q , vals_usm_type = _get_indices_queue_usm_type (
1073+ inds , ary .sycl_queue , ary .usm_type
1074+ )
10851075
1086- vals_usm_type = dpctl .utils .get_coerced_usm_type (usm_types_ )
1087- exec_q = dpctl .utils .get_execution_queue (queues_ )
10881076 if exec_q is not None :
10891077 if not isinstance (vals , dpt .usm_ndarray ):
10901078 vals = dpt .asarray (
0 commit comments