@@ -114,15 +114,12 @@ def _reduction_over_axis(
114114 res_shape = res_shape + (1 ,) * red_nd
115115 inv_perm = sorted (range (nd ), key = lambda d : perm [d ])
116116 res_shape = tuple (res_shape [i ] for i in inv_perm )
117- return dpt .astype (
118- dpt .full (
119- res_shape ,
120- _identity ,
121- dtype = _default_reduction_type_fn (inp_dt , q ),
122- usm_type = res_usm_type ,
123- sycl_queue = q ,
124- ),
125- res_dt ,
117+ return dpt .full (
118+ res_shape ,
119+ _identity ,
120+ dtype = res_dt ,
121+ usm_type = res_usm_type ,
122+ sycl_queue = q ,
126123 )
127124 if red_nd == 0 :
128125 return dpt .astype (x , res_dt , copy = False )
@@ -142,21 +139,51 @@ def _reduction_over_axis(
142139 "Automatically determined reduction data type does not "
143140 "have direct implementation"
144141 )
145- tmp_dt = _default_reduction_type_fn (inp_dt , q )
146- tmp = dpt .empty (
147- res_shape , dtype = tmp_dt , usm_type = res_usm_type , sycl_queue = q
148- )
149- ht_e_tmp , r_e = _reduction_fn (
150- src = arr2 , trailing_dims_to_reduce = red_nd , dst = tmp , sycl_queue = q
151- )
152- host_tasks_list .append (ht_e_tmp )
153- res = dpt .empty (
154- res_shape , dtype = res_dt , usm_type = res_usm_type , sycl_queue = q
155- )
156- ht_e , _ = ti ._copy_usm_ndarray_into_usm_ndarray (
157- src = tmp , dst = res , sycl_queue = q , depends = [r_e ]
158- )
159- host_tasks_list .append (ht_e )
142+ if _dtype_supported (res_dt , res_dt , res_usm_type , q ):
143+ tmp = dpt .empty (
144+ arr2 .shape , dtype = res_dt , usm_type = res_usm_type , sycl_queue = q
145+ )
146+ ht_e_cpy , cpy_e = ti ._copy_usm_ndarray_into_usm_ndarray (
147+ src = arr2 , dst = tmp , sycl_queue = q
148+ )
149+ host_tasks_list .append (ht_e_cpy )
150+ res = dpt .empty (
151+ res_shape , dtype = res_dt , usm_type = res_usm_type , sycl_queue = q
152+ )
153+ ht_e_red , _ = _reduction_fn (
154+ src = tmp ,
155+ trailing_dims_to_reduce = red_nd ,
156+ dst = res ,
157+ sycl_queue = q ,
158+ depends = [cpy_e ],
159+ )
160+ host_tasks_list .append (ht_e_red )
161+ else :
162+ buf_dt = _default_reduction_type_fn (inp_dt , q )
163+ tmp = dpt .empty (
164+ arr2 .shape , dtype = buf_dt , usm_type = res_usm_type , sycl_queue = q
165+ )
166+ ht_e_cpy , cpy_e = ti ._copy_usm_ndarray_into_usm_ndarray (
167+ src = arr2 , dst = tmp , sycl_queue = q
168+ )
169+ tmp_res = dpt .empty (
170+ res_shape , dtype = buf_dt , usm_type = res_usm_type , sycl_queue = q
171+ )
172+ host_tasks_list .append (ht_e_cpy )
173+ res = dpt .empty (
174+ res_shape , dtype = res_dt , usm_type = res_usm_type , sycl_queue = q
175+ )
176+ ht_e_red , r_e = _reduction_fn (
177+ src = tmp ,
178+ trailing_dims_to_reduce = red_nd ,
179+ dst = tmp_res ,
180+ sycl_queue = q ,
181+ depends = [cpy_e ],
182+ )
183+ ht_e_cpy2 , _ = ti ._copy_usm_ndarray_into_usm_ndarray (
184+ src = tmp_res , dst = res , sycl_queue = q , depends = [r_e ]
185+ )
186+ host_tasks_list .append (ht_e_cpy2 )
160187
161188 if keepdims :
162189 res_shape = res_shape + (1 ,) * red_nd
@@ -445,7 +472,7 @@ def _comparison_over_axis(x, axis, keepdims, _reduction_fn):
445472
446473
447474def max (x , axis = None , keepdims = False ):
448- """max(x, axis=None, dtype=None, keepdims=False)
475+ """max(x, axis=None, keepdims=False)
449476
450477 Calculates the maximum value of the input array `x`.
451478
@@ -473,7 +500,7 @@ def max(x, axis=None, keepdims=False):
473500
474501
475502def min (x , axis = None , keepdims = False ):
476- """min(x, axis=None, dtype=None, keepdims=False)
503+ """min(x, axis=None, keepdims=False)
477504
478505 Calculates the minimum value of the input array `x`.
479506
@@ -550,7 +577,7 @@ def _search_over_axis(x, axis, keepdims, _reduction_fn):
550577
551578
552579def argmax (x , axis = None , keepdims = False ):
553- """argmax(x, axis=None, dtype=None, keepdims=False)
580+ """argmax(x, axis=None, keepdims=False)
554581
555582 Returns the indices of the maximum values of the input array `x` along a
556583 specified axis.
@@ -582,7 +609,7 @@ def argmax(x, axis=None, keepdims=False):
582609
583610
584611def argmin (x , axis = None , keepdims = False ):
585- """argmin(x, axis=None, dtype=None, keepdims=False)
612+ """argmin(x, axis=None, keepdims=False)
586613
587614 Returns the indices of the minimum values of the input array `x` along a
588615 specified axis.
0 commit comments