@@ -65,14 +65,12 @@ constexpr bool is_double_v = is_double<T>::value;
6565 * @param input Input buffer containing HEALPix pixel-space data.
6666 * @param output Output buffer to store the FTM result.
6767 * @param workspace Output buffer for temporary workspace memory.
68- * @param callback_params Output buffer for callback parameters.
6968 * @param descriptor Descriptor containing transform parameters.
7069 * @return ffi::Error indicating success or failure.
7170 */
7271template <ffi::DataType T>
7372ffi::Error healpix_forward (cudaStream_t stream, ffi::Buffer<T> input, ffi::Result<ffi::Buffer<T>> output,
7473 ffi::Result<ffi::Buffer<T>> workspace,
75- ffi::Result<ffi::Buffer<ffi::DataType::S64>> callback_params,
7674 s2fftDescriptor descriptor) {
7775 // Step 1: Determine the complex type based on the XLA data type.
7876 using fft_complex_type = fft_complex_t <T>;
@@ -82,10 +80,9 @@ ffi::Error healpix_forward(cudaStream_t stream, ffi::Buffer<T> input, ffi::Resul
8280 if (dim_in.size () == 2 ) {
8381 // Step 2a: Batched case.
8482 int batch_count = dim_in[0 ];
85- // Step 2b: Compute offsets for input, output, and callback parameters for each batch.
83+ // Step 2b: Compute offsets for input and output for each batch.
8684 int64_t input_offset = descriptor.nside * descriptor.nside * 12 ;
8785 int64_t output_offset = (4 * descriptor.nside - 1 ) * (2 * descriptor.harmonic_band_limit );
88- int64_t params_offset = 2 * (descriptor.nside - 1 ) + 1 ;
8986
9087 // Step 2c: Fork CUDA streams for parallel processing of batches.
9188 CudaStreamHandler handler;
@@ -99,16 +96,13 @@ ffi::Error healpix_forward(cudaStream_t stream, ffi::Buffer<T> input, ffi::Resul
9996 auto executor = std::make_shared<s2fftExec<fft_complex_type>>();
10097 PlanCache::GetInstance ().GetS2FFTExec (descriptor, executor);
10198
102- // Step 2f: Calculate device pointers for the current batch's data, output, workspace, and
103- // callback parameters.
99+ // Step 2f: Calculate device pointers for the current batch's data, output, and workspace.
104100 fft_complex_type* data_c =
105101 reinterpret_cast <fft_complex_type*>(input.typed_data () + i * input_offset);
106102 fft_complex_type* out_c =
107103 reinterpret_cast <fft_complex_type*>(output->typed_data () + i * output_offset);
108104 fft_complex_type* workspace_c =
109105 reinterpret_cast <fft_complex_type*>(workspace->typed_data () + i * executor->m_work_size );
110- int64* callback_params_c =
111- reinterpret_cast <int64*>(callback_params->typed_data () + i * params_offset);
112106
113107 // Step 2g: Launch the forward transform on this sub-stream.
114108 executor->Forward (descriptor, sub_stream, data_c, workspace_c);
@@ -121,11 +115,10 @@ ffi::Error healpix_forward(cudaStream_t stream, ffi::Buffer<T> input, ffi::Resul
121115 return ffi::Error::Success ();
122116 } else {
123117 // Step 2j: Non-batched case.
124- // Step 2k: Get device pointers for data, output, workspace, and callback parameters .
118+ // Step 2k: Get device pointers for data, output, and workspace .
125119 fft_complex_type* data_c = reinterpret_cast <fft_complex_type*>(input.typed_data ());
126120 fft_complex_type* out_c = reinterpret_cast <fft_complex_type*>(output->typed_data ());
127121 fft_complex_type* workspace_c = reinterpret_cast <fft_complex_type*>(workspace->typed_data ());
128- int64* callback_params_c = reinterpret_cast <int64*>(callback_params->typed_data ());
129122
130123 // Step 2l: Get or create an s2fftExec instance from the PlanCache.
131124 auto executor = std::make_shared<s2fftExec<fft_complex_type>>();
@@ -152,14 +145,12 @@ ffi::Error healpix_forward(cudaStream_t stream, ffi::Buffer<T> input, ffi::Resul
152145 * @param input Input buffer containing FTM data.
153146 * @param output Output buffer to store HEALPix pixel-space data.
154147 * @param workspace Output buffer for temporary workspace memory.
155- * @param callback_params Output buffer for callback parameters.
156148 * @param descriptor Descriptor containing transform parameters.
157149 * @return ffi::Error indicating success or failure.
158150 */
159151template <ffi::DataType T>
160152ffi::Error healpix_backward (cudaStream_t stream, ffi::Buffer<T> input, ffi::Result<ffi::Buffer<T>> output,
161153 ffi::Result<ffi::Buffer<T>> workspace,
162- ffi::Result<ffi::Buffer<ffi::DataType::S64>> callback_params,
163154 s2fftDescriptor descriptor) {
164155 // Step 1: Determine the complex type based on the XLA data type.
165156 using fft_complex_type = fft_complex_t <T>;
@@ -189,16 +180,13 @@ ffi::Error healpix_backward(cudaStream_t stream, ffi::Buffer<T> input, ffi::Resu
189180 auto executor = std::make_shared<s2fftExec<fft_complex_type>>();
190181 PlanCache::GetInstance ().GetS2FFTExec (descriptor, executor);
191182
192- // Step 2f: Calculate device pointers for the current batch's data, output, workspace, and
193- // callback parameters.
183+ // Step 2f: Calculate device pointers for the current batch's data, output, and workspace.
194184 fft_complex_type* data_c =
195185 reinterpret_cast <fft_complex_type*>(input.typed_data () + i * input_offset);
196186 fft_complex_type* out_c =
197187 reinterpret_cast <fft_complex_type*>(output->typed_data () + i * output_offset);
198188 fft_complex_type* workspace_c =
199189 reinterpret_cast <fft_complex_type*>(workspace->typed_data () + i * executor->m_work_size );
200- int64* callback_params_c =
201- reinterpret_cast <int64*>(callback_params->typed_data () + i * sizeof (int64) * 2 );
202190
203191 // Step 2g: Launch spectral folding kernel.
204192 s2fftKernels::launch_spectral_folding (data_c, out_c, descriptor.nside ,
@@ -215,11 +203,10 @@ ffi::Error healpix_backward(cudaStream_t stream, ffi::Buffer<T> input, ffi::Resu
215203 // Assertions to ensure correct input/output dimensions for non-batched operations.
216204 assert (dim_in.size () == 2 );
217205 assert (dim_out.size () == 1 );
218- // Step 2k: Get device pointers for data, output, workspace, and callback parameters .
206+ // Step 2k: Get device pointers for data, output, and workspace .
219207 fft_complex_type* data_c = reinterpret_cast <fft_complex_type*>(input.typed_data ());
220208 fft_complex_type* out_c = reinterpret_cast <fft_complex_type*>(output->typed_data ());
221209 fft_complex_type* workspace_c = reinterpret_cast <fft_complex_type*>(workspace->typed_data ());
222- int64* callback_params_c = reinterpret_cast <int64*>(callback_params->typed_data ());
223210
224211 // Step 2l: Get or create an s2fftExec instance from the PlanCache.
225212 auto executor = std::make_shared<s2fftExec<fft_complex_type>>();
@@ -310,24 +297,22 @@ s2fftDescriptor build_descriptor(int64_t nside, int64_t harmonic_band_limit, boo
310297 * @param input Input buffer.
311298 * @param output Output buffer.
312299 * @param workspace Output buffer for temporary workspace memory.
313- * @param callback_params Output buffer for callback parameters.
314300 * @return ffi::Error indicating success or failure.
315301 */
316302template <ffi::DataType T>
317303ffi::Error healpix_fft_cuda (cudaStream_t stream, int64_t nside, int64_t harmonic_band_limit, bool reality,
318304 bool forward, bool normalize, bool adjoint, ffi::Buffer<T> input,
319- ffi::Result<ffi::Buffer<T>> output, ffi::Result<ffi::Buffer<T>> workspace,
320- ffi::Result<ffi::Buffer<ffi::DataType::S64>> callback_params) {
305+ ffi::Result<ffi::Buffer<T>> output, ffi::Result<ffi::Buffer<T>> workspace) {
321306 // Step 1: Build the s2fftDescriptor based on the input parameters.
322307 size_t work_size = 0 ; // Variable to hold the workspace size
323308 s2fftDescriptor descriptor = build_descriptor<T>(nside, harmonic_band_limit, reality, forward, normalize,
324309 adjoint, true , work_size);
325310
326311 // Step 2: Dispatch to either forward or backward transform based on the 'forward' flag.
327312 if (forward) {
328- return healpix_forward<T>(stream, input, output, workspace, callback_params, descriptor);
313+ return healpix_forward<T>(stream, input, output, workspace, descriptor);
329314 } else {
330- return healpix_backward<T>(stream, input, output, workspace, callback_params, descriptor);
315+ return healpix_backward<T>(stream, input, output, workspace, descriptor);
331316 }
332317}
333318
@@ -348,8 +333,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(healpix_fft_cuda_C64, healpix_fft_cuda<ffi::DataTy
348333 .Attr<bool>(" adjoint" )
349334 .Arg<ffi::Buffer<ffi::DataType::C64>>()
350335 .Ret<ffi::Buffer<ffi::DataType::C64>>()
351- .Ret<ffi::Buffer<ffi::DataType::C64>>()
352- .Ret<ffi::Buffer<ffi::DataType::S64>>());
336+ .Ret<ffi::Buffer<ffi::DataType::C64>>());
353337
354338XLA_FFI_DEFINE_HANDLER_SYMBOL (healpix_fft_cuda_C128, healpix_fft_cuda<ffi::DataType::C128>,
355339 ffi::Ffi::Bind ()
@@ -362,8 +346,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(healpix_fft_cuda_C128, healpix_fft_cuda<ffi::DataT
362346 .Attr<bool>(" adjoint" )
363347 .Arg<ffi::Buffer<ffi::DataType::C128>>()
364348 .Ret<ffi::Buffer<ffi::DataType::C128>>()
365- .Ret<ffi::Buffer<ffi::DataType::C128>>()
366- .Ret<ffi::Buffer<ffi::DataType::S64>>());
349+ .Ret<ffi::Buffer<ffi::DataType::C128>>());
367350
368351/* *
369352 * @brief Encapsulates an FFI handler into a nanobind capsule.
0 commit comments