66from numpy .testing import assert_allclose
77from packaging .version import Version as _Version
88
9+ import s2fft
910from s2fft .sampling import s2_samples as samples
1011from s2fft .utils .healpix_ffts import (
1112 healpix_fft_cuda ,
@@ -103,8 +104,9 @@ def test_healpix_fft_cuda_transforms(flm_generator, nside):
103104 # Generate a random bandlimited signal
104105 def generate_flm ():
105106 flm = flm_generator (L = L , reality = False )
106- flm_hp = samples .flm_2d_to_hp (flm , L )
107- f = hp .sphtfunc .alm2map (flm_hp , nside , lmax = L - 1 )
107+ f = s2fft .inverse (
108+ flm , L = L , nside = nside , reality = False , method = "jax" , sampling = "healpix"
109+ )
108110 return f
109111
110112 f_stacked = jnp .stack ([generate_flm () for _ in range (10 )], axis = 0 )
@@ -125,15 +127,15 @@ def healpix_cuda(f):
125127 )
126128 # test jacfwd
127129 assert_allclose (
128- jax .jacfwd (healpix_jax )(f ),
129- jax .jacfwd (healpix_cuda )(f ),
130+ jax .jacfwd (healpix_jax )(f . real ),
131+ jax .jacfwd (healpix_cuda )(f . real ),
130132 atol = 1e-7 ,
131133 rtol = 1e-7 ,
132134 )
133135 # test jacrev
134136 assert_allclose (
135- jax .jacrev (healpix_jax )(f ),
136- jax .jacrev (healpix_cuda )(f ),
137+ jax .jacrev (healpix_jax )(f . real ),
138+ jax .jacrev (healpix_cuda )(f . real ),
137139 atol = 1e-7 ,
138140 rtol = 1e-7 ,
139141 )
@@ -147,8 +149,9 @@ def test_healpix_ifft_cuda_transforms(flm_generator, nside):
147149 # Generate a random bandlimited signal
148150 def generate_flm ():
149151 flm = flm_generator (L = L , reality = False )
150- flm_hp = samples .flm_2d_to_hp (flm , L )
151- f = hp .sphtfunc .alm2map (flm_hp , nside , lmax = L - 1 )
152+ f = s2fft .inverse (
153+ flm , L = L , nside = nside , reality = False , method = "jax" , sampling = "healpix"
154+ )
152155 ftm = healpix_fft_jax (f , L , nside , False )
153156 return ftm
154157
@@ -164,23 +167,23 @@ def healpix_inv_cuda(f):
164167 # Test VMAP
165168 assert_allclose (
166169 jax .vmap (healpix_inv_jax )(ftm_stacked ).flatten (),
167- jax .vmap (healpix_inv_jax )(ftm_stacked ).flatten (),
170+ jax .vmap (healpix_inv_cuda )(ftm_stacked ).flatten (),
168171 atol = 1e-7 ,
169172 rtol = 1e-7 ,
170173 )
171174
172175 # test jacfwd
173176 assert_allclose (
174- jax .jacfwd (healpix_inv_jax )(ftm ).flatten (),
175- jax .jacfwd (healpix_inv_cuda )(ftm ).flatten (),
177+ jax .jacfwd (healpix_inv_jax )(ftm . real ).flatten (),
178+ jax .jacfwd (healpix_inv_cuda )(ftm . real ).flatten (),
176179 atol = 1e-7 ,
177180 rtol = 1e-7 ,
178181 )
179182
180183 # test jacrev
181184 assert_allclose (
182- jax .jacrev (healpix_inv_jax )(ftm ).flatten (),
183- jax .jacrev (healpix_inv_cuda )(ftm ).flatten (),
185+ jax .jacrev (healpix_inv_jax )(ftm . real ).flatten (),
186+ jax .jacrev (healpix_inv_cuda )(ftm . real ).flatten (),
184187 atol = 1e-7 ,
185188 rtol = 1e-7 ,
186189 )
0 commit comments