1212#include < numeric>
1313
1414#include < vector>
15- #include " s2fft_callbacks .h"
15+ #include " s2fft_kernels .h"
1616
1717namespace s2fft {
1818
@@ -81,15 +81,7 @@ HRESULT s2fftExec<Complex>::Initialize(const s2fftDescriptor &descriptor) {
8181 // Step 7e: Update overall maximum workspace size again.
8282 worksize = std::max (worksize, polar_worksize);
8383
84- // Step 7f: Allocate device memory for callback parameters and copy host parameters.
85- int64 params[2 ];
86- int64 *params_dev;
87- params[0 ] = n[0 ];
88- params[1 ] = idist;
89- cudaMalloc (¶ms_dev, 2 * sizeof (int64));
90- cudaMemcpy (params_dev, params, 2 * sizeof (int64), cudaMemcpyHostToDevice);
91-
92- // Step 7g: Store the created plans.
84+ // Step 7f: Store the created plans.
9385 m_polar_plans.push_back (plan);
9486 m_inverse_polar_plans.push_back (inverse_plan);
9587 }
@@ -117,34 +109,21 @@ HRESULT s2fftExec<Complex>::Initialize(const s2fftDescriptor &descriptor) {
117109 return S_OK;
118110}
119111
112+
120113template <typename Complex>
121114HRESULT s2fftExec<Complex>::Forward(const s2fftDescriptor &desc, cudaStream_t stream, Complex *data,
122- Complex *workspace, int64 *callback_params ) {
115+ Complex *workspace) {
123116 // Step 1: Determine the FFT direction (forward or inverse based on adjoint flag).
124117 const int DIRECTION = desc.adjoint ? CUFFT_INVERSE : CUFFT_FORWARD;
125118 // Step 2: Extract normalization, shift, and double precision flags from the descriptor.
126119 const s2fftKernels::fft_norm &norm = desc.norm ;
127120 const bool &shift = desc.shift ;
128- const bool &isDouble = desc.double_precision ;
129121
130122 // Step 3: Execute FFTs for polar rings.
131123 for (int i = 0 ; i < m_nside - 1 ; i++) {
132124 // Step 3a: Get upper and lower ring offsets.
133125 int upper_ring_offset = m_upper_ring_offsets[i];
134- int lower_ring_offset = m_lower_ring_offsets[i];
135-
136- // Step 3b: Set parameters for the polar ring FFT callback.
137- int64 param_offset = 2 * i; // Offset for the parameters in the callback
138- int64 params[2 ];
139- params[0 ] = 4 * ((int64)i + 1 ); // Size of the ring
140- params[1 ] = lower_ring_offset - upper_ring_offset;
141126
142- // Step 3c: Copy callback parameters to device memory asynchronously.
143- int64 *params_device = callback_params + param_offset;
144- cudaMemcpyAsync (params_device, params, 2 * sizeof (int64), cudaMemcpyHostToDevice, stream);
145-
146- // Step 3d: Set the forward callback for the current polar plan.
147- s2fftKernels::setForwardCallback (m_polar_plans[i], params_device, shift, false , isDouble, norm);
148127 // Step 3e: Set the CUDA stream and work area for the cuFFT plan.
149128 CUFFT_CALL (cufftSetStream (m_polar_plans[i], stream));
150129 CUFFT_CALL (cufftSetWorkArea (m_polar_plans[i], workspace));
@@ -153,51 +132,49 @@ HRESULT s2fftExec<Complex>::Forward(const s2fftDescriptor &desc, cudaStream_t st
153132 cufftXtExec (m_polar_plans[i], data + upper_ring_offset, data + upper_ring_offset, DIRECTION));
154133 }
155134 // Step 4: Execute FFT for the equatorial ring.
156- // Step 4a: Set equator parameters for the callback.
157- int64 equator_size = (4 * m_nside);
158- int64 equator_offset = (m_nside - 1 ) * 2 ;
159- int64 *equator_params_device = callback_params + equator_offset;
160- // Step 4b: Copy equator parameters to device memory asynchronously.
161- cudaMemcpyAsync (equator_params_device, &equator_size, sizeof (int64), cudaMemcpyHostToDevice, stream);
162- // Step 4c: Set the forward callback for the equatorial plan.
163- s2fftKernels::setForwardCallback (m_equator_plan, equator_params_device, shift, true , isDouble, norm);
164135 // Step 4d: Set the CUDA stream and work area for the equatorial cuFFT plan.
165136 CUFFT_CALL (cufftSetStream (m_equator_plan, stream));
166137 CUFFT_CALL (cufftSetWorkArea (m_equator_plan, workspace));
167138 // Step 4e: Execute the cuFFT transform for the equator.
168139 CUFFT_CALL (cufftXtExec (m_equator_plan, data + m_equatorial_offset_start, data + m_equatorial_offset_start,
169140 DIRECTION));
170141
142+ // Step 5: Launch the custom kernel for normalization and shifting.
143+ switch (norm) {
144+ case s2fftKernels::fft_norm::NONE:
145+ case s2fftKernels::fft_norm::BACKWARD:
146+ // No normalization, only shift if required.
147+ s2fftKernels::launch_shift_normalize_kernel (stream, data, m_nside, shift, 2 );
148+ break ;
149+ case s2fftKernels::fft_norm::FORWARD:
150+ // Normalize by sqrt(Npix).
151+ std::cout << " Applying forward normalization." << std::endl;
152+ s2fftKernels::launch_shift_normalize_kernel (stream, data, m_nside, shift, 0 );
153+ break ;
154+ case s2fftKernels::fft_norm::ORTHO:
155+ // Normalize by Npix.
156+ s2fftKernels::launch_shift_normalize_kernel (stream, data, m_nside, shift, 1 );
157+ break ;
158+ default :
159+ return E_INVALIDARG; // Invalid normalization type.
160+ }
161+
162+
171163 return S_OK;
172164}
173165
174166template <typename Complex>
175167HRESULT s2fftExec<Complex>::Backward(const s2fftDescriptor &desc, cudaStream_t stream, Complex *data,
176- Complex *workspace, int64 *callback_params ) {
168+ Complex *workspace) {
177169 // Step 1: Determine the FFT direction (forward or inverse based on adjoint flag).
178170 const int DIRECTION = desc.adjoint ? CUFFT_FORWARD : CUFFT_INVERSE;
179171 // Step 2: Extract normalization, shift, and double precision flags from the descriptor.
180172 const s2fftKernels::fft_norm &norm = desc.norm ;
181- const bool &shift = desc.shift ;
182- const bool &isDouble = desc.double_precision ;
183173
184174 // Step 3: Execute inverse FFTs for polar rings.
185175 for (int i = 0 ; i < m_nside - 1 ; i++) {
186176 // Step 3a: Get upper and lower ring offsets.
187177 int upper_ring_offset = m_upper_ring_offsets[i];
188- int lower_ring_offset = m_lower_ring_offsets[i];
189- // Step 3b: Set parameters for the polar ring inverse FFT callback.
190- int64 param_offset = 2 * i; // Offset for the parameters in the callback
191- int64 params[2 ];
192- params[0 ] = 4 * ((int64)i + 1 ); // Size of the ring
193- params[1 ] = lower_ring_offset - upper_ring_offset;
194-
195- // Step 3c: Copy callback parameters to device memory asynchronously.
196- int64 *params_device = callback_params + param_offset;
197- cudaMemcpyAsync (params_device, params, 2 * sizeof (int64), cudaMemcpyHostToDevice, stream);
198- // Step 3d: Set the backward callback for the current polar plan.
199- s2fftKernels::setBackwardCallback (m_inverse_polar_plans[i], params_device, shift, false , isDouble,
200- norm);
201178
202179 // Step 3e: Set the CUDA stream and work area for the cuFFT plan.
203180 CUFFT_CALL (cufftSetStream (m_inverse_polar_plans[i], stream));
@@ -207,22 +184,31 @@ HRESULT s2fftExec<Complex>::Backward(const s2fftDescriptor &desc, cudaStream_t s
207184 DIRECTION));
208185 }
209186 // Step 4: Execute inverse FFT for the equatorial ring.
210- // Step 4a: Set equator parameters for the callback.
211- int64 equator_size = (4 * m_nside);
212- int64 equator_offset = (m_nside - 1 ) * 2 ;
213- int64 *equator_params_device = callback_params + equator_offset;
214- // Step 4b: Copy equator parameters to device memory asynchronously.
215- cudaMemcpyAsync (equator_params_device, &equator_size, sizeof (int64), cudaMemcpyHostToDevice, stream);
216- // Step 4c: Set the backward callback for the equatorial plan.
217- s2fftKernels::setBackwardCallback (m_inverse_equator_plan, equator_params_device, shift, true , isDouble,
218- norm);
219187 // Step 4d: Set the CUDA stream and work area for the equatorial cuFFT plan.
220188 CUFFT_CALL (cufftSetStream (m_inverse_equator_plan, stream));
221189 CUFFT_CALL (cufftSetWorkArea (m_inverse_equator_plan, workspace));
222190 // Step 4e: Execute the cuFFT transform for the equator.
223191 CUFFT_CALL (cufftXtExec (m_inverse_equator_plan, data + m_equatorial_offset_start,
224192 data + m_equatorial_offset_start, DIRECTION));
225193
194+ // Step 5: Launch the custom kernel for normalization and shifting.
195+ switch (norm) {
196+ case s2fftKernels::fft_norm::NONE:
197+ case s2fftKernels::fft_norm::FORWARD:
198+ // No normalization, do nothing.
199+ break ;
200+ case s2fftKernels::fft_norm::BACKWARD:
201+ // Normalize by sqrt(Npix).
202+ s2fftKernels::launch_shift_normalize_kernel (stream, data, m_nside, false , 0 );
203+ break ;
204+ case s2fftKernels::fft_norm::ORTHO:
205+ // Normalize by Npix.
206+ s2fftKernels::launch_shift_normalize_kernel (stream, data, m_nside, false , 1 );
207+ break ;
208+ default :
209+ return E_INVALIDARG; // Invalid normalization type.
210+ }
211+
226212 return S_OK;
227213}
228214
0 commit comments