@@ -54,6 +54,8 @@ struct exp2_helper;
5454template<typename T NBL_STRUCT_CONSTRAINABLE>
5555struct log_helper;
5656template<typename T NBL_STRUCT_CONSTRAINABLE>
57+ struct log2_helper;
58+ template<typename T NBL_STRUCT_CONSTRAINABLE>
5759struct abs_helper;
5860template<typename T NBL_STRUCT_CONSTRAINABLE>
5961struct cos_helper;
@@ -64,7 +66,7 @@ struct acos_helper;
6466template<typename T NBL_STRUCT_CONSTRAINABLE>
6567struct sqrt_helper;
6668template<typename T, typename U NBL_STRUCT_CONSTRAINABLE>
67- struct lerp_helper ;
69+ struct mix_helper ;
6870template<typename T NBL_STRUCT_CONSTRAINABLE>
6971struct modf_helper;
7072
@@ -82,12 +84,23 @@ struct HELPER_NAME<T NBL_PARTIAL_REQ_BOT(always_true<decltype(spirv::SPIRV_FUNCT
8284};
8385
8486AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER (sin_helper, sin, T)
85- AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER (cos_helper, cos, T)
87+ //AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER(cos_helper, cos, T)
88+
89+ template<typename T> NBL_PARTIAL_REQ_TOP (is_same_v<decltype (spirv::cos<T>(experimental::declval<T>())), T>)
90+ struct cos_helper<T NBL_PARTIAL_REQ_BOT (is_same_v<decltype (spirv::cos<T>(experimental::declval<T>())), T>) >
91+ {
92+ static T __call (T arg)
93+ {
94+ return spirv::cos<T>(arg);
95+ }
96+ };
97+
8698AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER (acos_helper, acos, T)
8799AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER (abs_helper, sAbs, T)
88100AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER (abs_helper, fAbs, T)
89101AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER (sqrt_helper, sqrt, T)
90102AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER (log_helper, log, T)
103+ AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER (log2_helper, log2, T)
91104AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER (exp2_helper, exp2, T)
92105AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER (exp_helper, exp, T)
93106AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER (floor_helper, floor, T)
@@ -108,7 +121,7 @@ struct pow_helper<T NBL_PARTIAL_REQ_BOT(always_true<decltype(spirv::pow<T>(exper
108121};
109122
110123template<typename T, typename U> NBL_PARTIAL_REQ_TOP (always_true<decltype (spirv::fMix<T>(experimental::declval<T>(), experimental::declval<T>(), experimental::declval<U>()))>)
111- struct lerp_helper <T, U NBL_PARTIAL_REQ_BOT (always_true<decltype (spirv::fMix<T>(experimental::declval<T>(), experimental::declval<T>(), experimental::declval<U>()))>) >
124+ struct mix_helper <T, U NBL_PARTIAL_REQ_BOT (always_true<decltype (spirv::fMix<T>(experimental::declval<T>(), experimental::declval<T>(), experimental::declval<U>()))>) >
112125{
113126 using return_t = conditional_t<is_vector_v<T>, vector <typename vector_traits<T>::scalar_type, vector_traits<T>::Dimension>, T>;
114127 static inline return_t __call (const T x, const T y, const U a)
@@ -150,8 +163,33 @@ struct modf_helper<T NBL_PARTIAL_REQ_BOT(concepts::FloatingPoint<T> && is_vector
150163 }
151164};
152165
166+ template<typename FloatingPoint>
167+ NBL_PARTIAL_REQ_TOP (concepts::FloatingPointScalar<FloatingPoint>)
168+ struct erf_helper<FloatingPoint NBL_PARTIAL_REQ_BOT (concepts::FloatingPointScalar<FloatingPoint>) >
169+ {
170+ static FloatingPoint __call (NBL_CONST_REF_ARG (FloatingPoint) _x)
171+ {
172+ const FloatingPoint a1 = 0.254829592 ;
173+ const FloatingPoint a2 = -0.284496736 ;
174+ const FloatingPoint a3 = 1.421413741 ;
175+ const FloatingPoint a4 = -1.453152027 ;
176+ const FloatingPoint a5 = 1.061405429 ;
177+ const FloatingPoint p = 0.3275911 ;
178+
179+ FloatingPoint sign = sign (_x);
180+ FloatingPoint x = abs (_x);
181+
182+ FloatingPoint t = 1.0 / (1.0 + p * x);
183+ FloatingPoint y = 1.0 - (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * exp (-x * x);
184+
185+ return sign * y;
186+ }
187+ };
188+
153189#else // C++ only specializations
154190
191+
192+ // not giving an explicit template parameter to std function below because not every function used here is templated
155193#define AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER (HELPER_NAME, REQUIREMENT, STD_FUNCTION_NAME, RETURN_TYPE)\
156194template<typename T>\
157195requires REQUIREMENT \
@@ -170,6 +208,7 @@ AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER(acos_helper, concepts::FloatingPointScalar<T
170208AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER (sqrt_helper, concepts::FloatingPointScalar<T>, sqrt, T)
171209AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER (abs_helper, concepts::Scalar<T>, abs, T)
172210AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER (log_helper, concepts::Scalar<T>, log, T)
211+ AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER (log2_helper, concepts::FloatingPointScalar<T>, log2, T)
173212AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER (exp2_helper, concepts::Scalar<T>, exp2, T)
174213AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER (exp_helper, concepts::Scalar<T>, exp, T)
175214AUTO_SPECIALIZE_TRIVIAL_CASE_HELPER (floor_helper, concepts::FloatingPointScalar<T>, floor, T)
@@ -182,7 +221,7 @@ struct pow_helper<T>
182221 using return_t = T;
183222 static inline return_t __call (const T x, const T y)
184223 {
185- return std::pow (x, y);
224+ return std::pow<T> (x, y);
186225 }
187226};
188227
@@ -228,7 +267,7 @@ struct isnan_helper<T>
228267
229268template<typename T, typename U>
230269requires concepts::FloatingPoint<T> && (concepts::FloatingPoint<T> || concepts::Boolean<T>)
231- struct lerp_helper <T, U>
270+ struct mix_helper <T, U>
232271{
233272 using return_t = T;
234273 static inline return_t __call (const T x, const T y, const U a)
@@ -237,48 +276,29 @@ struct lerp_helper<T, U>
237276 }
238277};
239278
240- #endif
241-
242- // C++ and HLSL specializations
243-
244279template<typename FloatingPoint>
245280NBL_PARTIAL_REQ_TOP (concepts::FloatingPointScalar<FloatingPoint>)
246281struct erf_helper<FloatingPoint NBL_PARTIAL_REQ_BOT (concepts::FloatingPointScalar<FloatingPoint>) >
247282{
248- static FloatingPoint __call (NBL_CONST_REF_ARG (FloatingPoint) _x )
283+ static FloatingPoint __call (NBL_CONST_REF_ARG (FloatingPoint) x )
249284 {
250- #ifdef __HLSL_VERSION
251- const FloatingPoint a1 = 0.254829592 ;
252- const FloatingPoint a2 = -0.284496736 ;
253- const FloatingPoint a3 = 1.421413741 ;
254- const FloatingPoint a4 = -1.453152027 ;
255- const FloatingPoint a5 = 1.061405429 ;
256- const FloatingPoint p = 0.3275911 ;
285+ return std::erf<FloatingPoint>(x);
286+ }
287+ };
257288
258- FloatingPoint sign = sign (_x);
259- FloatingPoint x = abs (_x);
289+ #endif // C++ only specializations
260290
261- FloatingPoint t = 1.0 / (1.0 + p * x);
262- FloatingPoint y = 1.0 - (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * exp (-x * x);
291+ // C++ and HLSL specializations
263292
264- return sign * y;
265- #else
266- return std::erf (_x);
267- #endif
268- }
269- };
270293template<typename FloatingPoint>
271294NBL_PARTIAL_REQ_TOP (concepts::FloatingPointScalar<FloatingPoint>)
272295struct erfInv_helper<FloatingPoint NBL_PARTIAL_REQ_BOT (concepts::FloatingPointScalar<FloatingPoint>) >
273296{
274297 static FloatingPoint __call (NBL_CONST_REF_ARG (FloatingPoint) _x)
275298 {
276299 FloatingPoint x = clamp <FloatingPoint>(_x, -0.99999 , 0.99999 );
277- #ifdef __HLSL_VERSION
278- FloatingPoint w = -log ((1.0 - x) * (1.0 + x));
279- #else
280- FloatingPoint w = -std::log ((1.0 - x) * (1.0 + x));
281- #endif
300+
301+ FloatingPoint w = -log_helper<FloatingPoint>::__call ((1.0 - x) * (1.0 + x));
282302 FloatingPoint p;
283303 if (w < 5.0 )
284304 {
@@ -295,11 +315,7 @@ struct erfInv_helper<FloatingPoint NBL_PARTIAL_REQ_BOT(concepts::FloatingPointSc
295315 }
296316 else
297317 {
298- #ifdef __HLSL_VERSION
299- w = sqrt (w) - 3.0 ;
300- #else
301- w = std::sqrt (w) - 3.0 ;
302- #endif
318+ w = sqrt_helper<FloatingPoint>::__call (w) - 3.0 ;
303319 p = -0.000200214257 ;
304320 p = 0.000100950558 + p * w;
305321 p = 0.00134934322 + p * w;
@@ -345,6 +361,7 @@ struct HELPER_NAME<T NBL_PARTIAL_REQ_BOT(VECTOR_SPECIALIZATION_CONCEPT) >\
345361AUTO_SPECIALIZE_HELPER_FOR_VECTOR (sqrt_helper, T)
346362AUTO_SPECIALIZE_HELPER_FOR_VECTOR (abs_helper, T)
347363AUTO_SPECIALIZE_HELPER_FOR_VECTOR (log_helper, T)
364+ AUTO_SPECIALIZE_HELPER_FOR_VECTOR (log2_helper, T)
348365AUTO_SPECIALIZE_HELPER_FOR_VECTOR (exp2_helper, T)
349366AUTO_SPECIALIZE_HELPER_FOR_VECTOR (exp_helper, T)
350367AUTO_SPECIALIZE_HELPER_FOR_VECTOR (floor_helper, T)
0 commit comments