@@ -53,27 +53,27 @@ template <typename argT1, typename argT2, typename resT>
5353struct FloorDivideFunctor
5454{
5555
56- using supports_sg_loadstore =
57- std::negation<std::disjunction<tu_ns::is_complex<argT1>,
58- tu_ns::is_complex<argT2>>>; // TRUE
59- using supports_vec = std::negation<std::disjunction<
60- tu_ns::is_complex<argT1>,
61- tu_ns::is_complex<argT2>,
62- std::conjunction<std::is_integral<argT1>, std::is_signed<argT1>>,
63- std::conjunction<std::is_integral<argT2>, std::is_signed<argT2>>>>;
64- // no vec overload for signed integers to avoid loop
56+ using supports_sg_loadstore = std::negation<
57+ std::disjunction<tu_ns::is_complex<argT1>, tu_ns::is_complex<argT2>>>;
58+ using supports_vec = std::negation<
59+ std::disjunction<tu_ns::is_complex<argT1>, tu_ns::is_complex<argT2>>>;
6560
6661 resT operator ()(const argT1 &in1, const argT2 &in2)
6762 {
6863 auto tmp = in1 / in2;
6964 if constexpr (std::is_integral_v<decltype (tmp)>) {
7065 if constexpr (std::is_unsigned_v<decltype (tmp)>) {
71- return tmp;
66+ return (in2 == argT2 ( 0 )) ? resT ( 0 ) : tmp;
7267 }
7368 else {
74- auto rem = in1 % in2;
75- auto corr = (rem != 0 && ((rem < 0 ) != (in2 < 0 )));
76- return (tmp - corr);
69+ if (in2 == argT2 (0 )) {
70+ return resT (0 );
71+ }
72+ else {
73+ auto rem = in1 % in2;
74+ auto corr = (rem != 0 && ((rem < 0 ) != (in2 < 0 )));
75+ return (tmp - corr);
76+ }
7777 }
7878 }
7979 else {
@@ -86,17 +86,37 @@ struct FloorDivideFunctor
8686 const sycl::vec<argT2, vec_sz> &in2)
8787 {
8888 auto tmp = in1 / in2;
89- if constexpr (std::is_same_v<resT,
90- typename decltype (tmp)::element_type> &&
91- std::is_integral_v<resT>)
92- {
93- return tmp;
94- }
95- else if constexpr (std::is_integral_v<typename decltype (
96- tmp)::element_type>) {
97- using dpctl::tensor::type_utils::vec_cast;
98- return vec_cast<resT, typename decltype (tmp)::element_type, vec_sz>(
99- tmp);
89+ using tmpT = typename decltype (tmp)::element_type;
90+ if constexpr (std::is_integral_v<tmpT>) {
91+ if constexpr (std::is_signed_v<tmpT>) {
92+ auto rem_tmp = in1 % in2;
93+ #pragma unroll
94+ for (int i = 0 ; i < vec_sz; ++i) {
95+ if (in2[i] == argT2 (0 )) {
96+ tmp[i] = tmpT (0 );
97+ }
98+ else {
99+ tmpT corr = (rem_tmp[i] != 0 &&
100+ ((rem_tmp[i] < 0 ) != (in2[i] < 0 )));
101+ tmp[i] -= corr;
102+ }
103+ }
104+ }
105+ else {
106+ #pragma unroll
107+ for (int i = 0 ; i < vec_sz; ++i) {
108+ if (in2[i] == argT2 (0 )) {
109+ tmp[i] = tmpT (0 );
110+ }
111+ }
112+ }
113+ if constexpr (std::is_same_v<resT, tmpT>) {
114+ return tmp;
115+ }
116+ else {
117+ using dpctl::tensor::type_utils::vec_cast;
118+ return vec_cast<resT, tmpT, vec_sz>(tmp);
119+ }
100120 }
101121 else {
102122 sycl::vec<resT, vec_sz> res = sycl::floor (tmp);
0 commit comments