@@ -118,7 +118,7 @@ def _reduction_over_axis(
118118 dpt .full (
119119 res_shape ,
120120 _identity ,
121- dtype = _default_reduction_type_fn ( inp_dt , q ) ,
121+ dtype = dtype ,
122122 usm_type = res_usm_type ,
123123 sycl_queue = q ,
124124 ),
@@ -142,21 +142,51 @@ def _reduction_over_axis(
142142 "Automatically determined reduction data type does not "
143143 "have direct implementation"
144144 )
145- tmp_dt = _default_reduction_type_fn (inp_dt , q )
146- tmp = dpt .empty (
147- res_shape , dtype = tmp_dt , usm_type = res_usm_type , sycl_queue = q
148- )
149- ht_e_tmp , r_e = _reduction_fn (
150- src = arr2 , trailing_dims_to_reduce = red_nd , dst = tmp , sycl_queue = q
151- )
152- host_tasks_list .append (ht_e_tmp )
153- res = dpt .empty (
154- res_shape , dtype = res_dt , usm_type = res_usm_type , sycl_queue = q
155- )
156- ht_e , _ = ti ._copy_usm_ndarray_into_usm_ndarray (
157- src = tmp , dst = res , sycl_queue = q , depends = [r_e ]
158- )
159- host_tasks_list .append (ht_e )
145+ if _dtype_supported (res_dt , res_dt , res_usm_type , q ):
146+ tmp = dpt .empty (
147+ arr2 .shape , dtype = res_dt , usm_type = res_usm_type , sycl_queue = q
148+ )
149+ ht_e_cpy , cpy_e = ti ._copy_usm_ndarray_into_usm_ndarray (
150+ src = arr2 , dst = tmp , sycl_queue = q
151+ )
152+ host_tasks_list .append (ht_e_cpy )
153+ res = dpt .empty (
154+ res_shape , dtype = res_dt , usm_type = res_usm_type , sycl_queue = q
155+ )
156+ ht_e_red , _ = _reduction_fn (
157+ src = tmp ,
158+ trailing_dims_to_reduce = red_nd ,
159+ dst = res ,
160+ sycl_queue = q ,
161+ depends = [cpy_e ],
162+ )
163+ host_tasks_list .append (ht_e_red )
164+ else :
165+ buf_dt = _default_reduction_type_fn (inp_dt , q )
166+ tmp = dpt .empty (
167+ arr2 .shape , dtype = buf_dt , usm_type = res_usm_type , sycl_queue = q
168+ )
169+ ht_e_cpy , cpy_e = ti ._copy_usm_ndarray_into_usm_ndarray (
170+ src = arr2 , dst = tmp , sycl_queue = q
171+ )
172+ tmp_res = dpt .empty (
173+ res_shape , dtype = buf_dt , usm_type = res_usm_type , sycl_queue = q
174+ )
175+ host_tasks_list .append (ht_e_cpy )
176+ res = dpt .empty (
177+ res_shape , dtype = res_dt , usm_type = res_usm_type , sycl_queue = q
178+ )
179+ ht_e_red , r_e = _reduction_fn (
180+ src = tmp ,
181+ trailing_dims_to_reduce = red_nd ,
182+ dst = tmp_res ,
183+ sycl_queue = q ,
184+ depends = [cpy_e ],
185+ )
186+ ht_e_cpy2 , _ = ti ._copy_usm_ndarray_into_usm_ndarray (
187+ src = tmp_res , dst = res , sycl_queue = q , depends = [r_e ]
188+ )
189+ host_tasks_list .append (ht_e_cpy2 )
160190
161191 if keepdims :
162192 res_shape = res_shape + (1 ,) * red_nd
0 commit comments