77namespace kernel_float {
88
99template <typename F, typename L, typename R>
10- using zip_type = vector<
11- result_t <F, vector_value_type<L>, vector_value_type<R>>,
12- broadcast_vector_extent_type<L, R>>;
10+ using zip_type = map_type<F, L, R>;
1311
1412/* *
1513 * Combines the elements from the two inputs (`left` and `right`) element-wise, applying a provided binary
@@ -25,20 +23,7 @@ using zip_type = vector<
2523 */
2624template <typename F, typename L, typename R>
2725KERNEL_FLOAT_INLINE zip_type<F, L, R> zip (F fun, const L& left, const R& right) {
28- using A = vector_value_type<L>;
29- using B = vector_value_type<R>;
30- using O = result_t <F, A, B>;
31- using E = broadcast_vector_extent_type<L, R>;
32- vector_storage<O, E::value> result;
33-
34- detail::apply_impl<F, E::value, O, A, B>::call (
35- fun,
36- result.data (),
37- detail::broadcast_impl<A, vector_extent_type<L>, E>::call (into_vector_storage (left)).data (),
38- detail::broadcast_impl<B, vector_extent_type<R>, E>::call (into_vector_storage (right))
39- .data ());
40-
41- return result;
26+ return ::kernel_float::map (fun, left, right);
4227}
4328
4429template <typename F, typename L, typename R>
@@ -67,7 +52,14 @@ KERNEL_FLOAT_INLINE zip_common_type<F, L, R> zip_common(F fun, const L& left, co
6752
6853 vector_storage<O, E::value> result;
6954
70- detail::apply_impl<F, E::value, O, T, T>::call (
55+ // Use the `apply_fastmath_impl` if KERNEL_FLOAT_FAST_MATH is enabled
56+ #if KERNEL_FLOAT_FAST_MATH
57+ using apply_impl = detail::apply_fastmath_impl<F, E::value, O, T, T>;
58+ #else
59+ using apply_impl = detail::apply_impl<F, E::value, O, T, T>;
60+ #endif
61+
62+ apply_impl::call (
7163 fun,
7264 result.data (),
7365 detail::convert_impl<vector_value_type<L>, vector_extent_type<L>, T, E>::call (
@@ -277,36 +269,17 @@ KERNEL_FLOAT_DEFINE_BINARY(
277269#if KERNEL_FLOAT_IS_DEVICE
278270KERNEL_FLOAT_DEFINE_BINARY (
279271 rhypot,
280- (T( 1 ) / ops::hypot<T>()(left, right)),
272+ (ops::rcp<T>( ops::hypot<T>()(left, right) )),
281273 ::rhypot(left, right),
282274 ::rhypotf(left, right))
283275#else
284276KERNEL_FLOAT_DEFINE_BINARY (
285277 rhypot,
286- (T( 1 ) / ops::hypot<T>()(left, right)),
278+ (ops::rcp<T>( ops::hypot<T>()(left, right) )),
287279 (double (1 ) / ::hypot(left, right)),
288280 (float (1 ) / ::hypotf(left, right)))
289281#endif
290282
291- #if KERNEL_FLOAT_IS_DEVICE
292- #define KERNEL_FLOAT_DEFINE_BINARY_FAST (FUN_NAME, OP_NAME, FLOAT_FUN ) \
293- KERNEL_FLOAT_DEFINE_BINARY ( \
294- FUN_NAME, \
295- ops::OP_NAME<T> {}(left, right), \
296- ops::OP_NAME<double> {}(left, right), \
297- ops::OP_NAME<float > {}(left, right))
298- #else
299- #define KERNEL_FLOAT_DEFINE_BINARY_FAST (FUN_NAME, OP_NAME, FLOAT_FUN ) \
300- KERNEL_FLOAT_DEFINE_BINARY ( \
301- FUN_NAME, \
302- ops::OP_NAME<T> {}(left, right), \
303- ops::OP_NAME<double> {}(left, right), \
304- ops::OP_NAME<float > {}(left, right))
305- #endif
306-
307- KERNEL_FLOAT_DEFINE_BINARY_FAST (fast_div, divide, __fdividef)
308- KERNEL_FLOAT_DEFINE_BINARY_FAST(fast_pow, pow, __powf)
309-
310283namespace ops {
311284template <>
312285struct add <bool > {
@@ -323,6 +296,52 @@ struct multiply<bool> {
323296};
324297}; // namespace ops
325298
299+ namespace detail {
300+ template <typename T, size_t N>
301+ struct apply_fastmath_impl <ops::divide<T>, N, T, T, T> {
302+ KERNEL_FLOAT_INLINE static void
303+ call (ops::divide<T> fun, T* result, const T* lhs, const T* rhs) {
304+ T rhs_rcp[N];
305+
306+ // Fast way to perform division is to multiply by the reciprocal
307+ apply_fastmath_impl<ops::rcp<T>, N, T, T, T>::call ({}, rhs_rcp, rhs);
308+ apply_fastmath_impl<ops::multiply<T>, N, T, T, T>::call ({}, result, lhs, rhs_rcp);
309+ }
310+ };
311+
312+ #if KERNEL_FLOAT_IS_DEVICE
313+ template <size_t N>
314+ struct apply_fastmath_impl <ops::divide<float >, N, float , float , float > {
315+ KERNEL_FLOAT_INLINE static void
316+ call (ops::divide<float > fun, float * result, const float * lhs, const float * rhs) {
317+ #pragma unroll
318+ for (size_t i = 0 ; i < N; i++) {
319+ result[i] = __fdividef (lhs[i], rhs[i]);
320+ }
321+ }
322+ };
323+ #endif
324+ } // namespace detail
325+
326+ template <typename L, typename R, typename T = promoted_vector_value_type<L, R>>
327+ KERNEL_FLOAT_INLINE zip_common_type<ops::divide<T>, T, T>
328+ fast_divide (const L& left, const R& right) {
329+ using E = broadcast_vector_extent_type<L, R>;
330+ vector_storage<T, E::value> result;
331+
332+ detail::apply_fastmath_impl<ops::divide<T>, E::value, T, T, T>::call (
333+ ops::divide<T> {},
334+ result.data (),
335+ detail::convert_impl<vector_value_type<L>, vector_extent_type<L>, T, E>::call (
336+ into_vector_storage (left))
337+ .data (),
338+ detail::convert_impl<vector_value_type<R>, vector_extent_type<R>, T, E>::call (
339+ into_vector_storage (right))
340+ .data ());
341+
342+ return result;
343+ }
344+
326345namespace detail {
327346template <typename T>
328347struct cross_impl {
0 commit comments