Skip to content

Commit 9e0f121

Browse files
committed
Update JAX binding layer
1 parent b5cbeac commit 9e0f121

File tree

1 file changed

+69
-23
lines changed

1 file changed

+69
-23
lines changed

s2fft/utils/healpix_ffts.py

Lines changed: 69 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
# did not find promote_dtypes_complex outside _src
1111
from jax._src.numpy.util import promote_dtypes_complex
12+
from jax.interpreters import batching
1213
from s2fft_lib import _s2fft
1314

1415
from s2fft.sampling import s2_samples as samples
@@ -692,23 +693,25 @@ def ring_phase_shifts_hp_jax(
692693
# Custom healpix_fft_cuda primitive
693694

694695

695-
def _healpix_fft_cuda_abstract(f, L, nside, reality, fft_type, norm):
696+
def _healpix_fft_cuda_abstract(f, L, nside, reality, fft_type, norm, adjoint):
696697
# For the forward pass, the input is a HEALPix pixel-space array of size nside^2 *
697698
# 12 and the output is a FTM array of shape (number of rings , width of FTM slice)
698699
# which is (4 * nside - 1 , 2 * L )
699700
healpix_size = (nside**2 * 12,)
700701
ftm_size = (4 * nside - 1, 2 * L)
701702
if fft_type == "forward":
702-
assert f.shape == healpix_size
703-
return f.update(shape=ftm_size, dtype=f.dtype)
703+
batch_shape = (f.shape[0],) if f.ndim == 2 else ()
704+
assert (f.shape[-1],) == healpix_size
705+
return f.update(shape=batch_shape + ftm_size, dtype=f.dtype)
704706
elif fft_type == "backward":
705-
assert f.shape == ftm_size
706-
return f.update(shape=healpix_size, dtype=f.dtype)
707+
batch_shape = (f.shape[0],) if f.ndim == 3 else ()
708+
assert f.shape[-2:] == ftm_size
709+
return f.update(shape=batch_shape + healpix_size, dtype=f.dtype)
707710
else:
708711
raise ValueError(f"fft_type {fft_type} not recognised.")
709712

710713

711-
def _healpix_fft_cuda_lowering(ctx, f, *, L, nside, reality, fft_type, norm):
714+
def _healpix_fft_cuda_lowering(ctx, f, *, L, nside, reality, fft_type, norm, adjoint):
712715
assert _s2fft.COMPILED_WITH_CUDA, """
713716
S2FFT was compiled without CUDA support. Cuda functions are not supported.
714717
Please make sure that nvcc is in your path and $CUDA_HOME is set then reinstall s2fft using pip.
@@ -748,27 +751,57 @@ def _healpix_fft_cuda_lowering(ctx, f, *, L, nside, reality, fft_type, norm):
748751
reality=reality,
749752
normalize=normalize,
750753
forward=forward,
754+
adjoint=adjoint,
751755
)
752756

753757

758+
def _healpix_fft_cuda_batching_rule(
759+
batched_args, batched_axis, L, nside, reality, fft_type, norm, adjoint
760+
):
761+
(x,) = batched_args
762+
(bd,) = batched_axis
763+
764+
if fft_type == "forward":
765+
assert x.ndim == 2
766+
elif fft_type == "backward":
767+
assert x.ndim == 3
768+
else:
769+
raise ValueError(f"fft_type {fft_type} not recognised.")
770+
771+
x = batching.moveaxis(x, bd, 0)
772+
return _healpix_fft_cuda_primitive.bind(
773+
x,
774+
L=L,
775+
nside=nside,
776+
reality=reality,
777+
fft_type=fft_type,
778+
norm=norm,
779+
adjoint=adjoint,
780+
), 0
781+
782+
754783
def _healpix_fft_cuda_transpose(
755-
df: jnp.ndarray, L: int, nside: int, reality: bool, fft_type: str, norm: str
784+
df: jnp.ndarray,
785+
L: int,
786+
nside: int,
787+
reality: bool,
788+
fft_type: str,
789+
norm: str,
790+
adjoint: bool,
756791
) -> jnp.ndarray:
757-
scale_factors = (
758-
jnp.concatenate((jnp.ones(L), 2 * jnp.ones(L * (L - 1) // 2)))
759-
* (3 * nside**2)
760-
/ jnp.pi
792+
fft_type = "backward" if fft_type == "forward" else "forward"
793+
norm = "backward" if norm == "forward" else "forward"
794+
return (
795+
_healpix_fft_cuda_primitive.bind(
796+
df,
797+
L=L,
798+
nside=nside,
799+
reality=reality,
800+
fft_type=fft_type,
801+
norm=norm,
802+
adjoint=not adjoint,
803+
),
761804
)
762-
if fft_type == "forward":
763-
return (
764-
scale_factors
765-
* jnp.conj(healpix_ifft_cuda(jnp.conj(df), L, nside, reality, norm)),
766-
)
767-
elif fft_type == "backward":
768-
return (
769-
scale_factors
770-
* jnp.conj(healpix_fft_cuda(jnp.conj(df), L, nside, reality, norm)),
771-
)
772805

773806

774807
# Register healpfix_fft_cuda custom call target
@@ -781,6 +814,7 @@ def _healpix_fft_cuda_transpose(
781814
abstract_evaluation=_healpix_fft_cuda_abstract,
782815
lowering_per_platform={None: _healpix_fft_cuda_lowering},
783816
transpose=_healpix_fft_cuda_transpose,
817+
batcher=_healpix_fft_cuda_batching_rule,
784818
is_linear=True,
785819
)
786820

@@ -811,7 +845,13 @@ def healpix_fft_cuda(
811845
"""
812846
(f,) = promote_dtypes_complex(f)
813847
return _healpix_fft_cuda_primitive.bind(
814-
f, L=L, nside=nside, reality=reality, fft_type="forward", norm=norm
848+
f,
849+
L=L,
850+
nside=nside,
851+
reality=reality,
852+
fft_type="forward",
853+
norm=norm,
854+
adjoint=False,
815855
)
816856

817857

@@ -841,5 +881,11 @@ def healpix_ifft_cuda(
841881
"""
842882
(ftm,) = promote_dtypes_complex(ftm)
843883
return _healpix_fft_cuda_primitive.bind(
844-
ftm, L=L, nside=nside, reality=reality, fft_type="backward", norm=norm
884+
ftm,
885+
L=L,
886+
nside=nside,
887+
reality=reality,
888+
fft_type="backward",
889+
norm=norm,
890+
adjoint=False,
845891
)

0 commit comments

Comments
 (0)