11// ================================================================================
22// this file has been auto-generated, do not modify its contents!
3- // date: 2023-09-19 20:45:16.880746
4- // git hash: da0a46b533ef9d25638748eb951284f14e7c48bb
3+ // date: 2023-09-21 09:37:28.638971
4+ // git hash: 07af0ad9ff5c16595790d579577244bc482f0999
55// ================================================================================
66
77#ifndef KERNEL_FLOAT_MACROS_H
@@ -497,7 +497,7 @@ struct extent<N> {
497497};
498498
499499template <typename T>
500- struct into_vector_traits {
500+ struct into_vector_impl {
501501 using value_type = T;
502502 using extent_type = extent<1 >;
503503
@@ -508,7 +508,7 @@ struct into_vector_traits {
508508};
509509
510510template <typename T, size_t N>
511- struct into_vector_traits <T[N]> {
511+ struct into_vector_impl <T[N]> {
512512 using value_type = T;
513513 using extent_type = extent<N>;
514514
@@ -526,19 +526,19 @@ struct into_vector_traits<T[N]> {
526526};
527527
528528template <typename V>
529- struct into_vector_traits <const V>: into_vector_traits <V> {};
529+ struct into_vector_impl <const V>: into_vector_impl <V> {};
530530
531531template <typename V>
532- struct into_vector_traits <V&>: into_vector_traits <V> {};
532+ struct into_vector_impl <V&>: into_vector_impl <V> {};
533533
534534template <typename V>
535- struct into_vector_traits <const V&>: into_vector_traits <V> {};
535+ struct into_vector_impl <const V&>: into_vector_impl <V> {};
536536
537537template <typename V>
538- struct into_vector_traits <V&&>: into_vector_traits <V> {};
538+ struct into_vector_impl <V&&>: into_vector_impl <V> {};
539539
540540template <typename T, size_t N, size_t A>
541- struct into_vector_traits <aligned_array<T, N, A>> {
541+ struct into_vector_impl <aligned_array<T, N, A>> {
542542 using value_type = T;
543543 using extent_type = extent<N>;
544544
@@ -550,7 +550,7 @@ struct into_vector_traits<aligned_array<T, N, A>> {
550550
551551#define KERNEL_FLOAT_DEFINE_VECTOR_TYPE (T, T1, T2, T3, T4 ) \
552552 template <> \
553- struct into_vector_traits <::T1> { \
553+ struct into_vector_impl <::T1> { \
554554 using value_type = T; \
555555 using extent_type = extent<1 >; \
556556 \
@@ -561,7 +561,7 @@ struct into_vector_traits<aligned_array<T, N, A>> {
561561 }; \
562562 \
563563 template <> \
564- struct into_vector_traits <::T2> { \
564+ struct into_vector_impl <::T2> { \
565565 using value_type = T; \
566566 using extent_type = extent<2 >; \
567567 \
@@ -572,7 +572,7 @@ struct into_vector_traits<aligned_array<T, N, A>> {
572572 }; \
573573 \
574574 template <> \
575- struct into_vector_traits <::T3> { \
575+ struct into_vector_impl <::T3> { \
576576 using value_type = T; \
577577 using extent_type = extent<3 >; \
578578 \
@@ -583,7 +583,7 @@ struct into_vector_traits<aligned_array<T, N, A>> {
583583 }; \
584584 \
585585 template <> \
586- struct into_vector_traits <::T4> { \
586+ struct into_vector_impl <::T4> { \
587587 using value_type = T; \
588588 using extent_type = extent<4 >; \
589589 \
@@ -612,7 +612,7 @@ template<typename T, typename E, typename S = vector_storage<T, E::size>>
612612struct vector;
613613
614614template <typename T, typename E, typename S>
615- struct into_vector_traits <vector<T, E, S>> {
615+ struct into_vector_impl <vector<T, E, S>> {
616616 using value_type = T;
617617 using extent_type = E;
618618
@@ -634,10 +634,10 @@ struct vector_traits<vector<T, E, S>> {
634634};
635635
636636template <typename V>
637- using vector_value_type = typename into_vector_traits <V>::value_type;
637+ using vector_value_type = typename into_vector_impl <V>::value_type;
638638
639639template <typename V>
640- using vector_extent_type = typename into_vector_traits <V>::extent_type;
640+ using vector_extent_type = typename into_vector_impl <V>::extent_type;
641641
642642template <typename V>
643643static constexpr size_t vector_extent = vector_extent_type<V>::value;
@@ -653,7 +653,7 @@ using promoted_vector_value_type = promote_t<vector_value_type<Vs>...>;
653653
654654template <typename V>
655655KERNEL_FLOAT_INLINE vector_storage_type<V> into_vector_storage (V&& input) {
656- return into_vector_traits <V>::call (std::forward<V>(input));
656+ return into_vector_impl <V>::call (std::forward<V>(input));
657657}
658658
659659} // namespace kernel_float
@@ -1732,7 +1732,10 @@ namespace kernel_float {
17321732template <typename T = double >
17331733struct constant {
17341734 template <typename R>
1735- KERNEL_FLOAT_INLINE explicit constexpr constant (const constant<R>& that) : value_(that.get()) {}
1735+ KERNEL_FLOAT_INLINE explicit constexpr constant (const constant<R>& that) {
1736+ auto f = ops::cast<R, T>();
1737+ value_ = f (that.get ());
1738+ }
17361739
17371740 KERNEL_FLOAT_INLINE
17381741 constexpr constant (T value = {}) : value_(value) {}
@@ -1793,28 +1796,43 @@ struct cast<constant<T>, R, m> {
17931796};
17941797} // namespace ops
17951798
1796- #define KERNEL_FLOAT_CONSTANT_DEFINE_OP (OP ) \
1797- template <typename L, typename R> \
1798- R operator OP (const constant<L>& left, const R& right) { \
1799- using T = vector_value_type<R>; \
1800- return operator OP (T (left.get ()), right); \
1801- } \
1802- \
1803- template <typename L, typename R> \
1804- L operator OP (const L& left, const constant<R>& right) { \
1805- using T = vector_value_type<L>; \
1806- return operator OP (left, T (right.get ())); \
1807- } \
1808- \
1809- template <typename L, typename R, typename T = promote_t <L, R>> \
1810- constant<T> operator OP (const constant<L>& left, const constant<R>& right) { \
1811- return constant<T>(operator OP (T (left.get ()), T (right.get ()))); \
1812- }
1813-
1814- // KERNEL_FLOAT_CONSTANT_DEFINE_OP(+)
1815- // KERNEL_FLOAT_CONSTANT_DEFINE_OP(-)
1816- // KERNEL_FLOAT_CONSTANT_DEFINE_OP(*)
1817- // KERNEL_FLOAT_CONSTANT_DEFINE_OP(/)
1799+ #define KERNEL_FLOAT_CONSTANT_DEFINE_OP (OP ) \
1800+ template <typename L, typename R> \
1801+ KERNEL_FLOAT_INLINE auto operator OP (const constant<L>& left, const R& right) { \
1802+ auto f = ops::cast<L, vector_value_type<R>>(); \
1803+ return f (left.get ()) OP right; \
1804+ } \
1805+ \
1806+ template <typename L, typename R> \
1807+ KERNEL_FLOAT_INLINE auto operator OP (const L& left, const constant<R>& right) { \
1808+ auto f = ops::cast<R, vector_value_type<L>>(); \
1809+ return left OP f (right.get ()); \
1810+ } \
1811+ \
1812+ template <typename L, typename R, typename E> \
1813+ KERNEL_FLOAT_INLINE auto operator OP (const constant<L>& left, const vector<R, E>& right) { \
1814+ auto f = ops::cast<L, R>(); \
1815+ return f (left.get ()) OP right; \
1816+ } \
1817+ \
1818+ template <typename L, typename R, typename E> \
1819+ KERNEL_FLOAT_INLINE auto operator OP (const vector<L, E>& left, const constant<R>& right) { \
1820+ auto f = ops::cast<R, L>(); \
1821+ return left OP f (right.get ()); \
1822+ } \
1823+ \
1824+ template <typename L, typename R, typename T = promote_t <L, R>> \
1825+ KERNEL_FLOAT_INLINE constant<T> operator OP ( \
1826+ const constant<L>& left, \
1827+ const constant<R>& right) { \
1828+ return constant<T>(left.get ()) OP constant<T>(right.get ()); \
1829+ }
1830+
1831+ KERNEL_FLOAT_CONSTANT_DEFINE_OP (+)
1832+ KERNEL_FLOAT_CONSTANT_DEFINE_OP (-)
1833+ KERNEL_FLOAT_CONSTANT_DEFINE_OP (*)
1834+ KERNEL_FLOAT_CONSTANT_DEFINE_OP (/)
1835+ KERNEL_FLOAT_CONSTANT_DEFINE_OP (%)
18181836
18191837} // namespace kernel_float
18201838
@@ -2731,7 +2749,7 @@ namespace ops {
27312749template <typename T>
27322750struct fma {
27332751 KERNEL_FLOAT_INLINE T operator ()(T a, T b, T c) {
2734- return a + b * c;
2752+ return a * b + c;
27352753 }
27362754};
27372755
@@ -3066,7 +3084,7 @@ struct vector: public S {
30663084 */
30673085template <typename V>
30683086KERNEL_FLOAT_INLINE into_vector_type<V> into_vector (V&& input) {
3069- return into_vector_traits <V>::call (std::forward<V>(input));
3087+ return into_vector_impl <V>::call (std::forward<V>(input));
30703088}
30713089
30723090template <typename T>
@@ -3136,7 +3154,7 @@ KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(float, __half)
31363154KERNEL_FLOAT_DEFINE_PROMOTED_TYPE (double , __half)
31373155
31383156template <>
3139- struct into_vector_traits <__half2> {
3157+ struct into_vector_impl <__half2> {
31403158 using value_type = __half;
31413159 using extent_type = extent<2 >;
31423160
@@ -3440,7 +3458,7 @@ KERNEL_FLOAT_DEFINE_PROMOTED_TYPE(float, __nv_bfloat16)
34403458KERNEL_FLOAT_DEFINE_PROMOTED_TYPE (double , __nv_bfloat16)
34413459
34423460template <>
3443- struct into_vector_traits <__nv_bfloat162> {
3461+ struct into_vector_impl <__nv_bfloat162> {
34443462 using value_type = __nv_bfloat16;
34453463 using extent_type = extent<2 >;
34463464
0 commit comments