@@ -989,15 +989,22 @@ def _place_impl(ary, ary_mask, vals, axis=0):
989989 ary_mask .sycl_queue ,
990990 )
991991 )
992+ coerced_usm_type = dpctl .utils .get_coerced_usm_type (
993+ (
994+ ary .usm_type ,
995+ ary_mask .usm_type ,
996+ )
997+ )
992998 if exec_q is None :
993999 raise dpctl .utils .ExecutionPlacementError (
9941000 "arrays have different associated queues. "
9951001 "Use `y.to_device(x.device)` to migrate."
9961002 )
9971003 elif isinstance (ary_mask , np .ndarray ):
9981004 exec_q = ary .sycl_queue
1005+ coerced_usm_type = ary .usm_type
9991006 ary_mask = dpt .asarray (
1000- ary_mask , usm_type = ary . usm_type , sycl_queue = exec_q
1007+ ary_mask , usm_type = coerced_usm_type , sycl_queue = exec_q
10011008 )
10021009 else :
10031010 raise TypeError (
@@ -1006,9 +1013,20 @@ def _place_impl(ary, ary_mask, vals, axis=0):
10061013 )
10071014 if exec_q is not None :
10081015 if not isinstance (vals , dpt .usm_ndarray ):
1009- vals = dpt .asarray (vals , dtype = ary .dtype , sycl_queue = exec_q )
1016+ vals = dpt .asarray (
1017+ vals ,
1018+ dtype = ary .dtype ,
1019+ usm_type = coerced_usm_type ,
1020+ sycl_queue = exec_q ,
1021+ )
10101022 else :
10111023 exec_q = dpctl .utils .get_execution_queue ((exec_q , vals .sycl_queue ))
1024+ coerced_usm_type = dpctl .utils .get_coerced_usm_type (
1025+ (
1026+ coerced_usm_type ,
1027+ vals .usm_type ,
1028+ )
1029+ )
10121030 if exec_q is None :
10131031 raise dpctl .utils .ExecutionPlacementError (
10141032 "arrays have different associated queues. "
@@ -1023,7 +1041,12 @@ def _place_impl(ary, ary_mask, vals, axis=0):
10231041 )
10241042 mask_nelems = ary_mask .size
10251043 cumsum_dt = dpt .int32 if mask_nelems < int32_t_max else dpt .int64
1026- cumsum = dpt .empty (mask_nelems , dtype = cumsum_dt , device = ary_mask .device )
1044+ cumsum = dpt .empty (
1045+ mask_nelems ,
1046+ dtype = cumsum_dt ,
1047+ usm_type = coerced_usm_type ,
1048+ device = ary_mask .device ,
1049+ )
10271050 exec_q = cumsum .sycl_queue
10281051 _manager = dpctl .utils .SequentialOrderManager [exec_q ]
10291052 dep_ev = _manager .submitted_events
@@ -1069,17 +1092,26 @@ def _put_multi_index(ary, inds, p, vals, mode=0):
10691092 if not isinstance (inds , (list , tuple )):
10701093 inds = (inds ,)
10711094
1072- exec_q , vals_usm_type = _get_indices_queue_usm_type (
1095+ exec_q , coerced_usm_type = _get_indices_queue_usm_type (
10731096 inds , ary .sycl_queue , ary .usm_type
10741097 )
10751098
10761099 if exec_q is not None :
10771100 if not isinstance (vals , dpt .usm_ndarray ):
10781101 vals = dpt .asarray (
1079- vals , dtype = ary .dtype , usm_type = vals_usm_type , sycl_queue = exec_q
1102+ vals ,
1103+ dtype = ary .dtype ,
1104+ usm_type = coerced_usm_type ,
1105+ sycl_queue = exec_q ,
10801106 )
10811107 else :
10821108 exec_q = dpctl .utils .get_execution_queue ((exec_q , vals .sycl_queue ))
1109+ coerced_usm_type = dpctl .utils .get_coerced_usm_type (
1110+ (
1111+ coerced_usm_type ,
1112+ vals .usm_type ,
1113+ )
1114+ )
10831115 if exec_q is None :
10841116 raise dpctl .utils .ExecutionPlacementError (
10851117 "Can not automatically determine where to allocate the "
@@ -1088,7 +1120,7 @@ def _put_multi_index(ary, inds, p, vals, mode=0):
10881120 "be associated with the same queue."
10891121 )
10901122
1091- inds = _prepare_indices_arrays (inds , exec_q , vals_usm_type )
1123+ inds = _prepare_indices_arrays (inds , exec_q , coerced_usm_type )
10921124
10931125 ind0 = inds [0 ]
10941126 ary_sh = ary .shape
0 commit comments