@@ -756,20 +756,28 @@ def _extract_impl(ary, ary_mask, axis=0):
756756 raise TypeError (
757757 f"Expecting type dpctl.tensor.usm_ndarray, got { type (ary )} "
758758 )
759- if not isinstance (ary_mask , dpt .usm_ndarray ):
760- raise TypeError (
761- f"Expecting type dpctl.tensor.usm_ndarray, got { type ( ary_mask ) } "
759+ if isinstance (ary_mask , dpt .usm_ndarray ):
760+ dst_usm_type = dpctl . utils . get_coerced_usm_type (
761+ ( ary . usm_type , ary_mask . usm_type )
762762 )
763- dst_usm_type = dpctl .utils .get_coerced_usm_type (
764- (ary .usm_type , ary_mask .usm_type )
765- )
766- exec_q = dpctl .utils .get_execution_queue (
767- (ary .sycl_queue , ary_mask .sycl_queue )
768- )
769- if exec_q is None :
770- raise dpctl .utils .ExecutionPlacementError (
771- "arrays have different associated queues. "
772- "Use `y.to_device(x.device)` to migrate."
763+ exec_q = dpctl .utils .get_execution_queue (
764+ (ary .sycl_queue , ary_mask .sycl_queue )
765+ )
766+ if exec_q is None :
767+ raise dpctl .utils .ExecutionPlacementError (
768+ "arrays have different associated queues. "
769+ "Use `y.to_device(x.device)` to migrate."
770+ )
771+ elif isinstance (ary_mask , np .ndarray ):
772+ dst_usm_type = ary .usm_type
773+ exec_q = ary .sycl_queue
774+ ary_mask = dpt .asarray (
775+ ary_mask , usm_type = dst_usm_type , sycl_queue = exec_q
776+ )
777+ else :
778+ raise TypeError (
779+ "Expecting type dpctl.tensor.usm_ndarray or numpy.ndarray, got "
780+ f"{ type (ary_mask )} "
773781 )
774782 ary_nd = ary .ndim
775783 pp = normalize_axis_index (operator .index (axis ), ary_nd )
@@ -839,31 +847,32 @@ def _nonzero_impl(ary):
839847
840848def _validate_indices (inds , queue_list , usm_type_list ):
841849 """
842- Utility for validating indices are usm_ndarray of integral dtype or Python
843- integers. At least one must be an array.
850+ Utility for validating indices are NumPy ndarray or usm_ndarray of integral
851+ dtype or Python integers. At least one must be an array.
844852
845853 For each array, the queue and usm type are appended to `queue_list` and
846854 `usm_type_list`, respectively.
847855 """
848- any_usmarray = False
856+ any_array = False
849857 for ind in inds :
850- if isinstance (ind , dpt .usm_ndarray ):
851- any_usmarray = True
858+ if isinstance (ind , ( np . ndarray , dpt .usm_ndarray ) ):
859+ any_array = True
852860 if ind .dtype .kind not in "ui" :
853861 raise IndexError (
854862 "arrays used as indices must be of integer (or boolean) "
855863 "type"
856864 )
857- queue_list .append (ind .sycl_queue )
858- usm_type_list .append (ind .usm_type )
865+ if isinstance (ind , dpt .usm_ndarray ):
866+ queue_list .append (ind .sycl_queue )
867+ usm_type_list .append (ind .usm_type )
859868 elif not isinstance (ind , Integral ):
860869 raise TypeError (
861- "all elements of `ind` expected to be usm_ndarrays "
862- f"or integers, found { type (ind )} "
870+ "all elements of `ind` expected to be usm_ndarrays, "
871+ f"NumPy arrays, or integers, found { type (ind )} "
863872 )
864- if not any_usmarray :
873+ if not any_array :
865874 raise TypeError (
866- "at least one element of `inds` expected to be a usm_ndarray "
875+ "at least one element of `inds` expected to be an array "
867876 )
868877 return inds
869878
@@ -942,8 +951,7 @@ def _take_multi_index(ary, inds, p, mode=0):
942951 "be associated with the same queue."
943952 )
944953
945- if len (inds ) > 1 :
946- inds = _prepare_indices_arrays (inds , exec_q , res_usm_type )
954+ inds = _prepare_indices_arrays (inds , exec_q , res_usm_type )
947955
948956 ind0 = inds [0 ]
949957 ary_sh = ary .shape
@@ -976,16 +984,28 @@ def _place_impl(ary, ary_mask, vals, axis=0):
976984 raise TypeError (
977985 f"Expecting type dpctl.tensor.usm_ndarray, got { type (ary )} "
978986 )
979- if not isinstance (ary_mask , dpt .usm_ndarray ):
980- raise TypeError (
981- f"Expecting type dpctl.tensor.usm_ndarray, got { type (ary_mask )} "
987+ if isinstance (ary_mask , dpt .usm_ndarray ):
988+ exec_q = dpctl .utils .get_execution_queue (
989+ (
990+ ary .sycl_queue ,
991+ ary_mask .sycl_queue ,
992+ )
982993 )
983- exec_q = dpctl .utils .get_execution_queue (
984- (
985- ary .sycl_queue ,
986- ary_mask .sycl_queue ,
994+ if exec_q is None :
995+ raise dpctl .utils .ExecutionPlacementError (
996+ "arrays have different associated queues. "
997+ "Use `y.to_device(x.device)` to migrate."
998+ )
999+ elif isinstance (ary_mask , np .ndarray ):
1000+ exec_q = ary .sycl_queue
1001+ ary_mask = dpt .asarray (
1002+ ary_mask , usm_type = ary .usm_type , sycl_queue = exec_q
1003+ )
1004+ else :
1005+ raise TypeError (
1006+ "Expecting type dpctl.tensor.usm_ndarray or numpy.ndarray, got "
1007+ f"{ type (ary_mask )} "
9871008 )
988- )
9891009 if exec_q is not None :
9901010 if not isinstance (vals , dpt .usm_ndarray ):
9911011 vals = dpt .asarray (vals , dtype = ary .dtype , sycl_queue = exec_q )
@@ -1080,8 +1100,7 @@ def _put_multi_index(ary, inds, p, vals, mode=0):
10801100 "be associated with the same queue."
10811101 )
10821102
1083- if len (inds ) > 1 :
1084- inds = _prepare_indices_arrays (inds , exec_q , vals_usm_type )
1103+ inds = _prepare_indices_arrays (inds , exec_q , vals_usm_type )
10851104
10861105 ind0 = inds [0 ]
10871106 ary_sh = ary .shape
0 commit comments