1+ #ifndef _NBL_HLSL_FORMAT_SHARED_EXP_HLSL_
2+ #define _NBL_HLSL_FORMAT_SHARED_EXP_HLSL_
3+
4+ #include "nbl/builtin/hlsl/cpp_compat.hlsl"
5+ #include "nbl/builtin/hlsl/type_traits.hlsl"
6+ #include "nbl/builtin/hlsl/limits.hlsl"
7+
8+ namespace nbl
9+ {
10+ namespace hlsl
11+ {
12+
13+ namespace format
14+ {
15+
16+ template<typename IntT, uint16_t _Components, uint16_t _ExponentBits>
17+ struct shared_exp// : enable_if_t<_ExponentBits<16> need a way to static_assert in SPIRV!
18+ {
19+ using this_t = shared_exp<IntT,_Components,_ExponentBits>;
20+ using storage_t = typename make_unsigned<IntT>::type;
21+ NBL_CONSTEXPR_STATIC_INLINE uint16_t Components = _Components;
22+ NBL_CONSTEXPR_STATIC_INLINE uint16_t ExponentBits = _ExponentBits;
23+
24+ // Not even going to consider fp16 and fp64 dependence on device traits
25+ using decode_t = float32_t;
26+
27+ bool operator==(const this_t other)
28+ {
29+ return storage==other.storage;
30+ }
31+ bool operator!=(const this_t other)
32+ {
33+ return storage==other.storage;
34+ }
35+
36+ storage_t storage;
37+ };
38+
39+ // all of this because DXC has bugs in partial template spec
40+ namespace impl
41+ {
42+ template<typename IntT, uint16_t _Components, uint16_t _ExponentBits>
43+ struct numeric_limits_shared_exp
44+ {
45+ using type = format::shared_exp<IntT,_Components,_ExponentBits>;
46+ using value_type = typename type::decode_t;
47+ using __storage_t = typename type::storage_t;
48+
49+ NBL_CONSTEXPR_STATIC_INLINE bool is_specialized = true ;
50+ NBL_CONSTEXPR_STATIC_INLINE bool is_signed = is_signed_v<IntT>;
51+ NBL_CONSTEXPR_STATIC_INLINE bool is_integer = false ;
52+ NBL_CONSTEXPR_STATIC_INLINE bool is_exact = false ;
53+ // infinity and NaN are not representable in shared exponent formats
54+ NBL_CONSTEXPR_STATIC_INLINE bool has_infinity = false ;
55+ NBL_CONSTEXPR_STATIC_INLINE bool has_quiet_NaN = false ;
56+ NBL_CONSTEXPR_STATIC_INLINE bool has_signaling_NaN = false ;
57+ // shared exponent formats have no leading 1 in the mantissa, therefore denormalized values aren't really a concept, although one can argue all values are denorm then?
58+ NBL_CONSTEXPR_STATIC_INLINE bool has_denorm = false ;
59+ NBL_CONSTEXPR_STATIC_INLINE bool has_denorm_loss = false ;
60+ // truncation
61+ // NBL_CONSTEXPR_STATIC_INLINE float_round_style round_style = round_to_nearest;
62+ NBL_CONSTEXPR_STATIC_INLINE bool is_iec559 = false ;
63+ NBL_CONSTEXPR_STATIC_INLINE bool is_bounded = true ;
64+ NBL_CONSTEXPR_STATIC_INLINE bool is_modulo = false ;
65+ NBL_CONSTEXPR_STATIC_INLINE int32_t digits = (sizeof (IntT)*8 -(is_signed ? _Components:0 )-_ExponentBits)/_Components;
66+ NBL_CONSTEXPR_STATIC_INLINE int32_t radix = 2 ;
67+ NBL_CONSTEXPR_STATIC_INLINE int32_t max_exponent = 1 <<(_ExponentBits-1 );
68+ NBL_CONSTEXPR_STATIC_INLINE int32_t min_exponent = 1 -max_exponent;
69+ NBL_CONSTEXPR_STATIC_INLINE bool traps = false ;
70+
71+ // extras
72+ NBL_CONSTEXPR_STATIC_INLINE __storage_t MantissaMask = ((__storage_t (1 ))<<digits)-__storage_t (1 );
73+ NBL_CONSTEXPR_STATIC_INLINE uint16_t ExponentBits = _ExponentBits;
74+ NBL_CONSTEXPR_STATIC_INLINE uint16_t ExponentMask = uint16_t ((1 <<_ExponentBits)-1 );
75+
76+ // TODO: functions done as vars
77+ // NBL_CONSTEXPR_STATIC_INLINE value_type min = base::min();
78+ // shift down by 1 to get rid of explicit 1 in mantissa that is now implicit, then +1 in the exponent to compensate
79+ NBL_CONSTEXPR_STATIC_INLINE __storage_t max =
80+ ((max_exponent+1 -numeric_limits<value_type>::min_exponent)<<(numeric_limits<value_type>::digits-1 ))|
81+ ((MantissaMask>>1 )<<(numeric_limits<value_type>::digits-digits));
82+ NBL_CONSTEXPR_STATIC_INLINE __storage_t lowest = is_signed ? ((__storage_t (1 )<<(sizeof (__storage_t)*8 -1 ))|max ):__storage_t (0 );
83+ /*
84+ NBL_CONSTEXPR_STATIC_INLINE value_type epsilon = base::epsilon();
85+ NBL_CONSTEXPR_STATIC_INLINE value_type round_error = base::round_error();
86+ */
87+ };
88+ }
89+
90+ }
91+
92+ // specialize the limits
93+ template<typename IntT, uint16_t _Components, uint16_t _ExponentBits>
94+ struct numeric_limits<format::shared_exp<IntT,_Components,_ExponentBits> > : format::impl::numeric_limits_shared_exp<IntT,_Components,_ExponentBits>
95+ {
96+ };
97+
98+ namespace impl
99+ {
100+ // TODO: remove after the `emulated_float` merge
101+ template<typename T, typename U>
102+ struct _static_cast_helper;
103+
104+ // TODO: versions for `float16_t`
105+
106+ // decode
107+ template<typename IntT, uint16_t _Components, uint16_t _ExponentBits>
108+ struct _static_cast_helper<
109+ vector <typename format::shared_exp<IntT,_Components,_ExponentBits>::decode_t,_Components>,
110+ format::shared_exp<IntT,_Components,_ExponentBits>
111+ >
112+ {
113+ using U = format::shared_exp<IntT,_Components,_ExponentBits>;
114+ using T = vector <typename U::decode_t,_Components>;
115+
116+ T operator ()(U val)
117+ {
118+ using storage_t = typename U::storage_t;
119+ // DXC error: error: expression class 'DependentScopeDeclRefExpr' unimplemented, doesn't matter as decode_t is always float32_t for now
120+ //using decode_t = typename T::decode_t;
121+ using decode_t = float32_t;
122+ // no clue why the compiler doesn't pick up the partial specialization and tries to use the general one
123+ using limits_t = format::impl::numeric_limits_shared_exp<IntT,_Components,_ExponentBits>;
124+
125+ T retval;
126+ for (uint16_t i=0 ; i<_Components; i++)
127+ retval[i] = decode_t ((val.storage>>storage_t (limits_t::digits*i))&limits_t::MantissaMask);
128+ uint16_t exponent = uint16_t (val.storage>>storage_t (limits_t::digits*3 ));
129+ if (limits_t::is_signed)
130+ {
131+ for (uint16_t i=0 ; i<_Components; i++)
132+ if (exponent&(uint16_t (1 )<<(_ExponentBits+i)))
133+ retval[i] = -retval[i];
134+ exponent &= limits_t::ExponentMask;
135+ }
136+ return retval*exp2 (int32_t (exponent-limits_t::digits)+limits_t::min_exponent);
137+ }
138+ };
139+ // encode (WARNING DOES NOT CHECK THAT INPUT IS IN THE RANGE!)
140+ template<typename IntT, uint16_t _Components, uint16_t _ExponentBits>
141+ struct _static_cast_helper<
142+ format::shared_exp<IntT,_Components,_ExponentBits>,
143+ vector <typename format::shared_exp<IntT,_Components,_ExponentBits>::decode_t,_Components>
144+ >
145+ {
146+ using T = format::shared_exp<IntT,_Components,_ExponentBits>;
147+ using U = vector <typename T::decode_t,_Components>;
148+
149+ T operator ()(U val)
150+ {
151+ using storage_t = typename T::storage_t;
152+ // DXC error: error: expression class 'DependentScopeDeclRefExpr' unimplemented, doesn't matter as decode_t is always float32_t for now
153+ //using decode_t = typename T::decode_t;
154+ using decode_t = float32_t;
155+ //
156+ using decode_bits_t = unsigned_integer_of_size<sizeof (decode_t)>::type;
157+ // no clue why the compiler doesn't pick up the partial specialization and tries to use the general one
158+ using limits_t = format::impl::numeric_limits_shared_exp<IntT,_Components,_ExponentBits>;
159+
160+ // get exponents
161+ vector <uint16_t,_Components> exponentsDecBias;
162+ const int32_t dec_MantissaStoredBits = numeric_limits<decode_t>::digits-1 ;
163+ for (uint16_t i=0 ; i<_Components; i++)
164+ {
165+ decode_t v = val[i];
166+ if (limits_t::is_signed)
167+ v = abs (v);
168+ exponentsDecBias[i] = uint16_t (asuint (v)>>dec_MantissaStoredBits);
169+ }
170+
171+ // get the maximum exponent
172+ uint16_t sharedExponentDecBias = exponentsDecBias[0 ];
173+ for (uint16_t i=1 ; i<_Components; i++)
174+ sharedExponentDecBias = max (exponentsDecBias[i],sharedExponentDecBias);
175+
176+ // NOTE: we don't consider clamping against `limits_t::max_exponent`, should be ensured by clamping the inputs against `limits_t::max` before casting!
177+
178+ // we need to stop "shifting up" implicit leading 1. to farthest left position if the exponent too small
179+ uint16_t clampedSharedExponentDecBias;
180+ if (limits_t::min_exponent>numeric_limits<decode_t>::min_exponent) // if ofc its needed at all
181+ clampedSharedExponentDecBias = max (sharedExponentDecBias,uint16_t (limits_t::min_exponent-numeric_limits<decode_t>::min_exponent));
182+ else
183+ clampedSharedExponentDecBias = sharedExponentDecBias;
184+
185+ // we always shift down, the question is how much
186+ vector <uint16_t,_Components> mantissaShifts;
187+ for (uint16_t i=0 ; i<_Components; i++)
188+ mantissaShifts[i] = min (clampedSharedExponentDecBias+uint16_t (-limits_t::min_exponent)-exponentsDecBias[i],uint16_t (numeric_limits<decode_t>::digits));
189+
190+ // finally lets re-bias our exponent (it will always be positive), note the -1 because IEEE754 floats reserve the lowest exponent values for denorm
191+ const uint16_t sharedExponentEncBias = int16_t (clampedSharedExponentDecBias+int16_t (-limits_t::min_exponent))-uint16_t (1 -numeric_limits<decode_t>::min_exponent);
192+
193+ //
194+ T retval;
195+ retval.storage = storage_t (sharedExponentEncBias)<<(limits_t::digits*3 );
196+ const decode_bits_t dec_MantissaMask = (decode_bits_t (1 )<<dec_MantissaStoredBits)-1 ;
197+ for (uint16_t i=0 ; i<_Components; i++)
198+ {
199+ decode_bits_t origBitPattern = bit_cast<decode_bits_t>(val[i])&dec_MantissaMask;
200+ // put the implicit 1 in (don't care about denormalized because its probably less than our `limits_t::min` (TODO: static assert it)
201+ origBitPattern |= decode_bits_t (1 )<<dec_MantissaStoredBits;
202+ // shift and put in the right place
203+ retval.storage |= storage_t (origBitPattern>>mantissaShifts[i])<<(limits_t::digits*i);
204+ }
205+ if (limits_t::is_signed)
206+ {
207+ // doing ops on smaller integers is faster
208+ decode_bits_t SignMask = 0x1 <<(sizeof (decode_t)*8 -1 );
209+ decode_bits_t signs = bit_cast<decode_bits_t>(val[0 ])&SignMask;
210+ for (uint16_t i=1 ; i<_Components; i++)
211+ signs |= (bit_cast<decode_bits_t>(val[i])&SignMask)>>i;
212+ retval.storage |= storage_t (signs)<<((sizeof (storage_t)-sizeof (decode_t))*8 );
213+ }
214+ return retval;
215+ }
216+ };
217+ }
218+ }
219+ }
220+ #endif
0 commit comments