Skip to content

Commit 8fe86c2

Browse files
committed
Update healpix_ffts to use new FFI lowered cuda healpix ffts
1 parent 2b591ca commit 8fe86c2

File tree

3 files changed

+80
-27
lines changed

3 files changed

+80
-27
lines changed

s2fft/transforms/spherical.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -82,14 +82,25 @@ def inverse(
8282
recover acceleration by the number of devices.
8383
8484
"""
85-
if spin >= 8 and method in ["numpy", "jax"]:
85+
if spin >= 8 and method in ["numpy", "jax", "cuda"]:
8686
raise Warning("Recursive transform may provide lower precision beyond spin ~ 8")
8787

8888
if method == "numpy":
8989
return inverse_numpy(flm, L, spin, nside, sampling, reality, precomps, L_lower)
90-
elif method == "jax":
90+
elif method in ["jax", "cuda"]:
91+
use_healpix_custom_primitive = method == "cuda"
92+
method = "jax"
9193
return inverse_jax(
92-
flm, L, spin, nside, sampling, reality, precomps, spmd, L_lower
94+
flm,
95+
L,
96+
spin,
97+
nside,
98+
sampling,
99+
reality,
100+
precomps,
101+
spmd,
102+
L_lower,
103+
use_healpix_custom_primitive,
93104
)
94105
elif method == "jax_ssht":
95106
if sampling.lower() == "healpix":
@@ -205,7 +216,7 @@ def inverse_numpy(
205216
return np.fft.ifft(np.fft.ifftshift(ftm, axes=1), axis=1, norm="forward")
206217

207218

208-
@partial(jit, static_argnums=(1, 3, 4, 5, 7, 8))
219+
@partial(jit, static_argnums=(1, 3, 4, 5, 7, 8, 9))
209220
def inverse_jax(
210221
flm: jnp.ndarray,
211222
L: int,
@@ -216,6 +227,7 @@ def inverse_jax(
216227
precomps: List = None,
217228
spmd: bool = False,
218229
L_lower: int = 0,
230+
use_healpix_custom_primitive: bool = False,
219231
) -> jnp.ndarray:
220232
r"""
221233
Compute the inverse spin-spherical harmonic transform (JAX).
@@ -251,6 +263,12 @@ def inverse_jax(
251263
L_lower (int, optional): Harmonic lower-bound. Transform will only be computed
252264
for :math:`\texttt{L_lower} \leq \ell < \texttt{L}`. Defaults to 0.
253265
266+
use_healpix_custom_primitive (bool, optional): Whether to use a custom CUDA
267+
primitive for computing HEALPix fast fourier transform when `sampling =
268+
"healpix"` and running on a cuda compatible gpu device. using a custom
269+
primitive reduces long compilation times when jit compiling. defaults to
270+
`False`.
271+
254272
Returns:
255273
jnp.ndarray: Signal on the sphere.
256274
@@ -326,7 +344,10 @@ def f_bwd(res, gtm):
326344
jnp.flip(jnp.conj(ftm[:, L - 1 + m_offset + 1 :]), axis=-1)
327345
)
328346
if sampling.lower() == "healpix":
329-
return hp.healpix_ifft(ftm, L, nside, "jax")
347+
if use_healpix_custom_primitive:
348+
return hp.healpix_ifft(ftm, L, nside, "cuda")
349+
else:
350+
return hp.healpix_ifft(ftm, L, nside, "jax")
330351
else:
331352
ftm = jnp.conj(jnp.fft.ifftshift(ftm, axes=1))
332353
f = jnp.conj(jnp.fft.fft(ftm, axis=1, norm="backward"))
@@ -406,7 +427,7 @@ def forward(
406427
recover acceleration by the number of devices.
407428
408429
"""
409-
if spin >= 8 and method in ["numpy", "jax"]:
430+
if spin >= 8 and method in ["numpy", "jax", "cuda"]:
410431
raise Warning("Recursive transform may provide lower precision beyond spin ~ 8")
411432

412433
if iter is None:

s2fft/transforms/wigner.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,20 @@ def inverse(
8686

8787
if method == "numpy":
8888
return inverse_numpy(flmn, L, N, nside, sampling, reality, precomps, L_lower)
89-
elif method == "jax":
90-
return inverse_jax(flmn, L, N, nside, sampling, reality, precomps, L_lower)
89+
elif method in ["jax", "cuda"]:
90+
use_healpix_custom_primitive = method == "cuda"
91+
method = "jax"
92+
return inverse_jax(
93+
flmn,
94+
L,
95+
N,
96+
nside,
97+
sampling,
98+
reality,
99+
precomps,
100+
L_lower,
101+
use_healpix_custom_primitive,
102+
)
91103
elif method == "jax_ssht":
92104
if sampling.lower() == "healpix":
93105
raise ValueError("SSHT does not support healpix sampling.")

s2fft/utils/healpix_ffts.py

Lines changed: 39 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from functools import partial
22

3+
import jax
34
import jax.numpy as jnp
45
import jaxlib.mlir.ir as ir
56
import numpy as np
@@ -8,8 +9,6 @@
89

910
# did not find promote_dtypes_complex outside _src
1011
from jax._src.numpy.util import promote_dtypes_complex
11-
from jax.lib import xla_client
12-
from jaxlib.hlo_helpers import custom_call
1312
from s2fft_lib import _s2fft
1413

1514
from s2fft.sampling import s2_samples as samples
@@ -703,16 +702,18 @@ def _healpix_fft_cuda_abstract(f, L, nside, reality, fft_type, norm):
703702
assert f.shape == healpix_size
704703
return f.update(shape=ftm_size, dtype=f.dtype)
705704
elif fft_type == "backward":
706-
print(f"f.shape {f.shape}")
707705
assert f.shape == ftm_size
708706
return f.update(shape=healpix_size, dtype=f.dtype)
709707
else:
710708
raise ValueError(f"fft_type {fft_type} not recognised.")
711709

712710

713711
def _healpix_fft_cuda_lowering(ctx, f, *, L, nside, reality, fft_type, norm):
712+
assert _s2fft.COMPILED_WITH_CUDA, """
713+
S2FFT was compiled without CUDA support. Cuda functions are not supported.
714+
Please make sure that nvcc is in your path and $CUDA_HOME is set then reinstall s2fft using pip.
715+
"""
714716
(aval_out,) = ctx.avals_out
715-
a_type = ir.RankedTensorType(f.type)
716717

717718
out_dtype = aval_out.dtype
718719
if out_dtype == np.complex64:
@@ -734,34 +735,53 @@ def _healpix_fft_cuda_lowering(ctx, f, *, L, nside, reality, fft_type, norm):
734735
else:
735736
raise ValueError(f"Unknown norm {norm}")
736737

737-
descriptor = _s2fft.build_healpix_fft_descriptor(
738-
nside, L, reality, forward, normalize, is_double
738+
if is_double:
739+
ffi_lowered = jax.ffi.ffi_lowering("healpix_fft_cuda_c128")
740+
else:
741+
ffi_lowered = jax.ffi.ffi_lowering("healpix_fft_cuda_c64")
742+
743+
return ffi_lowered(
744+
ctx,
745+
f,
746+
nside=nside,
747+
harmonic_band_limit=L,
748+
reality=reality,
749+
normalize=normalize,
750+
forward=forward,
739751
)
740752

741-
layout = tuple(range(len(a_type.shape) - 1, -1, -1))
742-
out_layout = tuple(range(len(out_type.shape) - 1, -1, -1))
743-
744-
result = custom_call(
745-
"healpix_fft_cuda",
746-
result_types=[out_type],
747-
operands=[f],
748-
operand_layouts=[layout],
749-
result_layouts=[out_layout],
750-
has_side_effect=True,
751-
backend_config=descriptor,
753+
754+
def _healpix_fft_cuda_transpose(
755+
df: jnp.ndarray, L: int, nside: int, reality: bool, fft_type: str, norm: str
756+
) -> jnp.ndarray:
757+
scale_factors = (
758+
jnp.concatenate((jnp.ones(L), 2 * jnp.ones(L * (L - 1) // 2)))
759+
* (3 * nside**2)
760+
/ jnp.pi
752761
)
753-
return result.results
762+
if fft_type == "forward":
763+
return (
764+
scale_factors
765+
* jnp.conj(healpix_ifft_cuda(jnp.conj(df), L, nside, reality, norm)),
766+
)
767+
elif fft_type == "backward":
768+
return (
769+
scale_factors
770+
* jnp.conj(healpix_fft_cuda(jnp.conj(df), L, nside, reality, norm)),
771+
)
754772

755773

756774
# Register healpfix_fft_cuda custom call target
757775
for name, fn in _s2fft.registration().items():
758-
xla_client.register_custom_call_target(name, fn, platform="gpu")
776+
jax.ffi.register_ffi_target(name, fn, platform="CUDA")
759777

760778
_healpix_fft_cuda_primitive = register_primitive(
761779
"healpix_fft_cuda",
762780
multiple_results=False,
763781
abstract_evaluation=_healpix_fft_cuda_abstract,
764782
lowering_per_platform={None: _healpix_fft_cuda_lowering},
783+
transpose=_healpix_fft_cuda_transpose,
784+
is_linear=True,
765785
)
766786

767787

0 commit comments

Comments
 (0)