Skip to content

Commit 0e03787

Browse files
committed
Implement requested changes
1 parent a70b262 commit 0e03787

File tree

2 files changed

+27
-17
lines changed

2 files changed

+27
-17
lines changed

s2fft/utils/healpix_ffts.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -711,11 +711,18 @@ def _healpix_fft_cuda_abstract(f, L, nside, reality, fft_type, norm, adjoint):
711711
raise ValueError(f"fft_type {fft_type} not recognised.")
712712

713713

714+
class MissingCUDASupport(Exception): # noqa : D107
715+
def __init__(self): # noqa : D107
716+
super().__init__("""
717+
S2FFT was compiled without CUDA support. Cuda functions are not supported.
718+
Please make sure that nvcc is in your path and $CUDA_HOME is set then reinstall s2fft using pip.
719+
""")
720+
721+
714722
def _healpix_fft_cuda_lowering(ctx, f, *, L, nside, reality, fft_type, norm, adjoint):
715-
assert _s2fft.COMPILED_WITH_CUDA, """
716-
S2FFT was compiled without CUDA support. Cuda functions are not supported.
717-
Please make sure that nvcc is in your path and $CUDA_HOME is set then reinstall s2fft using pip.
718-
"""
723+
if not _s2fft.COMPILED_WITH_CUDA:
724+
raise MissingCUDASupport()
725+
719726
(aval_out,) = ctx.avals_out
720727

721728
out_dtype = aval_out.dtype

tests/test_healpix_ffts.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from numpy.testing import assert_allclose
77
from packaging.version import Version as _Version
88

9+
import s2fft
910
from s2fft.sampling import s2_samples as samples
1011
from s2fft.utils.healpix_ffts import (
1112
healpix_fft_cuda,
@@ -103,8 +104,9 @@ def test_healpix_fft_cuda_transforms(flm_generator, nside):
103104
# Generate a random bandlimited signal
104105
def generate_flm():
105106
flm = flm_generator(L=L, reality=False)
106-
flm_hp = samples.flm_2d_to_hp(flm, L)
107-
f = hp.sphtfunc.alm2map(flm_hp, nside, lmax=L - 1)
107+
f = s2fft.inverse(
108+
flm, L=L, nside=nside, reality=False, method="jax", sampling="healpix"
109+
)
108110
return f
109111

110112
f_stacked = jnp.stack([generate_flm() for _ in range(10)], axis=0)
@@ -125,15 +127,15 @@ def healpix_cuda(f):
125127
)
126128
# test jacfwd
127129
assert_allclose(
128-
jax.jacfwd(healpix_jax)(f),
129-
jax.jacfwd(healpix_cuda)(f),
130+
jax.jacfwd(healpix_jax)(f.real),
131+
jax.jacfwd(healpix_cuda)(f.real),
130132
atol=1e-7,
131133
rtol=1e-7,
132134
)
133135
# test jacrev
134136
assert_allclose(
135-
jax.jacrev(healpix_jax)(f),
136-
jax.jacrev(healpix_cuda)(f),
137+
jax.jacrev(healpix_jax)(f.real),
138+
jax.jacrev(healpix_cuda)(f.real),
137139
atol=1e-7,
138140
rtol=1e-7,
139141
)
@@ -147,8 +149,9 @@ def test_healpix_ifft_cuda_transforms(flm_generator, nside):
147149
# Generate a random bandlimited signal
148150
def generate_flm():
149151
flm = flm_generator(L=L, reality=False)
150-
flm_hp = samples.flm_2d_to_hp(flm, L)
151-
f = hp.sphtfunc.alm2map(flm_hp, nside, lmax=L - 1)
152+
f = s2fft.inverse(
153+
flm, L=L, nside=nside, reality=False, method="jax", sampling="healpix"
154+
)
152155
ftm = healpix_fft_jax(f, L, nside, False)
153156
return ftm
154157

@@ -164,23 +167,23 @@ def healpix_inv_cuda(f):
164167
# Test VMAP
165168
assert_allclose(
166169
jax.vmap(healpix_inv_jax)(ftm_stacked).flatten(),
167-
jax.vmap(healpix_inv_jax)(ftm_stacked).flatten(),
170+
jax.vmap(healpix_inv_cuda)(ftm_stacked).flatten(),
168171
atol=1e-7,
169172
rtol=1e-7,
170173
)
171174

172175
# test jacfwd
173176
assert_allclose(
174-
jax.jacfwd(healpix_inv_jax)(ftm).flatten(),
175-
jax.jacfwd(healpix_inv_cuda)(ftm).flatten(),
177+
jax.jacfwd(healpix_inv_jax)(ftm.real).flatten(),
178+
jax.jacfwd(healpix_inv_cuda)(ftm.real).flatten(),
176179
atol=1e-7,
177180
rtol=1e-7,
178181
)
179182

180183
# test jacrev
181184
assert_allclose(
182-
jax.jacrev(healpix_inv_jax)(ftm).flatten(),
183-
jax.jacrev(healpix_inv_cuda)(ftm).flatten(),
185+
jax.jacrev(healpix_inv_jax)(ftm.real).flatten(),
186+
jax.jacrev(healpix_inv_cuda)(ftm.real).flatten(),
184187
atol=1e-7,
185188
rtol=1e-7,
186189
)

0 commit comments

Comments
 (0)