Skip to content

Commit d29af9b

Browse files
committed
format
1 parent fd7860e commit d29af9b

File tree

2 files changed

+10
-6
lines changed

2 files changed

+10
-6
lines changed

lib/src/extensions.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ ffi::Error healpix_backward(cudaStream_t stream, ffi::Buffer<T> input, ffi::Resu
251251
*/
252252
template <ffi::DataType T>
253253
s2fftDescriptor build_descriptor(int64_t nside, int64_t harmonic_band_limit, bool reality, bool forward,
254-
bool normalize, bool adjoint, bool must_exist , size_t& work_size) {
254+
bool normalize, bool adjoint, bool must_exist, size_t& work_size) {
255255
using fft_complex_type = fft_complex_t<T>;
256256
// Step 1: Determine FFT normalization type based on forward/normalize flags.
257257
s2fftKernels::fft_norm norm = s2fftKernels::fft_norm::NONE;
@@ -285,7 +285,7 @@ s2fftDescriptor build_descriptor(int64_t nside, int64_t harmonic_band_limit, boo
285285
if (hr == S_OK) {
286286
executor->Initialize(descriptor);
287287
}
288-
// Make sure workspace is set
288+
// Make sure workspace is set
289289
assert(executor->m_work_size > 0 && "S2FFT INTERNAL ERROR: Workspace size is zero after initialization.");
290290
work_size = executor->m_work_size;
291291
// Step 7: Return the created descriptor.
@@ -320,8 +320,8 @@ ffi::Error healpix_fft_cuda(cudaStream_t stream, int64_t nside, int64_t harmonic
320320
ffi::Result<ffi::Buffer<ffi::DataType::S64>> callback_params) {
321321
// Step 1: Build the s2fftDescriptor based on the input parameters.
322322
size_t work_size = 0; // Variable to hold the workspace size
323-
s2fftDescriptor descriptor =
324-
build_descriptor<T>(nside, harmonic_band_limit, reality, forward, normalize, adjoint, true , work_size);
323+
s2fftDescriptor descriptor = build_descriptor<T>(nside, harmonic_band_limit, reality, forward, normalize,
324+
adjoint, true, work_size);
325325

326326
// Step 2: Dispatch to either forward or backward transform based on the 'forward' flag.
327327
if (forward) {

s2fft/utils/healpix_ffts.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -587,6 +587,7 @@ def _healpix_fft_cuda_abstract(f, L, nside, reality, fft_type, norm, adjoint):
587587
588588
Returns:
589589
Tuple of ShapedArray objects for output, workspace, and callback parameters.
590+
590591
"""
591592
# Step 1: Get lowering information (double precision, forward/backward, normalize).
592593
is_double, forward, normalize = _get_lowering_info(fft_type, norm, f.dtype)
@@ -632,7 +633,7 @@ def _healpix_fft_cuda_abstract(f, L, nside, reality, fft_type, norm, adjoint):
632633
shape=batch_shape + workspace_shape, dtype=workspace_dtype
633634
)
634635
params_eval = ShapedArray(shape=batch_shape + params_shape, dtype=np.int64)
635-
636+
636637
# Step 6: Return the ShapedArray objects.
637638
return (
638639
f.update(shape=out_shape, dtype=f.dtype),
@@ -666,6 +667,7 @@ def _healpix_fft_cuda_lowering(ctx, f, *, L, nside, reality, fft_type, norm, adj
666667
667668
Returns:
668669
The result of the FFI call.
670+
669671
"""
670672
# Step 1: Check if CUDA support is compiled in.
671673
if not _s2fft.COMPILED_WITH_CUDA:
@@ -715,6 +717,7 @@ def _healpix_fft_cuda_batching_rule(
715717
716718
Returns:
717719
Tuple of (output, output_batch_axes).
720+
718721
"""
719722
# Step 1: Unpack batched arguments and batching axes.
720723
(x,) = batched_args
@@ -772,6 +775,7 @@ def _healpix_fft_cuda_transpose(
772775
773776
Returns:
774777
The adjoint of the input.
778+
775779
"""
776780
# Step 1: Invert the FFT type and normalization for the adjoint operation.
777781
fft_type = "backward" if fft_type == "forward" else "forward"
@@ -901,4 +905,4 @@ def healpix_ifft_cuda(
901905
"jax": healpix_ifft_jax,
902906
"cuda": healpix_ifft_cuda,
903907
"torch": healpix_ifft_torch,
904-
}
908+
}

0 commit comments

Comments
 (0)