Skip to content

Commit 92fe6a0

Browse files
committed
add vmap jacrev and jacfwd tests
1 parent 9e0f121 commit 92fe6a0

File tree

1 file changed

+92
-0
lines changed

1 file changed

+92
-0
lines changed

tests/test_healpix_ffts.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import healpy as hp
22
import jax
3+
import jax.numpy as jnp
34
import numpy as np
45
import pytest
56
from numpy.testing import assert_allclose
@@ -92,3 +93,94 @@ def test_healpix_ifft_cuda(flm_generator, nside):
9293
atol=1e-7,
9394
rtol=1e-7,
9495
)
96+
97+
98+
@pytest.mark.skipif(not gpu_available, reason="GPU not available")
99+
@pytest.mark.parametrize("nside", nside_to_test)
100+
def test_healpix_fft_cuda_transforms(flm_generator, nside):
101+
L = 2 * nside
102+
103+
# Generate a random bandlimited signal
104+
def generate_flm():
105+
flm = flm_generator(L=L, reality=False)
106+
flm_hp = samples.flm_2d_to_hp(flm, L)
107+
f = hp.sphtfunc.alm2map(flm_hp, nside, lmax=L - 1)
108+
return f
109+
110+
f_stacked = jnp.stack([generate_flm() for _ in range(10)], axis=0)
111+
112+
def healpix_jax(f):
113+
return healpix_fft_jax(f, L, nside, False).real
114+
115+
def healpix_cuda(f):
116+
return healpix_fft_cuda(f, L, nside, False).real
117+
118+
f = f_stacked[0]
119+
# Test VMAP
120+
assert_allclose(
121+
jax.vmap(healpix_jax)(f_stacked),
122+
jax.vmap(healpix_cuda)(f_stacked),
123+
atol=1e-7,
124+
rtol=1e-7,
125+
)
126+
# test jacfwd
127+
assert_allclose(
128+
jax.jacfwd(healpix_jax)(f),
129+
jax.jacfwd(healpix_cuda)(f),
130+
atol=1e-7,
131+
rtol=1e-7,
132+
)
133+
# test jacrev
134+
assert_allclose(
135+
jax.jacrev(healpix_jax)(f),
136+
jax.jacrev(healpix_cuda)(f),
137+
atol=1e-7,
138+
rtol=1e-7,
139+
)
140+
141+
142+
@pytest.mark.skipif(not gpu_available, reason="GPU not available")
143+
@pytest.mark.parametrize("nside", nside_to_test)
144+
def test_healpix_ifft_cuda_transforms(flm_generator, nside):
145+
L = 2 * nside
146+
147+
# Generate a random bandlimited signal
148+
def generate_flm():
149+
flm = flm_generator(L=L, reality=False)
150+
flm_hp = samples.flm_2d_to_hp(flm, L)
151+
f = hp.sphtfunc.alm2map(flm_hp, nside, lmax=L - 1)
152+
ftm = healpix_fft_jax(f, L, nside, False)
153+
return ftm
154+
155+
ftm_stacked = jnp.stack([generate_flm() for _ in range(10)], axis=0)
156+
ftm = ftm_stacked[0].real
157+
158+
def healpix_inv_jax(f):
159+
return healpix_ifft_jax(f, L, nside, False).real
160+
161+
def healpix_inv_cuda(f):
162+
return healpix_ifft_cuda(f, L, nside, False).real
163+
164+
# Test VMAP
165+
assert_allclose(
166+
jax.vmap(healpix_inv_jax)(ftm_stacked).flatten(),
167+
jax.vmap(healpix_inv_jax)(ftm_stacked).flatten(),
168+
atol=1e-7,
169+
rtol=1e-7,
170+
)
171+
172+
# test jacfwd
173+
assert_allclose(
174+
jax.jacfwd(healpix_inv_jax)(ftm).flatten(),
175+
jax.jacfwd(healpix_inv_cuda)(ftm).flatten(),
176+
atol=1e-7,
177+
rtol=1e-7,
178+
)
179+
180+
# test jacrev
181+
assert_allclose(
182+
jax.jacrev(healpix_inv_jax)(ftm).flatten(),
183+
jax.jacrev(healpix_inv_cuda)(ftm).flatten(),
184+
atol=1e-7,
185+
rtol=1e-7,
186+
)

0 commit comments

Comments
 (0)