@@ -23,85 +23,125 @@ namespace fft
2323{
2424
2525// ---------------------------------- Utils -----------------------------------------------
26- template<typename SharedMemoryAdaptor, typename Scalar>
27- struct exchangeValues;
2826
29- template<typename SharedMemoryAdaptor>
30- struct exchangeValues<SharedMemoryAdaptor, float16_t>
27+ // No need to expose these
28+ namespace impl
3129{
32- static void __call (NBL_REF_ARG (complex_t<float16_t>) lo, NBL_REF_ARG (complex_t<float16_t>) hi, uint32_t threadID, uint32_t stride, NBL_REF_ARG (SharedMemoryAdaptor) sharedmemAdaptor)
30+ template<typename SharedMemoryAdaptor, typename Scalar>
31+ struct exchangeValues
3332 {
34- const bool topHalf = bool (threadID & stride);
35- // Pack two halves into a single uint32_t
36- uint32_t toExchange = bit_cast<uint32_t, float16_t2 >(topHalf ? float16_t2 (lo.real (), lo.imag ()) : float16_t2 (hi.real (), hi.imag ()));
37- shuffleXor<SharedMemoryAdaptor, uint32_t>::__call (toExchange, stride, sharedmemAdaptor);
38- float16_t2 exchanged = bit_cast<float16_t2, uint32_t>(toExchange);
39- if (topHalf)
33+ static void __call (NBL_REF_ARG (complex_t<Scalar>) lo, NBL_REF_ARG (complex_t<Scalar>) hi, uint32_t threadID, uint32_t stride, NBL_REF_ARG (SharedMemoryAdaptor) sharedmemAdaptor)
4034 {
41- lo.real (exchanged.x);
42- lo.imag (exchanged.y);
35+ const bool topHalf = bool (threadID & stride);
36+ // Pack into float vector because ternary operator does not support structs
37+ vector <Scalar, 2 > exchanged = topHalf ? vector <Scalar, 2 >(lo.real (), lo.imag ()) : vector <Scalar, 2 >(hi.real (), hi.imag ());
38+ shuffleXor<SharedMemoryAdaptor, vector <Scalar, 2 > >(exchanged, stride, sharedmemAdaptor);
39+ if (topHalf)
40+ {
41+ lo.real (exchanged.x);
42+ lo.imag (exchanged.y);
43+ }
44+ else
45+ {
46+ hi.real (exchanged.x);
47+ hi.imag (exchanged.y);
48+ }
4349 }
44- else
45- {
46- hi.real (exchanged.x);
47- lo.imag (exchanged.y);
48- }
49- }
50- };
50+ };
5151
52- template<typename SharedMemoryAdaptor>
53- struct exchangeValues<SharedMemoryAdaptor, float32_t>
54- {
55- static void __call (NBL_REF_ARG (complex_t<float32_t>) lo, NBL_REF_ARG (complex_t<float32_t>) hi, uint32_t threadID, uint32_t stride, NBL_REF_ARG (SharedMemoryAdaptor) sharedmemAdaptor)
52+ template<uint16_t N, uint16_t H>
53+ enable_if_t<(H <= N) && (N < 32 ), uint32_t> circularBitShiftRightHigher (uint32_t i)
5654 {
57- const bool topHalf = bool (threadID & stride);
58- // pack into `float32_t2` because ternary operator doesn't support structs
59- float32_t2 exchanged = topHalf ? float32_t2 (lo.real (), lo.imag ()) : float32_t2 (hi.real (), hi.imag ());
60- shuffleXor<SharedMemoryAdaptor, float32_t2>::__call (exchanged, stride, sharedmemAdaptor);
61- if (topHalf)
62- {
63- lo.real (exchanged.x);
64- lo.imag (exchanged.y);
65- }
66- else
67- {
68- hi.real (exchanged.x);
69- hi.imag (exchanged.y);
70- }
55+ // Highest H bits are numbered N-1 through N - H
56+ // N - H is then the middle bit
57+ // Lowest bits numbered from 0 through N - H - 1
58+ NBL_CONSTEXPR_STATIC_INLINE uint32_t lowMask = (1 << (N - H)) - 1 ;
59+ NBL_CONSTEXPR_STATIC_INLINE uint32_t midMask = 1 << (N - H);
60+ NBL_CONSTEXPR_STATIC_INLINE uint32_t highMask = ~(lowMask | midMask);
61+
62+ uint32_t low = i & lowMask;
63+ uint32_t mid = i & midMask;
64+ uint32_t high = i & highMask;
65+
66+ high >>= 1 ;
67+ mid <<= H - 1 ;
68+
69+ return mid | high | low;
7170 }
72- };
7371
74- template<typename SharedMemoryAdaptor>
75- struct exchangeValues<SharedMemoryAdaptor, float64_t>
76- {
77- static void __call (NBL_REF_ARG (complex_t<float64_t>) lo, NBL_REF_ARG (complex_t<float64_t>) hi, uint32_t threadID, uint32_t stride, NBL_REF_ARG (SharedMemoryAdaptor) sharedmemAdaptor)
72+ template<uint16_t N, uint16_t H>
73+ enable_if_t<(H <= N) && (N < 32 ), uint32_t> circularBitShiftLeftHigher (uint32_t i)
7874 {
79- const bool topHalf = bool (threadID & stride);
80- // pack into `float64_t2` because ternary operator doesn't support structs
81- float64_t2 exchanged = topHalf ? float64_t2 (lo.real (), lo.imag ()) : float64_t2 (hi.real (), hi.imag ());
82- shuffleXor<SharedMemoryAdaptor, float64_t2 >::__call (exchanged, stride, sharedmemAdaptor);
83- if (topHalf)
84- {
85- lo.real (exchanged.x);
86- lo.imag (exchanged.y);
87- }
88- else
89- {
90- hi.real (exchanged.x);
91- hi.imag (exchanged.y);
92- }
75+ // Highest H bits are numbered N-1 through N - H
76+ // N - 1 is then the highest bit, and N - 2 through N - H are the middle bits
77+ // Lowest bits numbered from 0 through N - H - 1
78+ NBL_CONSTEXPR_STATIC_INLINE uint32_t lowMask = (1 << (N - H)) - 1 ;
79+ NBL_CONSTEXPR_STATIC_INLINE uint32_t highMask = 1 << (N - 1 );
80+ NBL_CONSTEXPR_STATIC_INLINE uint32_t midMask = ~(lowMask | highMask);
81+
82+ uint32_t low = i & lowMask;
83+ uint32_t mid = i & midMask;
84+ uint32_t high = i & highMask;
85+
86+ mid <<= 1 ;
87+ high >>= H - 1 ;
88+
89+ return mid | high | low;
9390 }
94- };
91+ } //namespace impl
9592
9693// Get the required size (in number of uint32_t elements) of the workgroup shared memory array needed for the FFT
97- template <typename scalar_t, uint32_t WorkgroupSize>
94+ template <typename scalar_t, uint16_t WorkgroupSize>
9895NBL_CONSTEXPR uint32_t SharedMemoryDWORDs = (sizeof (complex_t<scalar_t>) / sizeof (uint32_t)) * WorkgroupSize;
9996
97+ // Util to unpack two values from the packed FFT X + iY - get outputs in the same input arguments, storing x to lo and y to hi
98+ template<typename Scalar>
99+ void unpack (NBL_REF_ARG (complex_t<Scalar>) lo, NBL_REF_ARG (complex_t<Scalar>) hi)
100+ {
101+ complex_t<Scalar> x = (lo + conj (hi)) * Scalar (0.5 );
102+ hi = rotateRight<Scalar>(lo - conj (hi)) * Scalar (0.5 );
103+ lo = x;
104+ }
105+
106+ template<uint16_t ElementsPerInvocation, uint16_t WorkgroupSize>
107+ struct FFTIndexingUtils
108+ {
109+ // This function maps the index `idx` in the output array of a Nabla FFT to the index `freqIdx` in the DFT such that `DFT[freqIdx] = NablaFFT[idx]`
110+ // This is because Cooley-Tukey + subgroup operations end up spewing out the outputs in a weird order
111+ static uint32_t getDFTIndex (uint32_t outputIdx)
112+ {
113+ return impl::circularBitShiftRightHigher<FFTSizeLog2, FFTSizeLog2 - ElementsPerInvocationLog2 + 1 >(glsl::bitfieldReverse<uint32_t>(outputIdx) >> (32 - FFTSizeLog2));
114+ }
115+
116+ // This function maps the index `freqIdx` in the DFT to the index `idx` in the output array of a Nabla FFT such that `DFT[freqIdx] = NablaFFT[idx]`
117+ // It is essentially the inverse of `getDFTIndex`
118+ static uint32_t getNablaIndex (uint32_t freqIdx)
119+ {
120+ return glsl::bitfieldReverse<uint32_t>(impl::circularBitShiftLeftHigher<FFTSizeLog2, FFTSizeLog2 - ElementsPerInvocationLog2 + 1 >(freqIdx)) >> (32 - FFTSizeLog2);
121+ }
122+
123+ // Mirrors an index about the Nyquist frequency in the DFT order
124+ static uint32_t getDFTMirrorIndex (uint32_t idx)
125+ {
126+ return (FFTSize - idx) & (FFTSize - 1 );
127+ }
128+
129+ // Given an index `idx` of an element into the Nabla FFT, get the index into the Nabla FFT of the element corresponding to its negative frequency
130+ static uint32_t getNablaMirrorIndex (uint32_t idx)
131+ {
132+ return getNablaIndex (getDFTMirrorIndex (getDFTIndex (idx)));
133+ }
134+
135+ NBL_CONSTEXPR_STATIC_INLINE uint16_t ElementsPerInvocationLog2 = mpl::log2<ElementsPerInvocation>::value;
136+ NBL_CONSTEXPR_STATIC_INLINE uint16_t FFTSizeLog2 = ElementsPerInvocationLog2 + mpl::log2<WorkgroupSize>::value;
137+ NBL_CONSTEXPR_STATIC_INLINE uint32_t FFTSize = uint32_t (WorkgroupSize) * uint32_t (ElementsPerInvocation);
138+ };
139+
100140} //namespace fft
101141
102142// ----------------------------------- End Utils -----------------------------------------------
103143
104- template<uint16_t ElementsPerInvocation, bool Inverse, uint32_t WorkgroupSize, typename Scalar, class device_capabilities=void >
144+ template<uint16_t ElementsPerInvocation, bool Inverse, uint16_t WorkgroupSize, typename Scalar, class device_capabilities=void >
105145struct FFT;
106146
107147// For the FFT methods below, we assume:
@@ -121,13 +161,13 @@ struct FFT;
121161// * void workgroupExecutionAndMemoryBarrier();
122162
123163// 2 items per invocation forward specialization
124- template<uint32_t WorkgroupSize, typename Scalar, class device_capabilities>
164+ template<uint16_t WorkgroupSize, typename Scalar, class device_capabilities>
125165struct FFT<2 ,false , WorkgroupSize, Scalar, device_capabilities>
126166{
127167 template<typename SharedMemoryAdaptor>
128168 static void FFT_loop (uint32_t stride, NBL_REF_ARG (complex_t<Scalar>) lo, NBL_REF_ARG (complex_t<Scalar>) hi, uint32_t threadID, NBL_REF_ARG (SharedMemoryAdaptor) sharedmemAdaptor)
129169 {
130- fft::exchangeValues<SharedMemoryAdaptor, Scalar>::__call (lo, hi, threadID, stride, sharedmemAdaptor);
170+ fft::impl:: exchangeValues<SharedMemoryAdaptor, Scalar>::__call (lo, hi, threadID, stride, sharedmemAdaptor);
131171
132172 // Get twiddle with k = threadID mod stride, halfN = stride
133173 hlsl::fft::DIF<Scalar>::radix2 (hlsl::fft::twiddle<false , Scalar>(threadID & (stride - 1 ), stride), lo, hi);
@@ -167,7 +207,7 @@ struct FFT<2,false, WorkgroupSize, Scalar, device_capabilities>
167207 }
168208
169209 // special last workgroup-shuffle
170- fft::exchangeValues<adaptor_t, Scalar>::__call (lo, hi, threadID, glsl::gl_SubgroupSize (), sharedmemAdaptor);
210+ fft::impl:: exchangeValues<adaptor_t, Scalar>::__call (lo, hi, threadID, glsl::gl_SubgroupSize (), sharedmemAdaptor);
171211
172212 // Remember to update the accessor's state
173213 sharedmemAccessor = sharedmemAdaptor.accessor;
@@ -185,7 +225,7 @@ struct FFT<2,false, WorkgroupSize, Scalar, device_capabilities>
185225
186226
187227// 2 items per invocation inverse specialization
188- template<uint32_t WorkgroupSize, typename Scalar, class device_capabilities>
228+ template<uint16_t WorkgroupSize, typename Scalar, class device_capabilities>
189229struct FFT<2 ,true , WorkgroupSize, Scalar, device_capabilities>
190230{
191231 template<typename SharedMemoryAdaptor>
@@ -194,7 +234,7 @@ struct FFT<2,true, WorkgroupSize, Scalar, device_capabilities>
194234 // Get twiddle with k = threadID mod stride, halfN = stride
195235 hlsl::fft::DIT<Scalar>::radix2 (hlsl::fft::twiddle<true , Scalar>(threadID & (stride - 1 ), stride), lo, hi);
196236
197- fft::exchangeValues<SharedMemoryAdaptor, Scalar>::__call (lo, hi, threadID, stride, sharedmemAdaptor);
237+ fft::impl:: exchangeValues<SharedMemoryAdaptor, Scalar>::__call (lo, hi, threadID, stride, sharedmemAdaptor);
198238 }
199239
200240
@@ -223,7 +263,7 @@ struct FFT<2,true, WorkgroupSize, Scalar, device_capabilities>
223263 sharedmemAdaptor.accessor = sharedmemAccessor;
224264
225265 // special first workgroup-shuffle
226- fft::exchangeValues<adaptor_t, Scalar>::__call (lo, hi, threadID, glsl::gl_SubgroupSize (), sharedmemAdaptor);
266+ fft::impl:: exchangeValues<adaptor_t, Scalar>::__call (lo, hi, threadID, glsl::gl_SubgroupSize (), sharedmemAdaptor);
227267
228268 // The bigger steps
229269 [unroll]
@@ -251,7 +291,7 @@ struct FFT<2,true, WorkgroupSize, Scalar, device_capabilities>
251291};
252292
253293// Forward FFT
254- template<uint32_t K, uint32_t WorkgroupSize, typename Scalar, class device_capabilities>
294+ template<uint32_t K, uint16_t WorkgroupSize, typename Scalar, class device_capabilities>
255295struct FFT<K, false , WorkgroupSize, Scalar, device_capabilities>
256296{
257297 template<typename Accessor, typename SharedMemoryAccessor>
@@ -294,7 +334,7 @@ struct FFT<K, false, WorkgroupSize, Scalar, device_capabilities>
294334};
295335
296336// Inverse FFT
297- template<uint32_t K, uint32_t WorkgroupSize, typename Scalar, class device_capabilities>
337+ template<uint32_t K, uint16_t WorkgroupSize, typename Scalar, class device_capabilities>
298338struct FFT<K, true , WorkgroupSize, Scalar, device_capabilities>
299339{
300340 template<typename Accessor, typename SharedMemoryAccessor>
0 commit comments