@@ -587,6 +587,7 @@ def _healpix_fft_cuda_abstract(f, L, nside, reality, fft_type, norm, adjoint):
587587
588588 Returns:
589589 Tuple of ShapedArray objects for output, workspace, and callback parameters.
590+
590591 """
591592 # Step 1: Get lowering information (double precision, forward/backward, normalize).
592593 is_double , forward , normalize = _get_lowering_info (fft_type , norm , f .dtype )
@@ -632,7 +633,7 @@ def _healpix_fft_cuda_abstract(f, L, nside, reality, fft_type, norm, adjoint):
632633 shape = batch_shape + workspace_shape , dtype = workspace_dtype
633634 )
634635 params_eval = ShapedArray (shape = batch_shape + params_shape , dtype = np .int64 )
635-
636+
636637 # Step 6: Return the ShapedArray objects.
637638 return (
638639 f .update (shape = out_shape , dtype = f .dtype ),
@@ -666,6 +667,7 @@ def _healpix_fft_cuda_lowering(ctx, f, *, L, nside, reality, fft_type, norm, adj
666667
667668 Returns:
668669 The result of the FFI call.
670+
669671 """
670672 # Step 1: Check if CUDA support is compiled in.
671673 if not _s2fft .COMPILED_WITH_CUDA :
@@ -715,6 +717,7 @@ def _healpix_fft_cuda_batching_rule(
715717
716718 Returns:
717719 Tuple of (output, output_batch_axes).
720+
718721 """
719722 # Step 1: Unpack batched arguments and batching axes.
720723 (x ,) = batched_args
@@ -772,6 +775,7 @@ def _healpix_fft_cuda_transpose(
772775
773776 Returns:
774777 The adjoint of the input.
778+
775779 """
776780 # Step 1: Invert the FFT type and normalization for the adjoint operation.
777781 fft_type = "backward" if fft_type == "forward" else "forward"
@@ -901,4 +905,4 @@ def healpix_ifft_cuda(
901905 "jax" : healpix_ifft_jax ,
902906 "cuda" : healpix_ifft_cuda ,
903907 "torch" : healpix_ifft_torch ,
904- }
908+ }
0 commit comments