@@ -103,7 +103,8 @@ template<typename T NBL_STRUCT_CONSTRAINABLE>
103103struct nMax_helper;
104104template<typename T NBL_STRUCT_CONSTRAINABLE>
105105struct nClamp_helper;
106-
106+ template<typename T NBL_STRUCT_CONSTRAINABLE>
107+ struct fma_helper;
107108
108109#ifdef __HLSL_VERSION // HLSL only specializations
109110
@@ -134,6 +135,7 @@ template<typename T> AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER(find_lsb_helper, findIL
134135#undef FIND_MSB_LSB_RETURN_TYPE
135136
136137template<typename T> AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER (bitReverse_helper, bitReverse, (T), (T), T)
138+ template<typename T> AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER (dot_helper, dot, (T), (T)(T), typename vector_traits<T>::scalar_type)
137139template<typename T> AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER (transpose_helper, transpose, (T), (T), T)
138140template<typename T> AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER (length_helper, length, (T), (T), typename vector_traits<T>::scalar_type)
139141template<typename T> AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER (normalize_helper, normalize, (T), (T), T)
@@ -162,6 +164,7 @@ template<typename T, typename U> AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER(refract_hel
162164template<typename T> AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER (nMax_helper, nMax, (T), (T)(T), T)
163165template<typename T> AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER (nMin_helper, nMin, (T), (T)(T), T)
164166template<typename T> AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER (nClamp_helper, nClamp, (T), (T)(T), T)
167+ template<typename T> AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER (fma_helper, fma, (T), (T)(T)(T), T)
165168
166169#define BITCOUNT_HELPER_RETRUN_TYPE conditional_t<is_vector_v<T>, vector <int32_t, vector_traits<T>::Dimension>, int32_t>
167170template<typename T> AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER (bitCount_helper, bitCount, (T), (T), BITCOUNT_HELPER_RETRUN_TYPE)
@@ -599,6 +602,16 @@ struct nClamp_helper<T>
599602 }
600603};
601604
605+ template<typename FloatingPoint>
606+ requires concepts::FloatingPointScalar<FloatingPoint>
607+ struct fma_helper<FloatingPoint>
608+ {
609+ static FloatingPoint __call (NBL_CONST_REF_ARG (FloatingPoint) x, NBL_CONST_REF_ARG (FloatingPoint) y, NBL_CONST_REF_ARG (FloatingPoint) z)
610+ {
611+ return std::fma (x, y, z);
612+ }
613+ };
614+
602615#endif // C++ only specializations
603616
604617// C++ and HLSL specializations
@@ -613,25 +626,6 @@ struct bitReverseAs_helper<T NBL_PARTIAL_REQ_BOT(concepts::UnsignedIntegralScala
613626 }
614627};
615628
616- template<typename Vectorial>
617- NBL_PARTIAL_REQ_TOP (concepts::Vectorial<Vectorial>)
618- struct dot_helper<Vectorial NBL_PARTIAL_REQ_BOT (concepts::Vectorial<Vectorial>) >
619- {
620- using scalar_type = typename vector_traits<Vectorial>::scalar_type;
621-
622- static inline scalar_type __call (NBL_CONST_REF_ARG (Vectorial) lhs, NBL_CONST_REF_ARG (Vectorial) rhs)
623- {
624- static const uint32_t ArrayDim = vector_traits<Vectorial>::Dimension;
625- static array_get<Vectorial, scalar_type> getter;
626-
627- scalar_type retval = getter (lhs, 0 ) * getter (rhs, 0 );
628- for (uint32_t i = 1 ; i < ArrayDim; ++i)
629- retval = retval + getter (lhs, i) * getter (rhs, i);
630-
631- return retval;
632- }
633- };
634-
635629#ifdef __HLSL_VERSION
636630// SPIR-V already defines specializations for builtin vector types
637631#define VECTOR_SPECIALIZATION_CONCEPT concepts::Vectorial<T> && !is_vector_v<T>
@@ -888,8 +882,54 @@ struct mix_helper<T, U NBL_PARTIAL_REQ_BOT(concepts::Vectorial<T> && concepts::B
888882 }
889883};
890884
885+ template<typename T>
886+ NBL_PARTIAL_REQ_TOP (VECTOR_SPECIALIZATION_CONCEPT)
887+ struct fma_helper<T NBL_PARTIAL_REQ_BOT (VECTOR_SPECIALIZATION_CONCEPT) >
888+ {
889+ using return_t = T;
890+ static return_t __call (NBL_CONST_REF_ARG (T) x, NBL_CONST_REF_ARG (T) y, NBL_CONST_REF_ARG (T) z)
891+ {
892+ using traits = hlsl::vector_traits<T>;
893+ array_get<T, typename traits::scalar_type> getter;
894+ array_set<T, typename traits::scalar_type> setter;
895+
896+ return_t output;
897+ for (uint32_t i = 0 ; i < traits::Dimension; ++i)
898+ setter (output, i, fma_helper<typename traits::scalar_type>::__call (getter (x, i), getter (y, i), getter (z, i)));
899+
900+ return output;
901+ }
902+ };
903+
904+ #ifdef __HLSL_VERSION
905+ #define DOT_HELPER_REQUIREMENT (concepts::Vectorial<Vectorial> && !is_vector_v<Vectorial>)
906+ #else
907+ #define DOT_HELPER_REQUIREMENT concepts::Vectorial<Vectorial>
908+ #endif
909+
910+ template<typename Vectorial>
911+ NBL_PARTIAL_REQ_TOP (DOT_HELPER_REQUIREMENT)
912+ struct dot_helper<Vectorial NBL_PARTIAL_REQ_BOT (DOT_HELPER_REQUIREMENT) >
913+ {
914+ using scalar_type = typename vector_traits<Vectorial>::scalar_type;
915+
916+ static inline scalar_type __call (NBL_CONST_REF_ARG (Vectorial) lhs, NBL_CONST_REF_ARG (Vectorial) rhs)
917+ {
918+ static const uint32_t ArrayDim = vector_traits<Vectorial>::Dimension;
919+ static array_get<Vectorial, scalar_type> getter;
920+
921+ scalar_type retval = getter (lhs, 0 ) * getter (rhs, 0 );
922+ for (uint32_t i = 1 ; i < ArrayDim; ++i)
923+ retval = fma_helper<scalar_type>::__call (getter (lhs, i), getter (rhs, i), retval);
924+
925+ return retval;
926+ }
927+ };
928+
929+ #undef DOT_HELPER_REQUIREMENT
930+
891931}
892932}
893933}
894934
895- #endif
935+ #endif
0 commit comments