22#include < cstddef>
33#include < complex>
44#include < type_traits>
5+ #include < cstdlib>
56
67namespace nb = nanobind;
78
@@ -62,21 +63,44 @@ constexpr bool is_double_v = is_double<T>::value;
6263 *
6364 * @tparam T The XLA data type (F32, F64, etc).
6465 * @param stream CUDA stream to use.
66+ * @param scratch ScratchAllocator for temporary device memory.
6567 * @param input Input buffer containing HEALPix pixel-space data.
6668 * @param output Output buffer to store the FTM result.
6769 * @param workspace Output buffer for temporary workspace memory.
6870 * @param descriptor Descriptor containing transform parameters.
6971 * @return ffi::Error indicating success or failure.
7072 */
7173template <ffi::DataType T>
72- ffi::Error healpix_forward (cudaStream_t stream, ffi::Buffer<T> input, ffi::Result<ffi::Buffer<T>> output,
73- ffi::Result<ffi::Buffer<T>> workspace, s2fftDescriptor descriptor) {
74+ ffi::Error healpix_forward (cudaStream_t stream, ffi::ScratchAllocator& scratch, ffi::Buffer<T> input,
75+ ffi::Result<ffi::Buffer<T>> output, ffi::Result<ffi::Buffer<T>> workspace,
76+ s2fftDescriptor descriptor) {
7477 // Step 1: Determine the complex type based on the XLA data type.
7578 using fft_complex_type = fft_complex_t <T>;
7679 const auto & dim_in = input.dimensions ();
7780
81+ // Step 1a: Parse environment variable for shift strategy (static for thread safety).
82+ static const std::string shift_strategy = []() {
83+ const char * env = std::getenv (" S2FFT_CUDA_SHIFT_STRATEGY" );
84+ return env ? std::string (env) : " in_place" ;
85+ }();
86+ bool use_out_of_place = (shift_strategy == " out_of_place" );
87+ bool is_batched = (dim_in.size () == 2 );
88+
89+ // Step 1b: Allocate scratch buffer if using out-of-place mode.
90+ fft_complex_type* shift_scratch = nullptr ;
91+ if (use_out_of_place && descriptor.shift ) {
92+ int64_t Npix = descriptor.nside * descriptor.nside * 12 ;
93+ int batch_count = is_batched ? dim_in[0 ] : 1 ;
94+ size_t scratch_size = Npix * sizeof (fft_complex_type) * batch_count;
95+ auto scratch_result = scratch.Allocate (scratch_size);
96+ if (!scratch_result.has_value ()) {
97+ return ffi::Error::Internal (" Failed to allocate scratch buffer for shift operation" );
98+ }
99+ shift_scratch = reinterpret_cast <fft_complex_type*>(scratch_result.value ());
100+ }
101+
78102 // Step 2: Handle batched and non-batched cases separately.
79- if (dim_in. size () == 2 ) {
103+ if (is_batched ) {
80104 // Step 2a: Batched case.
81105 int batch_count = dim_in[0 ];
82106 // Step 2b: Compute offsets for input and output for each batch.
@@ -104,7 +128,12 @@ ffi::Error healpix_forward(cudaStream_t stream, ffi::Buffer<T> input, ffi::Resul
104128 reinterpret_cast <fft_complex_type*>(workspace->typed_data () + i * executor->m_work_size );
105129
106130 // Step 2g: Launch the forward transform on this sub-stream.
107- executor->Forward (descriptor, sub_stream, data_c, workspace_c);
131+ fft_complex_type* shift_scratch_batch =
132+ use_out_of_place && shift_scratch
133+ ? shift_scratch + i * (descriptor.nside * descriptor.nside * 12 )
134+ : nullptr ;
135+ executor->Forward (descriptor, sub_stream, data_c, workspace_c, shift_scratch_batch,
136+ use_out_of_place);
108137 // Step 2h: Launch spectral extension kernel.
109138 s2fftKernels::launch_spectral_extension (data_c, out_c, descriptor.nside ,
110139 descriptor.harmonic_band_limit , sub_stream);
@@ -123,7 +152,7 @@ ffi::Error healpix_forward(cudaStream_t stream, ffi::Buffer<T> input, ffi::Resul
123152 auto executor = std::make_shared<s2fftExec<fft_complex_type>>();
124153 PlanCache::GetInstance ().GetS2FFTExec (descriptor, executor);
125154 // Step 2m: Launch the forward transform.
126- executor->Forward (descriptor, stream, data_c, workspace_c);
155+ executor->Forward (descriptor, stream, data_c, workspace_c, shift_scratch, use_out_of_place );
127156 // Step 2n: Launch spectral extension kernel.
128157 s2fftKernels::launch_spectral_extension (data_c, out_c, descriptor.nside ,
129158 descriptor.harmonic_band_limit , stream);
@@ -141,22 +170,45 @@ ffi::Error healpix_forward(cudaStream_t stream, ffi::Buffer<T> input, ffi::Resul
141170 *
142171 * @tparam T The XLA data type.
143172 * @param stream CUDA stream to use.
173+ * @param scratch ScratchAllocator for temporary device memory.
144174 * @param input Input buffer containing FTM data.
145175 * @param output Output buffer to store HEALPix pixel-space data.
146176 * @param workspace Output buffer for temporary workspace memory.
147177 * @param descriptor Descriptor containing transform parameters.
148178 * @return ffi::Error indicating success or failure.
149179 */
150180template <ffi::DataType T>
151- ffi::Error healpix_backward (cudaStream_t stream, ffi::Buffer<T> input, ffi::Result<ffi::Buffer<T>> output,
152- ffi::Result<ffi::Buffer<T>> workspace, s2fftDescriptor descriptor) {
181+ ffi::Error healpix_backward (cudaStream_t stream, ffi::ScratchAllocator& scratch, ffi::Buffer<T> input,
182+ ffi::Result<ffi::Buffer<T>> output, ffi::Result<ffi::Buffer<T>> workspace,
183+ s2fftDescriptor descriptor) {
153184 // Step 1: Determine the complex type based on the XLA data type.
154185 using fft_complex_type = fft_complex_t <T>;
155186 const auto & dim_in = input.dimensions ();
156187 const auto & dim_out = output->dimensions ();
157188
189+ // Step 1a: Parse environment variable for shift strategy (static for thread safety).
190+ static const std::string shift_strategy = []() {
191+ const char * env = std::getenv (" S2FFT_CUDA_SHIFT_STRATEGY" );
192+ return env ? std::string (env) : " in_place" ;
193+ }();
194+ bool use_out_of_place = (shift_strategy == " out_of_place" );
195+ bool is_batched = (dim_in.size () == 3 );
196+
197+ // Step 1b: Allocate scratch buffer if using out-of-place mode.
198+ fft_complex_type* shift_scratch = nullptr ;
199+ if (use_out_of_place && descriptor.shift ) {
200+ int64_t Npix = descriptor.nside * descriptor.nside * 12 ;
201+ int batch_count = is_batched ? dim_in[0 ] : 1 ;
202+ size_t scratch_size = Npix * sizeof (fft_complex_type) * batch_count;
203+ auto scratch_result = scratch.Allocate (scratch_size);
204+ if (!scratch_result.has_value ()) {
205+ return ffi::Error::Internal (" Failed to allocate scratch buffer for shift operation" );
206+ }
207+ shift_scratch = reinterpret_cast <fft_complex_type*>(scratch_result.value ());
208+ }
209+
158210 // Step 2: Handle batched and non-batched cases separately.
159- if (dim_in. size () == 3 ) {
211+ if (is_batched ) {
160212 // Step 2a: Batched case.
161213 // Assertions to ensure correct input/output dimensions for batched operations.
162214 assert (dim_out.size () == 2 );
@@ -191,7 +243,12 @@ ffi::Error healpix_backward(cudaStream_t stream, ffi::Buffer<T> input, ffi::Resu
191243 descriptor.harmonic_band_limit , descriptor.shift ,
192244 sub_stream);
193245 // Step 2h: Launch the backward transform on this sub-stream.
194- executor->Backward (descriptor, sub_stream, out_c, workspace_c);
246+ fft_complex_type* shift_scratch_batch =
247+ use_out_of_place && shift_scratch
248+ ? shift_scratch + i * (descriptor.nside * descriptor.nside * 12 )
249+ : nullptr ;
250+ executor->Backward (descriptor, sub_stream, out_c, workspace_c, shift_scratch_batch,
251+ use_out_of_place);
195252 }
196253 // Step 2i: Join all forked streams back to the main stream.
197254 handler.join (stream);
@@ -213,7 +270,7 @@ ffi::Error healpix_backward(cudaStream_t stream, ffi::Buffer<T> input, ffi::Resu
213270 s2fftKernels::launch_spectral_folding (data_c, out_c, descriptor.nside , descriptor.harmonic_band_limit ,
214271 descriptor.shift , stream);
215272 // Step 2n: Launch the backward transform.
216- executor->Backward (descriptor, stream, out_c, workspace_c);
273+ executor->Backward (descriptor, stream, out_c, workspace_c, shift_scratch, use_out_of_place );
217274 return ffi::Error::Success ();
218275 }
219276}
@@ -298,19 +355,20 @@ s2fftDescriptor build_descriptor(int64_t nside, int64_t harmonic_band_limit, boo
298355 * @return ffi::Error indicating success or failure.
299356 */
300357template <ffi::DataType T>
301- ffi::Error healpix_fft_cuda (cudaStream_t stream, int64_t nside, int64_t harmonic_band_limit, bool reality,
302- bool forward, bool normalize, bool adjoint, ffi::Buffer<T> input,
303- ffi::Result<ffi::Buffer<T>> output, ffi::Result<ffi::Buffer<T>> workspace) {
358+ ffi::Error healpix_fft_cuda (cudaStream_t stream, ffi::ScratchAllocator scratch, int64_t nside,
359+ int64_t harmonic_band_limit, bool reality, bool forward, bool normalize,
360+ bool adjoint, ffi::Buffer<T> input, ffi::Result<ffi::Buffer<T>> output,
361+ ffi::Result<ffi::Buffer<T>> workspace) {
304362 // Step 1: Build the s2fftDescriptor based on the input parameters.
305363 size_t work_size = 0 ; // Variable to hold the workspace size
306364 s2fftDescriptor descriptor = build_descriptor<T>(nside, harmonic_band_limit, reality, forward, normalize,
307365 adjoint, true , work_size);
308366
309367 // Step 2: Dispatch to either forward or backward transform based on the 'forward' flag.
310368 if (forward) {
311- return healpix_forward<T>(stream, input, output, workspace, descriptor);
369+ return healpix_forward<T>(stream, scratch, input, output, workspace, descriptor);
312370 } else {
313- return healpix_backward<T>(stream, input, output, workspace, descriptor);
371+ return healpix_backward<T>(stream, scratch, input, output, workspace, descriptor);
314372 }
315373}
316374
@@ -323,6 +381,7 @@ ffi::Error healpix_fft_cuda(cudaStream_t stream, int64_t nside, int64_t harmonic
323381XLA_FFI_DEFINE_HANDLER_SYMBOL (healpix_fft_cuda_C64, healpix_fft_cuda<ffi::DataType::C64>,
324382 ffi::Ffi::Bind ()
325383 .Ctx<ffi::PlatformStream<cudaStream_t>>()
384+ .Ctx<ffi::ScratchAllocator>()
326385 .Attr<int64_t>(" nside" )
327386 .Attr<int64_t>(" harmonic_band_limit" )
328387 .Attr<bool>(" reality" )
@@ -336,6 +395,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(healpix_fft_cuda_C64, healpix_fft_cuda<ffi::DataTy
336395XLA_FFI_DEFINE_HANDLER_SYMBOL (healpix_fft_cuda_C128, healpix_fft_cuda<ffi::DataType::C128>,
337396 ffi::Ffi::Bind ()
338397 .Ctx<ffi::PlatformStream<cudaStream_t>>()
398+ .Ctx<ffi::ScratchAllocator>()
339399 .Attr<int64_t>(" nside" )
340400 .Attr<int64_t>(" harmonic_band_limit" )
341401 .Attr<bool>(" reality" )
0 commit comments