Skip to content

Commit 1ac3541

Browse files
committed
removubg s2fft callbacks
1 parent d29af9b commit 1ac3541

File tree

6 files changed

+221
-132
lines changed

6 files changed

+221
-132
lines changed

CMakeLists.txt

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -53,16 +53,15 @@ if(CMAKE_CUDA_COMPILER)
5353
STABLE_ABI
5454
${CMAKE_CURRENT_LIST_DIR}/lib/src/extensions.cc
5555
${CMAKE_CURRENT_LIST_DIR}/lib/src/s2fft.cu
56-
${CMAKE_CURRENT_LIST_DIR}/lib/src/s2fft_callbacks.cu
5756
${CMAKE_CURRENT_LIST_DIR}/lib/src/plan_cache.cc
5857
${CMAKE_CURRENT_LIST_DIR}/lib/src/s2fft_kernels.cu)
5958

60-
target_link_libraries(_s2fft PRIVATE CUDA::cudart_static CUDA::cufft_static
61-
CUDA::culibos)
59+
target_link_libraries(_s2fft PRIVATE CUDA::cudart_static CUDA::cufft_static CUDA::culibos)
6260
target_include_directories(
63-
_s2fft PUBLIC ${CMAKE_CURRENT_LIST_DIR}/lib/include ${XLA_DIR})
61+
_s2fft PUBLIC ${CMAKE_CURRENT_LIST_DIR}/lib/include ${XLA_DIR} ${CUDAToolkit_INCLUDE_DIRS})
6462
set_target_properties(_s2fft PROPERTIES LINKER_LANGUAGE CUDA
6563
CUDA_SEPARABLE_COMPILATION ON)
64+
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -rdc=true")
6665
set(CMAKE_CUDA_ARCHITECTURES
6766
"70;80;89"
6867
CACHE STRING "List of CUDA compute capabilities to build cuDecomp for.")
@@ -85,7 +84,7 @@ else()
8584
# Add the executable
8685
execute_process(
8786
COMMAND "${Python_EXECUTABLE}" "-c"
88-
"from jax.extend import ffi; print(ffi.include_dir())"
87+
"from jax import ffi; print(ffi.include_dir())"
8988
OUTPUT_STRIP_TRAILING_WHITESPACE
9089
OUTPUT_VARIABLE XLA_DIR)
9190
message(STATUS "XLA include directory: ${XLA_DIR}")

lib/include/s2fft.h

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414
#include "cufft.h"
1515
#include "cufftXt.h"
1616
#include "thrust/device_vector.h"
17-
#include "s2fft_callbacks.h"
17+
#include "s2fft_kernels.h"
18+
1819

1920
namespace s2fft {
2021

@@ -168,11 +169,9 @@ class s2fftExec {
168169
* @param stream The CUDA stream to use for execution.
169170
* @param data Pointer to the input/output data on the device.
170171
* @param workspace Pointer to the workspace memory on the device.
171-
* @param callback_params Pointer to device memory containing callback parameters.
172172
* @return HRESULT indicating success or failure.
173173
*/
174-
HRESULT Forward(const s2fftDescriptor &desc, cudaStream_t stream, Complex *data, Complex *workspace,
175-
int64 *callback_params);
174+
HRESULT Forward(const s2fftDescriptor &desc, cudaStream_t stream, Complex *data, Complex *workspace);
176175

177176
/**
178177
* @brief Executes the backward Spherical Harmonic Transform.
@@ -184,11 +183,9 @@ class s2fftExec {
184183
* @param stream The CUDA stream to use for execution.
185184
* @param data Pointer to the input/output data on the device.
186185
* @param workspace Pointer to the workspace memory on the device.
187-
* @param callback_params Pointer to device memory containing callback parameters.
188186
* @return HRESULT indicating success or failure.
189187
*/
190-
HRESULT Backward(const s2fftDescriptor &desc, cudaStream_t stream, Complex *data, Complex *workspace,
191-
int64 *callback_params);
188+
HRESULT Backward(const s2fftDescriptor &desc, cudaStream_t stream, Complex *data, Complex *workspace);
192189

193190
public:
194191
// cuFFT handles for polar and equatorial FFT plans

lib/include/s2fft_kernels.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,29 @@ typedef long long int int64;
1111

1212
namespace s2fftKernels {
1313

14+
enum fft_norm {
15+
FORWARD = 1,
16+
BACKWARD = 2,
17+
ORTHO = 3,
18+
NONE = 4
19+
};
20+
1421
template <typename complex>
1522
HRESULT launch_spectral_folding(complex* data, complex* output, const int& nside, const int& L,
1623
const bool& shift, cudaStream_t stream);
1724
template <typename complex>
1825
HRESULT launch_spectral_extension(complex* data, complex* output, const int& nside, const int& L,
1926
cudaStream_t stream);
27+
28+
template <typename complex>
29+
HRESULT launch_shift_normalize_kernel(
30+
cudaStream_t stream,
31+
complex* data, // In-place data buffer
32+
int nside,
33+
bool apply_shift,
34+
int norm
35+
);
36+
2037
} // namespace s2fftKernels
2138

2239
#endif // _S2FFT_KERNELS_H

lib/src/extensions.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ ffi::Error healpix_forward(cudaStream_t stream, ffi::Buffer<T> input, ffi::Resul
111111
reinterpret_cast<int64*>(callback_params->typed_data() + i * params_offset);
112112

113113
// Step 2g: Launch the forward transform on this sub-stream.
114-
executor->Forward(descriptor, sub_stream, data_c, workspace_c, callback_params_c);
114+
executor->Forward(descriptor, sub_stream, data_c, workspace_c);
115115
// Step 2h: Launch spectral extension kernel.
116116
s2fftKernels::launch_spectral_extension(data_c, out_c, descriptor.nside,
117117
descriptor.harmonic_band_limit, sub_stream);
@@ -131,7 +131,7 @@ ffi::Error healpix_forward(cudaStream_t stream, ffi::Buffer<T> input, ffi::Resul
131131
auto executor = std::make_shared<s2fftExec<fft_complex_type>>();
132132
PlanCache::GetInstance().GetS2FFTExec(descriptor, executor);
133133
// Step 2m: Launch the forward transform.
134-
executor->Forward(descriptor, stream, data_c, workspace_c, callback_params_c);
134+
executor->Forward(descriptor, stream, data_c, workspace_c);
135135
// Step 2n: Launch spectral extension kernel.
136136
s2fftKernels::launch_spectral_extension(data_c, out_c, descriptor.nside,
137137
descriptor.harmonic_band_limit, stream);
@@ -205,7 +205,7 @@ ffi::Error healpix_backward(cudaStream_t stream, ffi::Buffer<T> input, ffi::Resu
205205
descriptor.harmonic_band_limit, descriptor.shift,
206206
sub_stream);
207207
// Step 2h: Launch the backward transform on this sub-stream.
208-
executor->Backward(descriptor, sub_stream, out_c, workspace_c, callback_params_c);
208+
executor->Backward(descriptor, sub_stream, out_c, workspace_c);
209209
}
210210
// Step 2i: Join all forked streams back to the main stream.
211211
handler.join(stream);
@@ -228,7 +228,7 @@ ffi::Error healpix_backward(cudaStream_t stream, ffi::Buffer<T> input, ffi::Resu
228228
s2fftKernels::launch_spectral_folding(data_c, out_c, descriptor.nside, descriptor.harmonic_band_limit,
229229
descriptor.shift, stream);
230230
// Step 2n: Launch the backward transform.
231-
executor->Backward(descriptor, stream, out_c, workspace_c, callback_params_c);
231+
executor->Backward(descriptor, stream, out_c, workspace_c);
232232
return ffi::Error::Success();
233233
}
234234
}

lib/src/s2fft.cu

Lines changed: 44 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
#include <numeric>
1313

1414
#include <vector>
15-
#include "s2fft_callbacks.h"
15+
#include "s2fft_kernels.h"
1616

1717
namespace s2fft {
1818

@@ -81,15 +81,7 @@ HRESULT s2fftExec<Complex>::Initialize(const s2fftDescriptor &descriptor) {
8181
// Step 7e: Update overall maximum workspace size again.
8282
worksize = std::max(worksize, polar_worksize);
8383

84-
// Step 7f: Allocate device memory for callback parameters and copy host parameters.
85-
int64 params[2];
86-
int64 *params_dev;
87-
params[0] = n[0];
88-
params[1] = idist;
89-
cudaMalloc(&params_dev, 2 * sizeof(int64));
90-
cudaMemcpy(params_dev, params, 2 * sizeof(int64), cudaMemcpyHostToDevice);
91-
92-
// Step 7g: Store the created plans.
84+
// Step 7f: Store the created plans.
9385
m_polar_plans.push_back(plan);
9486
m_inverse_polar_plans.push_back(inverse_plan);
9587
}
@@ -117,34 +109,21 @@ HRESULT s2fftExec<Complex>::Initialize(const s2fftDescriptor &descriptor) {
117109
return S_OK;
118110
}
119111

112+
120113
template <typename Complex>
121114
HRESULT s2fftExec<Complex>::Forward(const s2fftDescriptor &desc, cudaStream_t stream, Complex *data,
122-
Complex *workspace, int64 *callback_params) {
115+
Complex *workspace) {
123116
// Step 1: Determine the FFT direction (forward or inverse based on adjoint flag).
124117
const int DIRECTION = desc.adjoint ? CUFFT_INVERSE : CUFFT_FORWARD;
125118
// Step 2: Extract normalization, shift, and double precision flags from the descriptor.
126119
const s2fftKernels::fft_norm &norm = desc.norm;
127120
const bool &shift = desc.shift;
128-
const bool &isDouble = desc.double_precision;
129121

130122
// Step 3: Execute FFTs for polar rings.
131123
for (int i = 0; i < m_nside - 1; i++) {
132124
// Step 3a: Get upper and lower ring offsets.
133125
int upper_ring_offset = m_upper_ring_offsets[i];
134-
int lower_ring_offset = m_lower_ring_offsets[i];
135-
136-
// Step 3b: Set parameters for the polar ring FFT callback.
137-
int64 param_offset = 2 * i; // Offset for the parameters in the callback
138-
int64 params[2];
139-
params[0] = 4 * ((int64)i + 1); // Size of the ring
140-
params[1] = lower_ring_offset - upper_ring_offset;
141126

142-
// Step 3c: Copy callback parameters to device memory asynchronously.
143-
int64 *params_device = callback_params + param_offset;
144-
cudaMemcpyAsync(params_device, params, 2 * sizeof(int64), cudaMemcpyHostToDevice, stream);
145-
146-
// Step 3d: Set the forward callback for the current polar plan.
147-
s2fftKernels::setForwardCallback(m_polar_plans[i], params_device, shift, false, isDouble, norm);
148127
// Step 3e: Set the CUDA stream and work area for the cuFFT plan.
149128
CUFFT_CALL(cufftSetStream(m_polar_plans[i], stream));
150129
CUFFT_CALL(cufftSetWorkArea(m_polar_plans[i], workspace));
@@ -153,51 +132,49 @@ HRESULT s2fftExec<Complex>::Forward(const s2fftDescriptor &desc, cudaStream_t st
153132
cufftXtExec(m_polar_plans[i], data + upper_ring_offset, data + upper_ring_offset, DIRECTION));
154133
}
155134
// Step 4: Execute FFT for the equatorial ring.
156-
// Step 4a: Set equator parameters for the callback.
157-
int64 equator_size = (4 * m_nside);
158-
int64 equator_offset = (m_nside - 1) * 2;
159-
int64 *equator_params_device = callback_params + equator_offset;
160-
// Step 4b: Copy equator parameters to device memory asynchronously.
161-
cudaMemcpyAsync(equator_params_device, &equator_size, sizeof(int64), cudaMemcpyHostToDevice, stream);
162-
// Step 4c: Set the forward callback for the equatorial plan.
163-
s2fftKernels::setForwardCallback(m_equator_plan, equator_params_device, shift, true, isDouble, norm);
164135
// Step 4d: Set the CUDA stream and work area for the equatorial cuFFT plan.
165136
CUFFT_CALL(cufftSetStream(m_equator_plan, stream));
166137
CUFFT_CALL(cufftSetWorkArea(m_equator_plan, workspace));
167138
// Step 4e: Execute the cuFFT transform for the equator.
168139
CUFFT_CALL(cufftXtExec(m_equator_plan, data + m_equatorial_offset_start, data + m_equatorial_offset_start,
169140
DIRECTION));
170141

142+
// Step 5: Launch the custom kernel for normalization and shifting.
143+
switch (norm) {
144+
case s2fftKernels::fft_norm::NONE:
145+
case s2fftKernels::fft_norm::BACKWARD:
146+
// No normalization, only shift if required.
147+
s2fftKernels::launch_shift_normalize_kernel(stream, data, m_nside, shift, 2);
148+
break;
149+
case s2fftKernels::fft_norm::FORWARD:
150+
// Normalize by sqrt(Npix).
151+
std::cout << "Applying forward normalization." << std::endl;
152+
s2fftKernels::launch_shift_normalize_kernel(stream, data, m_nside, shift, 0);
153+
break;
154+
case s2fftKernels::fft_norm::ORTHO:
155+
// Normalize by Npix.
156+
s2fftKernels::launch_shift_normalize_kernel(stream, data, m_nside, shift, 1);
157+
break;
158+
default:
159+
return E_INVALIDARG; // Invalid normalization type.
160+
}
161+
162+
171163
return S_OK;
172164
}
173165

174166
template <typename Complex>
175167
HRESULT s2fftExec<Complex>::Backward(const s2fftDescriptor &desc, cudaStream_t stream, Complex *data,
176-
Complex *workspace, int64 *callback_params) {
168+
Complex *workspace) {
177169
// Step 1: Determine the FFT direction (forward or inverse based on adjoint flag).
178170
const int DIRECTION = desc.adjoint ? CUFFT_FORWARD : CUFFT_INVERSE;
179171
// Step 2: Extract normalization, shift, and double precision flags from the descriptor.
180172
const s2fftKernels::fft_norm &norm = desc.norm;
181-
const bool &shift = desc.shift;
182-
const bool &isDouble = desc.double_precision;
183173

184174
// Step 3: Execute inverse FFTs for polar rings.
185175
for (int i = 0; i < m_nside - 1; i++) {
186176
// Step 3a: Get upper and lower ring offsets.
187177
int upper_ring_offset = m_upper_ring_offsets[i];
188-
int lower_ring_offset = m_lower_ring_offsets[i];
189-
// Step 3b: Set parameters for the polar ring inverse FFT callback.
190-
int64 param_offset = 2 * i; // Offset for the parameters in the callback
191-
int64 params[2];
192-
params[0] = 4 * ((int64)i + 1); // Size of the ring
193-
params[1] = lower_ring_offset - upper_ring_offset;
194-
195-
// Step 3c: Copy callback parameters to device memory asynchronously.
196-
int64 *params_device = callback_params + param_offset;
197-
cudaMemcpyAsync(params_device, params, 2 * sizeof(int64), cudaMemcpyHostToDevice, stream);
198-
// Step 3d: Set the backward callback for the current polar plan.
199-
s2fftKernels::setBackwardCallback(m_inverse_polar_plans[i], params_device, shift, false, isDouble,
200-
norm);
201178

202179
// Step 3e: Set the CUDA stream and work area for the cuFFT plan.
203180
CUFFT_CALL(cufftSetStream(m_inverse_polar_plans[i], stream));
@@ -207,22 +184,31 @@ HRESULT s2fftExec<Complex>::Backward(const s2fftDescriptor &desc, cudaStream_t s
207184
DIRECTION));
208185
}
209186
// Step 4: Execute inverse FFT for the equatorial ring.
210-
// Step 4a: Set equator parameters for the callback.
211-
int64 equator_size = (4 * m_nside);
212-
int64 equator_offset = (m_nside - 1) * 2;
213-
int64 *equator_params_device = callback_params + equator_offset;
214-
// Step 4b: Copy equator parameters to device memory asynchronously.
215-
cudaMemcpyAsync(equator_params_device, &equator_size, sizeof(int64), cudaMemcpyHostToDevice, stream);
216-
// Step 4c: Set the backward callback for the equatorial plan.
217-
s2fftKernels::setBackwardCallback(m_inverse_equator_plan, equator_params_device, shift, true, isDouble,
218-
norm);
219187
// Step 4d: Set the CUDA stream and work area for the equatorial cuFFT plan.
220188
CUFFT_CALL(cufftSetStream(m_inverse_equator_plan, stream));
221189
CUFFT_CALL(cufftSetWorkArea(m_inverse_equator_plan, workspace));
222190
// Step 4e: Execute the cuFFT transform for the equator.
223191
CUFFT_CALL(cufftXtExec(m_inverse_equator_plan, data + m_equatorial_offset_start,
224192
data + m_equatorial_offset_start, DIRECTION));
225193

194+
// Step 5: Launch the custom kernel for normalization and shifting.
195+
switch (norm) {
196+
case s2fftKernels::fft_norm::NONE:
197+
case s2fftKernels::fft_norm::FORWARD:
198+
// No normalization, do nothing.
199+
break;
200+
case s2fftKernels::fft_norm::BACKWARD:
201+
// Normalize by sqrt(Npix).
202+
s2fftKernels::launch_shift_normalize_kernel(stream, data, m_nside, false, 0);
203+
break;
204+
case s2fftKernels::fft_norm::ORTHO:
205+
// Normalize by Npix.
206+
s2fftKernels::launch_shift_normalize_kernel(stream, data, m_nside, false, 1);
207+
break;
208+
default:
209+
return E_INVALIDARG; // Invalid normalization type.
210+
}
211+
226212
return S_OK;
227213
}
228214

0 commit comments

Comments
 (0)