@@ -70,10 +70,12 @@ template <typename argTy, typename outTy>
7070struct TypePairSupportDataForProdAccumulation
7171{
7272 static constexpr bool is_defined = std::disjunction<
73+ td_ns::TypePairDefinedEntry<argTy, bool , outTy, bool >,
7374 td_ns::TypePairDefinedEntry<argTy, bool , outTy, std::int32_t >,
7475 td_ns::TypePairDefinedEntry<argTy, bool , outTy, std::int64_t >,
7576
7677 // input int8_t
78+ td_ns::TypePairDefinedEntry<argTy, std::int8_t , outTy, std::int8_t >,
7779 td_ns::TypePairDefinedEntry<argTy, std::int8_t , outTy, std::int32_t >,
7880 td_ns::TypePairDefinedEntry<argTy, std::int8_t , outTy, std::int64_t >,
7981
@@ -138,7 +140,9 @@ struct CumProd1DContigFactory
138140 if constexpr (TypePairSupportDataForProdAccumulation<srcTy,
139141 dstTy>::is_defined)
140142 {
141- using ScanOpT = sycl::multiplies<dstTy>;
143+ using ScanOpT = std::conditional_t <std::is_same_v<dstTy, bool >,
144+ sycl::logical_and<dstTy>,
145+ sycl::multiplies<dstTy>>;
142146 constexpr bool include_initial = false ;
143147 if constexpr (std::is_same_v<srcTy, dstTy>) {
144148 using dpctl::tensor::kernels::accumulators::NoOpTransformer;
@@ -171,7 +175,9 @@ struct CumProd1DIncludeInitialContigFactory
171175 if constexpr (TypePairSupportDataForProdAccumulation<srcTy,
172176 dstTy>::is_defined)
173177 {
174- using ScanOpT = sycl::multiplies<dstTy>;
178+ using ScanOpT = std::conditional_t <std::is_same_v<dstTy, bool >,
179+ sycl::logical_and<dstTy>,
180+ sycl::multiplies<dstTy>>;
175181 constexpr bool include_initial = true ;
176182 if constexpr (std::is_same_v<srcTy, dstTy>) {
177183 using dpctl::tensor::kernels::accumulators::NoOpTransformer;
@@ -204,7 +210,9 @@ struct CumProdStridedFactory
204210 if constexpr (TypePairSupportDataForProdAccumulation<srcTy,
205211 dstTy>::is_defined)
206212 {
207- using ScanOpT = sycl::multiplies<dstTy>;
213+ using ScanOpT = std::conditional_t <std::is_same_v<dstTy, bool >,
214+ sycl::logical_and<dstTy>,
215+ sycl::multiplies<dstTy>>;
208216 constexpr bool include_initial = false ;
209217 if constexpr (std::is_same_v<srcTy, dstTy>) {
210218 using dpctl::tensor::kernels::accumulators::NoOpTransformer;
@@ -237,7 +245,9 @@ struct CumProdIncludeInitialStridedFactory
237245 if constexpr (TypePairSupportDataForProdAccumulation<srcTy,
238246 dstTy>::is_defined)
239247 {
240- using ScanOpT = sycl::multiplies<dstTy>;
248+ using ScanOpT = std::conditional_t <std::is_same_v<dstTy, bool >,
249+ sycl::logical_and<dstTy>,
250+ sycl::multiplies<dstTy>>;
241251 constexpr bool include_initial = true ;
242252 if constexpr (std::is_same_v<srcTy, dstTy>) {
243253 using dpctl::tensor::kernels::accumulators::NoOpTransformer;
0 commit comments