Skip to content

Commit 928ea12

Browse files
committed
Fix race condition error and update notebook
1 parent ac1609d commit 928ea12

File tree

7 files changed

+698
-640
lines changed

7 files changed

+698
-640
lines changed

lib/include/s2fft.h

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -168,9 +168,13 @@ class s2fftExec {
168168
* @param stream The CUDA stream to use for execution.
169169
* @param data Pointer to the input/output data on the device.
170170
* @param workspace Pointer to the workspace memory on the device.
171+
* @param shift_scratch Pointer to scratch buffer for out-of-place shifting (can be nullptr for in-place).
172+
* @param use_out_of_place If true, use out-of-place shifting with shift_scratch; if false, use in-place
173+
* with cooperative kernel.
171174
* @return HRESULT indicating success or failure.
172175
*/
173-
HRESULT Forward(const s2fftDescriptor &desc, cudaStream_t stream, Complex *data, Complex *workspace);
176+
HRESULT Forward(const s2fftDescriptor &desc, cudaStream_t stream, Complex *data, Complex *workspace,
177+
Complex *shift_scratch, bool use_out_of_place);
174178

175179
/**
176180
* @brief Executes the backward Spherical Harmonic Transform.
@@ -182,9 +186,13 @@ class s2fftExec {
182186
* @param stream The CUDA stream to use for execution.
183187
* @param data Pointer to the input/output data on the device.
184188
* @param workspace Pointer to the workspace memory on the device.
189+
* @param shift_scratch Pointer to scratch buffer for out-of-place shifting (can be nullptr for in-place).
190+
* @param use_out_of_place If true, use out-of-place shifting with shift_scratch; if false, use in-place
191+
* with cooperative kernel.
185192
* @return HRESULT indicating success or failure.
186193
*/
187-
HRESULT Backward(const s2fftDescriptor &desc, cudaStream_t stream, Complex *data, Complex *workspace);
194+
HRESULT Backward(const s2fftDescriptor &desc, cudaStream_t stream, Complex *data, Complex *workspace,
195+
Complex *shift_scratch, bool use_out_of_place);
188196

189197
public:
190198
// cuFFT handles for polar and equatorial FFT plans

lib/include/s2fft_kernels.h

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -70,20 +70,23 @@ HRESULT launch_spectral_extension(complex* data, complex* output, const int& nsi
7070
* This function configures and launches the shift_normalize_kernel with appropriate
7171
* grid and block dimensions. It handles both single and double precision complex
7272
* types and applies the requested normalization and shifting operations to HEALPix
73-
* pixel data on a per-ring basis.
73+
* pixel data. Supports both in-place (with cooperative kernel) and out-of-place
74+
* (with scratch buffer) modes to enable compatibility with JAX transforms.
7475
*
7576
* @tparam complex The complex type (cufftComplex or cufftDoubleComplex).
7677
* @param stream CUDA stream for kernel execution.
77-
* @param data Input/output array of HEALPix pixel data (in-place processing).
78+
* @param data Input/output array of HEALPix pixel data.
79+
* @param shift_buffer Scratch buffer for out-of-place shifting (can be nullptr for in-place).
7880
* @param nside The HEALPix Nside parameter.
7981
* @param apply_shift Flag indicating whether to apply FFT shifting.
8082
* @param norm Normalization type (0=by nphi, 1=by sqrt(nphi), 2=no normalization).
83+
* @param use_out_of_place If true, use out-of-place shifting with shift_buffer; if false, use in-place with
84+
* cooperative kernel.
8185
* @return HRESULT indicating success or failure.
8286
*/
8387
template <typename complex>
84-
HRESULT launch_shift_normalize_kernel(cudaStream_t stream,
85-
complex* data, // In-place data buffer
86-
int nside, bool apply_shift, int norm);
88+
HRESULT launch_shift_normalize_kernel(cudaStream_t stream, complex* data, complex* shift_buffer, int nside,
89+
bool apply_shift, int norm, bool use_out_of_place);
8790

8891
} // namespace s2fftKernels
8992

lib/src/extensions.cc

Lines changed: 75 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#include <cstddef>
33
#include <complex>
44
#include <type_traits>
5+
#include <cstdlib>
56

67
namespace nb = nanobind;
78

@@ -62,21 +63,44 @@ constexpr bool is_double_v = is_double<T>::value;
6263
*
6364
* @tparam T The XLA data type (F32, F64, etc).
6465
* @param stream CUDA stream to use.
66+
* @param scratch ScratchAllocator for temporary device memory.
6567
* @param input Input buffer containing HEALPix pixel-space data.
6668
* @param output Output buffer to store the FTM result.
6769
* @param workspace Output buffer for temporary workspace memory.
6870
* @param descriptor Descriptor containing transform parameters.
6971
* @return ffi::Error indicating success or failure.
7072
*/
7173
template <ffi::DataType T>
72-
ffi::Error healpix_forward(cudaStream_t stream, ffi::Buffer<T> input, ffi::Result<ffi::Buffer<T>> output,
73-
ffi::Result<ffi::Buffer<T>> workspace, s2fftDescriptor descriptor) {
74+
ffi::Error healpix_forward(cudaStream_t stream, ffi::ScratchAllocator& scratch, ffi::Buffer<T> input,
75+
ffi::Result<ffi::Buffer<T>> output, ffi::Result<ffi::Buffer<T>> workspace,
76+
s2fftDescriptor descriptor) {
7477
// Step 1: Determine the complex type based on the XLA data type.
7578
using fft_complex_type = fft_complex_t<T>;
7679
const auto& dim_in = input.dimensions();
7780

81+
// Step 1a: Parse environment variable for shift strategy (static for thread safety).
82+
static const std::string shift_strategy = []() {
83+
const char* env = std::getenv("S2FFT_CUDA_SHIFT_STRATEGY");
84+
return env ? std::string(env) : "in_place";
85+
}();
86+
bool use_out_of_place = (shift_strategy == "out_of_place");
87+
bool is_batched = (dim_in.size() == 2);
88+
89+
// Step 1b: Allocate scratch buffer if using out-of-place mode.
90+
fft_complex_type* shift_scratch = nullptr;
91+
if (use_out_of_place && descriptor.shift) {
92+
int64_t Npix = descriptor.nside * descriptor.nside * 12;
93+
int batch_count = is_batched ? dim_in[0] : 1;
94+
size_t scratch_size = Npix * sizeof(fft_complex_type) * batch_count;
95+
auto scratch_result = scratch.Allocate(scratch_size);
96+
if (!scratch_result.has_value()) {
97+
return ffi::Error::Internal("Failed to allocate scratch buffer for shift operation");
98+
}
99+
shift_scratch = reinterpret_cast<fft_complex_type*>(scratch_result.value());
100+
}
101+
78102
// Step 2: Handle batched and non-batched cases separately.
79-
if (dim_in.size() == 2) {
103+
if (is_batched) {
80104
// Step 2a: Batched case.
81105
int batch_count = dim_in[0];
82106
// Step 2b: Compute offsets for input and output for each batch.
@@ -104,7 +128,12 @@ ffi::Error healpix_forward(cudaStream_t stream, ffi::Buffer<T> input, ffi::Resul
104128
reinterpret_cast<fft_complex_type*>(workspace->typed_data() + i * executor->m_work_size);
105129

106130
// Step 2g: Launch the forward transform on this sub-stream.
107-
executor->Forward(descriptor, sub_stream, data_c, workspace_c);
131+
fft_complex_type* shift_scratch_batch =
132+
use_out_of_place && shift_scratch
133+
? shift_scratch + i * (descriptor.nside * descriptor.nside * 12)
134+
: nullptr;
135+
executor->Forward(descriptor, sub_stream, data_c, workspace_c, shift_scratch_batch,
136+
use_out_of_place);
108137
// Step 2h: Launch spectral extension kernel.
109138
s2fftKernels::launch_spectral_extension(data_c, out_c, descriptor.nside,
110139
descriptor.harmonic_band_limit, sub_stream);
@@ -123,7 +152,7 @@ ffi::Error healpix_forward(cudaStream_t stream, ffi::Buffer<T> input, ffi::Resul
123152
auto executor = std::make_shared<s2fftExec<fft_complex_type>>();
124153
PlanCache::GetInstance().GetS2FFTExec(descriptor, executor);
125154
// Step 2m: Launch the forward transform.
126-
executor->Forward(descriptor, stream, data_c, workspace_c);
155+
executor->Forward(descriptor, stream, data_c, workspace_c, shift_scratch, use_out_of_place);
127156
// Step 2n: Launch spectral extension kernel.
128157
s2fftKernels::launch_spectral_extension(data_c, out_c, descriptor.nside,
129158
descriptor.harmonic_band_limit, stream);
@@ -141,22 +170,45 @@ ffi::Error healpix_forward(cudaStream_t stream, ffi::Buffer<T> input, ffi::Resul
141170
*
142171
* @tparam T The XLA data type.
143172
* @param stream CUDA stream to use.
173+
* @param scratch ScratchAllocator for temporary device memory.
144174
* @param input Input buffer containing FTM data.
145175
* @param output Output buffer to store HEALPix pixel-space data.
146176
* @param workspace Output buffer for temporary workspace memory.
147177
* @param descriptor Descriptor containing transform parameters.
148178
* @return ffi::Error indicating success or failure.
149179
*/
150180
template <ffi::DataType T>
151-
ffi::Error healpix_backward(cudaStream_t stream, ffi::Buffer<T> input, ffi::Result<ffi::Buffer<T>> output,
152-
ffi::Result<ffi::Buffer<T>> workspace, s2fftDescriptor descriptor) {
181+
ffi::Error healpix_backward(cudaStream_t stream, ffi::ScratchAllocator& scratch, ffi::Buffer<T> input,
182+
ffi::Result<ffi::Buffer<T>> output, ffi::Result<ffi::Buffer<T>> workspace,
183+
s2fftDescriptor descriptor) {
153184
// Step 1: Determine the complex type based on the XLA data type.
154185
using fft_complex_type = fft_complex_t<T>;
155186
const auto& dim_in = input.dimensions();
156187
const auto& dim_out = output->dimensions();
157188

189+
// Step 1a: Parse environment variable for shift strategy (static for thread safety).
190+
static const std::string shift_strategy = []() {
191+
const char* env = std::getenv("S2FFT_CUDA_SHIFT_STRATEGY");
192+
return env ? std::string(env) : "in_place";
193+
}();
194+
bool use_out_of_place = (shift_strategy == "out_of_place");
195+
bool is_batched = (dim_in.size() == 3);
196+
197+
// Step 1b: Allocate scratch buffer if using out-of-place mode.
198+
fft_complex_type* shift_scratch = nullptr;
199+
if (use_out_of_place && descriptor.shift) {
200+
int64_t Npix = descriptor.nside * descriptor.nside * 12;
201+
int batch_count = is_batched ? dim_in[0] : 1;
202+
size_t scratch_size = Npix * sizeof(fft_complex_type) * batch_count;
203+
auto scratch_result = scratch.Allocate(scratch_size);
204+
if (!scratch_result.has_value()) {
205+
return ffi::Error::Internal("Failed to allocate scratch buffer for shift operation");
206+
}
207+
shift_scratch = reinterpret_cast<fft_complex_type*>(scratch_result.value());
208+
}
209+
158210
// Step 2: Handle batched and non-batched cases separately.
159-
if (dim_in.size() == 3) {
211+
if (is_batched) {
160212
// Step 2a: Batched case.
161213
// Assertions to ensure correct input/output dimensions for batched operations.
162214
assert(dim_out.size() == 2);
@@ -191,7 +243,12 @@ ffi::Error healpix_backward(cudaStream_t stream, ffi::Buffer<T> input, ffi::Resu
191243
descriptor.harmonic_band_limit, descriptor.shift,
192244
sub_stream);
193245
// Step 2h: Launch the backward transform on this sub-stream.
194-
executor->Backward(descriptor, sub_stream, out_c, workspace_c);
246+
fft_complex_type* shift_scratch_batch =
247+
use_out_of_place && shift_scratch
248+
? shift_scratch + i * (descriptor.nside * descriptor.nside * 12)
249+
: nullptr;
250+
executor->Backward(descriptor, sub_stream, out_c, workspace_c, shift_scratch_batch,
251+
use_out_of_place);
195252
}
196253
// Step 2i: Join all forked streams back to the main stream.
197254
handler.join(stream);
@@ -213,7 +270,7 @@ ffi::Error healpix_backward(cudaStream_t stream, ffi::Buffer<T> input, ffi::Resu
213270
s2fftKernels::launch_spectral_folding(data_c, out_c, descriptor.nside, descriptor.harmonic_band_limit,
214271
descriptor.shift, stream);
215272
// Step 2n: Launch the backward transform.
216-
executor->Backward(descriptor, stream, out_c, workspace_c);
273+
executor->Backward(descriptor, stream, out_c, workspace_c, shift_scratch, use_out_of_place);
217274
return ffi::Error::Success();
218275
}
219276
}
@@ -298,19 +355,20 @@ s2fftDescriptor build_descriptor(int64_t nside, int64_t harmonic_band_limit, boo
298355
* @return ffi::Error indicating success or failure.
299356
*/
300357
template <ffi::DataType T>
301-
ffi::Error healpix_fft_cuda(cudaStream_t stream, int64_t nside, int64_t harmonic_band_limit, bool reality,
302-
bool forward, bool normalize, bool adjoint, ffi::Buffer<T> input,
303-
ffi::Result<ffi::Buffer<T>> output, ffi::Result<ffi::Buffer<T>> workspace) {
358+
ffi::Error healpix_fft_cuda(cudaStream_t stream, ffi::ScratchAllocator scratch, int64_t nside,
359+
int64_t harmonic_band_limit, bool reality, bool forward, bool normalize,
360+
bool adjoint, ffi::Buffer<T> input, ffi::Result<ffi::Buffer<T>> output,
361+
ffi::Result<ffi::Buffer<T>> workspace) {
304362
// Step 1: Build the s2fftDescriptor based on the input parameters.
305363
size_t work_size = 0; // Variable to hold the workspace size
306364
s2fftDescriptor descriptor = build_descriptor<T>(nside, harmonic_band_limit, reality, forward, normalize,
307365
adjoint, true, work_size);
308366

309367
// Step 2: Dispatch to either forward or backward transform based on the 'forward' flag.
310368
if (forward) {
311-
return healpix_forward<T>(stream, input, output, workspace, descriptor);
369+
return healpix_forward<T>(stream, scratch, input, output, workspace, descriptor);
312370
} else {
313-
return healpix_backward<T>(stream, input, output, workspace, descriptor);
371+
return healpix_backward<T>(stream, scratch, input, output, workspace, descriptor);
314372
}
315373
}
316374

@@ -323,6 +381,7 @@ ffi::Error healpix_fft_cuda(cudaStream_t stream, int64_t nside, int64_t harmonic
323381
XLA_FFI_DEFINE_HANDLER_SYMBOL(healpix_fft_cuda_C64, healpix_fft_cuda<ffi::DataType::C64>,
324382
ffi::Ffi::Bind()
325383
.Ctx<ffi::PlatformStream<cudaStream_t>>()
384+
.Ctx<ffi::ScratchAllocator>()
326385
.Attr<int64_t>("nside")
327386
.Attr<int64_t>("harmonic_band_limit")
328387
.Attr<bool>("reality")
@@ -336,6 +395,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(healpix_fft_cuda_C64, healpix_fft_cuda<ffi::DataTy
336395
XLA_FFI_DEFINE_HANDLER_SYMBOL(healpix_fft_cuda_C128, healpix_fft_cuda<ffi::DataType::C128>,
337396
ffi::Ffi::Bind()
338397
.Ctx<ffi::PlatformStream<cudaStream_t>>()
398+
.Ctx<ffi::ScratchAllocator>()
339399
.Attr<int64_t>("nside")
340400
.Attr<int64_t>("harmonic_band_limit")
341401
.Attr<bool>("reality")

lib/src/s2fft.cu

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ HRESULT s2fftExec<Complex>::Initialize(const s2fftDescriptor &descriptor) {
111111

112112
template <typename Complex>
113113
HRESULT s2fftExec<Complex>::Forward(const s2fftDescriptor &desc, cudaStream_t stream, Complex *data,
114-
Complex *workspace) {
114+
Complex *workspace, Complex *shift_scratch, bool use_out_of_place) {
115115
// Step 1: Determine the FFT direction (forward or inverse based on adjoint flag).
116116
const int DIRECTION = desc.adjoint ? CUFFT_INVERSE : CUFFT_FORWARD;
117117
// Step 2: Extract normalization, shift, and double precision flags from the descriptor.
@@ -143,15 +143,18 @@ HRESULT s2fftExec<Complex>::Forward(const s2fftDescriptor &desc, cudaStream_t st
143143
case s2fftKernels::fft_norm::NONE:
144144
case s2fftKernels::fft_norm::BACKWARD:
145145
// No normalization, only shift if required.
146-
s2fftKernels::launch_shift_normalize_kernel(stream, data, m_nside, shift, 2);
146+
s2fftKernels::launch_shift_normalize_kernel(stream, data, shift_scratch, m_nside, shift, 2,
147+
use_out_of_place);
147148
break;
148149
case s2fftKernels::fft_norm::FORWARD:
149150
// Normalize by sqrt(Npix).
150-
s2fftKernels::launch_shift_normalize_kernel(stream, data, m_nside, shift, 0);
151+
s2fftKernels::launch_shift_normalize_kernel(stream, data, shift_scratch, m_nside, shift, 0,
152+
use_out_of_place);
151153
break;
152154
case s2fftKernels::fft_norm::ORTHO:
153155
// Normalize by Npix.
154-
s2fftKernels::launch_shift_normalize_kernel(stream, data, m_nside, shift, 1);
156+
s2fftKernels::launch_shift_normalize_kernel(stream, data, shift_scratch, m_nside, shift, 1,
157+
use_out_of_place);
155158
break;
156159
default:
157160
return E_INVALIDARG; // Invalid normalization type.
@@ -162,7 +165,7 @@ HRESULT s2fftExec<Complex>::Forward(const s2fftDescriptor &desc, cudaStream_t st
162165

163166
template <typename Complex>
164167
HRESULT s2fftExec<Complex>::Backward(const s2fftDescriptor &desc, cudaStream_t stream, Complex *data,
165-
Complex *workspace) {
168+
Complex *workspace, Complex *shift_scratch, bool use_out_of_place) {
166169
// Step 1: Determine the FFT direction (forward or inverse based on adjoint flag).
167170
const int DIRECTION = desc.adjoint ? CUFFT_FORWARD : CUFFT_INVERSE;
168171
// Step 2: Extract normalization, shift, and double precision flags from the descriptor.
@@ -196,11 +199,13 @@ HRESULT s2fftExec<Complex>::Backward(const s2fftDescriptor &desc, cudaStream_t s
196199
break;
197200
case s2fftKernels::fft_norm::BACKWARD:
198201
// Normalize by sqrt(Npix).
199-
s2fftKernels::launch_shift_normalize_kernel(stream, data, m_nside, false, 0);
202+
s2fftKernels::launch_shift_normalize_kernel(stream, data, shift_scratch, m_nside, false, 0,
203+
use_out_of_place);
200204
break;
201205
case s2fftKernels::fft_norm::ORTHO:
202206
// Normalize by Npix.
203-
s2fftKernels::launch_shift_normalize_kernel(stream, data, m_nside, false, 1);
207+
s2fftKernels::launch_shift_normalize_kernel(stream, data, shift_scratch, m_nside, false, 1,
208+
use_out_of_place);
204209
break;
205210
default:
206211
return E_INVALIDARG; // Invalid normalization type.

0 commit comments

Comments
 (0)