Skip to content

Commit 4d3d356

Browse files
committed
Redo how fastmath functions are implemented
Previously, each fast math function was implemented by having an corresponding `ops::fast_X<T>` functor for each `ops::X<T>` functor. This commit rewrite this logic to instead have a special `apply_fastmath_impl<...>` helper that falls back to `apply_impl<...>` if there is not fast math version available. This means the normal apply logic is used if there is not fast math version, which was previously not the case. Additionally, this commit adds a global `KERNEL_FLOAT_FAST_MATH` define that can be used to turn on fast math mode.
1 parent 41246ab commit 4d3d356

File tree

11 files changed

+446
-200
lines changed

11 files changed

+446
-200
lines changed

include/kernel_float/apply.h

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,9 @@ struct apply_recur_impl<1> {
152152
result[0] = fun(inputs[0]...);
153153
}
154154
};
155+
156+
template<typename F, size_t N, typename Output, typename... Args>
157+
struct apply_fastmath_impl: apply_impl<F, N, Output, Args...> {};
155158
} // namespace detail
156159

157160
template<typename F, typename... Args>
@@ -174,7 +177,34 @@ KERNEL_FLOAT_INLINE map_type<F, Args...> map(F fun, const Args&... args) {
174177
using E = broadcast_vector_extent_type<Args...>;
175178
vector_storage<Output, E::value> result;
176179

177-
detail::apply_impl<F, E::value, Output, vector_value_type<Args>...>::call(
180+
// Use the `apply_fastmath_impl` if KERNEL_FLOAT_FAST_MATH is enabled
181+
#if KERNEL_FLOAT_FAST_MATH
182+
using apply_impl = detail::apply_fastmath_impl<F, E::value, Output, vector_value_type<Args>...>;
183+
#else
184+
using apply_impl = detail::apply_impl<F, E::value, Output, vector_value_type<Args>...>;
185+
#endif
186+
187+
apply_impl::call(
188+
fun,
189+
result.data(),
190+
(detail::broadcast_impl<vector_value_type<Args>, vector_extent_type<Args>, E>::call(
191+
into_vector_storage(args))
192+
.data())...);
193+
194+
return result;
195+
}
196+
197+
/**
198+
* Apply the function `F` to each element from the vector `input` and return the results as a new vector. This
199+
* uses fast-math if available for the given function `F`, otherwise this function behaves like `map`.
200+
*/
201+
template<typename F, typename... Args>
202+
KERNEL_FLOAT_INLINE map_type<F, Args...> fast_map(F fun, const Args&... args) {
203+
using Output = result_t<F, vector_value_type<Args>...>;
204+
using E = broadcast_vector_extent_type<Args...>;
205+
vector_storage<Output, E::value> result;
206+
207+
detail::apply_fastmath_impl<F, E::value, Output, vector_value_type<Args>...>::call(
178208
fun,
179209
result.data(),
180210
(detail::broadcast_impl<vector_value_type<Args>, vector_extent_type<Args>, E>::call(

include/kernel_float/base.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,20 @@ struct extent<N> {
8989
static constexpr size_t size = N;
9090
};
9191

92+
namespace detail {
93+
// Indicates that elements of type `T` offer less precision than floats, thus operations
94+
// on elements of type `T` can be performed by upcasting them to ` float`.
95+
template<typename T>
96+
struct allow_float_fallback {
97+
static constexpr bool value = false;
98+
};
99+
100+
template<>
101+
struct allow_float_fallback<float> {
102+
static constexpr bool value = true;
103+
};
104+
} // namespace detail
105+
92106
template<typename T>
93107
struct into_vector_impl {
94108
using value_type = T;

include/kernel_float/bf16.h

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -72,11 +72,7 @@ KERNEL_FLOAT_BF16_UNARY_FUN(rsqrt, ::hrsqrt, ::h2rsqrt)
7272
KERNEL_FLOAT_BF16_UNARY_FUN(sin, ::hsin, ::h2sin)
7373
KERNEL_FLOAT_BF16_UNARY_FUN(sqrt, ::hsqrt, ::h2sqrt)
7474
KERNEL_FLOAT_BF16_UNARY_FUN(trunc, ::htrunc, ::h2trunc)
75-
76-
KERNEL_FLOAT_BF16_UNARY_FUN(fast_exp, ::hexp, ::h2exp)
77-
KERNEL_FLOAT_BF16_UNARY_FUN(fast_log, ::hlog, ::h2log)
78-
KERNEL_FLOAT_BF16_UNARY_FUN(fast_cos, ::hcos, ::h2cos)
79-
KERNEL_FLOAT_BF16_UNARY_FUN(fast_sin, ::hsin, ::h2sin)
75+
KERNEL_FLOAT_BF16_UNARY_FUN(rcp, ::hrcp, ::h2rcp)
8076
#endif
8177

8278
#if KERNEL_FLOAT_CUDA_ARCH >= 800
@@ -114,8 +110,6 @@ KERNEL_FLOAT_BF16_BINARY_FUN(divide, __hdiv, __h2div)
114110
KERNEL_FLOAT_BF16_BINARY_FUN(min, __hmin, __hmin2)
115111
KERNEL_FLOAT_BF16_BINARY_FUN(max, __hmax, __hmax2)
116112

117-
KERNEL_FLOAT_BF16_BINARY_FUN(fast_div, __hdiv, __h2div)
118-
119113
KERNEL_FLOAT_BF16_BINARY_FUN(equal_to, __heq, __heq2)
120114
KERNEL_FLOAT_BF16_BINARY_FUN(not_equal_to, __hneu, __hneu2)
121115
KERNEL_FLOAT_BF16_BINARY_FUN(less, __hlt, __hlt2)

include/kernel_float/binops.h

Lines changed: 58 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,7 @@
77
namespace kernel_float {
88

99
template<typename F, typename L, typename R>
10-
using zip_type = vector<
11-
result_t<F, vector_value_type<L>, vector_value_type<R>>,
12-
broadcast_vector_extent_type<L, R>>;
10+
using zip_type = map_type<F, L, R>;
1311

1412
/**
1513
* Combines the elements from the two inputs (`left` and `right`) element-wise, applying a provided binary
@@ -25,20 +23,7 @@ using zip_type = vector<
2523
*/
2624
template<typename F, typename L, typename R>
2725
KERNEL_FLOAT_INLINE zip_type<F, L, R> zip(F fun, const L& left, const R& right) {
28-
using A = vector_value_type<L>;
29-
using B = vector_value_type<R>;
30-
using O = result_t<F, A, B>;
31-
using E = broadcast_vector_extent_type<L, R>;
32-
vector_storage<O, E::value> result;
33-
34-
detail::apply_impl<F, E::value, O, A, B>::call(
35-
fun,
36-
result.data(),
37-
detail::broadcast_impl<A, vector_extent_type<L>, E>::call(into_vector_storage(left)).data(),
38-
detail::broadcast_impl<B, vector_extent_type<R>, E>::call(into_vector_storage(right))
39-
.data());
40-
41-
return result;
26+
return ::kernel_float::map(fun, left, right);
4227
}
4328

4429
template<typename F, typename L, typename R>
@@ -67,7 +52,14 @@ KERNEL_FLOAT_INLINE zip_common_type<F, L, R> zip_common(F fun, const L& left, co
6752

6853
vector_storage<O, E::value> result;
6954

70-
detail::apply_impl<F, E::value, O, T, T>::call(
55+
// Use the `apply_fastmath_impl` if KERNEL_FLOAT_FAST_MATH is enabled
56+
#if KERNEL_FLOAT_FAST_MATH
57+
using apply_impl = detail::apply_fastmath_impl<F, E::value, O, T, T>;
58+
#else
59+
using apply_impl = detail::apply_impl<F, E::value, O, T, T>;
60+
#endif
61+
62+
apply_impl::call(
7163
fun,
7264
result.data(),
7365
detail::convert_impl<vector_value_type<L>, vector_extent_type<L>, T, E>::call(
@@ -277,36 +269,17 @@ KERNEL_FLOAT_DEFINE_BINARY(
277269
#if KERNEL_FLOAT_IS_DEVICE
278270
KERNEL_FLOAT_DEFINE_BINARY(
279271
rhypot,
280-
(T(1) / ops::hypot<T>()(left, right)),
272+
(ops::rcp<T>(ops::hypot<T>()(left, right))),
281273
::rhypot(left, right),
282274
::rhypotf(left, right))
283275
#else
284276
KERNEL_FLOAT_DEFINE_BINARY(
285277
rhypot,
286-
(T(1) / ops::hypot<T>()(left, right)),
278+
(ops::rcp<T>(ops::hypot<T>()(left, right))),
287279
(double(1) / ::hypot(left, right)),
288280
(float(1) / ::hypotf(left, right)))
289281
#endif
290282

291-
#if KERNEL_FLOAT_IS_DEVICE
292-
#define KERNEL_FLOAT_DEFINE_BINARY_FAST(FUN_NAME, OP_NAME, FLOAT_FUN) \
293-
KERNEL_FLOAT_DEFINE_BINARY( \
294-
FUN_NAME, \
295-
ops::OP_NAME<T> {}(left, right), \
296-
ops::OP_NAME<double> {}(left, right), \
297-
ops::OP_NAME<float> {}(left, right))
298-
#else
299-
#define KERNEL_FLOAT_DEFINE_BINARY_FAST(FUN_NAME, OP_NAME, FLOAT_FUN) \
300-
KERNEL_FLOAT_DEFINE_BINARY( \
301-
FUN_NAME, \
302-
ops::OP_NAME<T> {}(left, right), \
303-
ops::OP_NAME<double> {}(left, right), \
304-
ops::OP_NAME<float> {}(left, right))
305-
#endif
306-
307-
KERNEL_FLOAT_DEFINE_BINARY_FAST(fast_div, divide, __fdividef)
308-
KERNEL_FLOAT_DEFINE_BINARY_FAST(fast_pow, pow, __powf)
309-
310283
namespace ops {
311284
template<>
312285
struct add<bool> {
@@ -323,6 +296,52 @@ struct multiply<bool> {
323296
};
324297
}; // namespace ops
325298

299+
namespace detail {
300+
template<typename T, size_t N>
301+
struct apply_fastmath_impl<ops::divide<T>, N, T, T, T> {
302+
KERNEL_FLOAT_INLINE static void
303+
call(ops::divide<T> fun, T* result, const T* lhs, const T* rhs) {
304+
T rhs_rcp[N];
305+
306+
// Fast way to perform division is to multiply by the reciprocal
307+
apply_fastmath_impl<ops::rcp<T>, N, T, T, T>::call({}, rhs_rcp, rhs);
308+
apply_fastmath_impl<ops::multiply<T>, N, T, T, T>::call({}, result, lhs, rhs_rcp);
309+
}
310+
};
311+
312+
#if KERNEL_FLOAT_IS_DEVICE
313+
template<size_t N>
314+
struct apply_fastmath_impl<ops::divide<float>, N, float, float, float> {
315+
KERNEL_FLOAT_INLINE static void
316+
call(ops::divide<float> fun, float* result, const float* lhs, const float* rhs) {
317+
#pragma unroll
318+
for (size_t i = 0; i < N; i++) {
319+
result[i] = __fdividef(lhs[i], rhs[i]);
320+
}
321+
}
322+
};
323+
#endif
324+
} // namespace detail
325+
326+
template<typename L, typename R, typename T = promoted_vector_value_type<L, R>>
327+
KERNEL_FLOAT_INLINE zip_common_type<ops::divide<T>, T, T>
328+
fast_divide(const L& left, const R& right) {
329+
using E = broadcast_vector_extent_type<L, R>;
330+
vector_storage<T, E::value> result;
331+
332+
detail::apply_fastmath_impl<ops::divide<T>, E::value, T, T, T>::call(
333+
ops::divide<T> {},
334+
result.data(),
335+
detail::convert_impl<vector_value_type<L>, vector_extent_type<L>, T, E>::call(
336+
into_vector_storage(left))
337+
.data(),
338+
detail::convert_impl<vector_value_type<R>, vector_extent_type<R>, T, E>::call(
339+
into_vector_storage(right))
340+
.data());
341+
342+
return result;
343+
}
344+
326345
namespace detail {
327346
template<typename T>
328347
struct cross_impl {

include/kernel_float/constant.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ struct constant {
3030
constexpr constant(T value = {}) : value_(value) {}
3131

3232
KERNEL_FLOAT_INLINE
33-
constexpr constant(const constant<T>& that) : value_(that.value) {}
33+
constexpr constant(const constant<T>& that) : value_(that.value_) {}
3434

3535
/**
3636
* Create a new constant from another constant of type `R`.
@@ -129,7 +129,9 @@ struct cast<constant<T>, R, m> {
129129
KERNEL_FLOAT_INLINE constant<T> operator OP( \
130130
const constant<L>& left, \
131131
const constant<R>& right) { \
132-
return constant<T>(left.get()) OP constant<T>(right.get()); \
132+
auto fl = ops::cast<L, T>(); \
133+
auto fr = ops::cast<R, T>(); \
134+
return fl(left.get()) OP fr(right.get()); \
133135
}
134136

135137
KERNEL_FLOAT_CONSTANT_DEFINE_OP(+)

include/kernel_float/fp16.h

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -68,11 +68,7 @@ KERNEL_FLOAT_FP16_UNARY_FUN(rsqrt, ::hrsqrt, ::h2rsqrt)
6868
KERNEL_FLOAT_FP16_UNARY_FUN(sin, ::hsin, ::h2sin)
6969
KERNEL_FLOAT_FP16_UNARY_FUN(sqrt, ::hsqrt, ::h2sqrt)
7070
KERNEL_FLOAT_FP16_UNARY_FUN(trunc, ::htrunc, ::h2trunc)
71-
72-
KERNEL_FLOAT_FP16_UNARY_FUN(fast_exp, ::hexp, ::h2exp)
73-
KERNEL_FLOAT_FP16_UNARY_FUN(fast_log, ::hlog, ::h2log)
74-
KERNEL_FLOAT_FP16_UNARY_FUN(fast_cos, ::hcos, ::h2cos)
75-
KERNEL_FLOAT_FP16_UNARY_FUN(fast_sin, ::hsin, ::h2sin)
71+
KERNEL_FLOAT_FP16_UNARY_FUN(rcp, ::hrcp, ::h2rcp)
7672

7773
#if KERNEL_FLOAT_IS_DEVICE
7874
#define KERNEL_FLOAT_FP16_BINARY_FUN(NAME, FUN1, FUN2) \
@@ -104,7 +100,6 @@ KERNEL_FLOAT_FP16_BINARY_FUN(multiply, __hmul, __hmul2)
104100
KERNEL_FLOAT_FP16_BINARY_FUN(divide, __hdiv, __h2div)
105101
KERNEL_FLOAT_FP16_BINARY_FUN(min, __hmin, __hmin2)
106102
KERNEL_FLOAT_FP16_BINARY_FUN(max, __hmax, __hmax2)
107-
KERNEL_FLOAT_FP16_BINARY_FUN(fast_div, __hdiv, __h2div)
108103

109104
KERNEL_FLOAT_FP16_BINARY_FUN(equal_to, __heq, __heq2)
110105
KERNEL_FLOAT_FP16_BINARY_FUN(not_equal_to, __hneu, __hneu2)

include/kernel_float/macros.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,4 +63,8 @@
6363

6464
#define KERNEL_FLOAT_MAX_ALIGNMENT (32)
6565

66+
#ifndef KERNEL_FLOAT_FAST_MATH
67+
#define KERNEL_FLOAT_FAST_MATH (0)
68+
#endif
69+
6670
#endif //KERNEL_FLOAT_MACROS_H

include/kernel_float/meta.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,9 @@ struct enable_if_impl<true, T> {
270270
template<bool C, typename T = void>
271271
using enable_if_t = typename detail::enable_if_impl<C, T>::type;
272272

273+
template<typename T, typename...>
274+
using identity_t = T;
275+
273276
KERNEL_FLOAT_INLINE
274277
constexpr size_t round_up_to_power_of_two(size_t n) {
275278
size_t result = 1;

0 commit comments

Comments
 (0)