99
1010# did not find promote_dtypes_complex outside _src
1111from jax ._src .numpy .util import promote_dtypes_complex
12+ from jax .interpreters import batching
1213from s2fft_lib import _s2fft
1314
1415from s2fft .sampling import s2_samples as samples
@@ -692,23 +693,25 @@ def ring_phase_shifts_hp_jax(
692693# Custom healpix_fft_cuda primitive
693694
694695
695- def _healpix_fft_cuda_abstract (f , L , nside , reality , fft_type , norm ):
696+ def _healpix_fft_cuda_abstract (f , L , nside , reality , fft_type , norm , adjoint ):
696697 # For the forward pass, the input is a HEALPix pixel-space array of size nside^2 *
697698 # 12 and the output is a FTM array of shape (number of rings , width of FTM slice)
698699 # which is (4 * nside - 1 , 2 * L )
699700 healpix_size = (nside ** 2 * 12 ,)
700701 ftm_size = (4 * nside - 1 , 2 * L )
701702 if fft_type == "forward" :
702- assert f .shape == healpix_size
703- return f .update (shape = ftm_size , dtype = f .dtype )
703+ batch_shape = (f .shape [0 ],) if f .ndim == 2 else ()
704+ assert (f .shape [- 1 ],) == healpix_size
705+ return f .update (shape = batch_shape + ftm_size , dtype = f .dtype )
704706 elif fft_type == "backward" :
705- assert f .shape == ftm_size
706- return f .update (shape = healpix_size , dtype = f .dtype )
707+ batch_shape = (f .shape [0 ],) if f .ndim == 3 else ()
708+ assert f .shape [- 2 :] == ftm_size
709+ return f .update (shape = batch_shape + healpix_size , dtype = f .dtype )
707710 else :
708711 raise ValueError (f"fft_type { fft_type } not recognised." )
709712
710713
711- def _healpix_fft_cuda_lowering (ctx , f , * , L , nside , reality , fft_type , norm ):
714+ def _healpix_fft_cuda_lowering (ctx , f , * , L , nside , reality , fft_type , norm , adjoint ):
712715 assert _s2fft .COMPILED_WITH_CUDA , """
713716 S2FFT was compiled without CUDA support. Cuda functions are not supported.
714717 Please make sure that nvcc is in your path and $CUDA_HOME is set then reinstall s2fft using pip.
@@ -748,27 +751,57 @@ def _healpix_fft_cuda_lowering(ctx, f, *, L, nside, reality, fft_type, norm):
748751 reality = reality ,
749752 normalize = normalize ,
750753 forward = forward ,
754+ adjoint = adjoint ,
751755 )
752756
753757
758+ def _healpix_fft_cuda_batching_rule (
759+ batched_args , batched_axis , L , nside , reality , fft_type , norm , adjoint
760+ ):
761+ (x ,) = batched_args
762+ (bd ,) = batched_axis
763+
764+ if fft_type == "forward" :
765+ assert x .ndim == 2
766+ elif fft_type == "backward" :
767+ assert x .ndim == 3
768+ else :
769+ raise ValueError (f"fft_type { fft_type } not recognised." )
770+
771+ x = batching .moveaxis (x , bd , 0 )
772+ return _healpix_fft_cuda_primitive .bind (
773+ x ,
774+ L = L ,
775+ nside = nside ,
776+ reality = reality ,
777+ fft_type = fft_type ,
778+ norm = norm ,
779+ adjoint = adjoint ,
780+ ), 0
781+
782+
754783def _healpix_fft_cuda_transpose (
755- df : jnp .ndarray , L : int , nside : int , reality : bool , fft_type : str , norm : str
784+ df : jnp .ndarray ,
785+ L : int ,
786+ nside : int ,
787+ reality : bool ,
788+ fft_type : str ,
789+ norm : str ,
790+ adjoint : bool ,
756791) -> jnp .ndarray :
757- scale_factors = (
758- jnp .concatenate ((jnp .ones (L ), 2 * jnp .ones (L * (L - 1 ) // 2 )))
759- * (3 * nside ** 2 )
760- / jnp .pi
792+ fft_type = "backward" if fft_type == "forward" else "forward"
793+ norm = "backward" if norm == "forward" else "forward"
794+ return (
795+ _healpix_fft_cuda_primitive .bind (
796+ df ,
797+ L = L ,
798+ nside = nside ,
799+ reality = reality ,
800+ fft_type = fft_type ,
801+ norm = norm ,
802+ adjoint = not adjoint ,
803+ ),
761804 )
762- if fft_type == "forward" :
763- return (
764- scale_factors
765- * jnp .conj (healpix_ifft_cuda (jnp .conj (df ), L , nside , reality , norm )),
766- )
767- elif fft_type == "backward" :
768- return (
769- scale_factors
770- * jnp .conj (healpix_fft_cuda (jnp .conj (df ), L , nside , reality , norm )),
771- )
772805
773806
774807# Register healpfix_fft_cuda custom call target
@@ -781,6 +814,7 @@ def _healpix_fft_cuda_transpose(
781814 abstract_evaluation = _healpix_fft_cuda_abstract ,
782815 lowering_per_platform = {None : _healpix_fft_cuda_lowering },
783816 transpose = _healpix_fft_cuda_transpose ,
817+ batcher = _healpix_fft_cuda_batching_rule ,
784818 is_linear = True ,
785819)
786820
@@ -811,7 +845,13 @@ def healpix_fft_cuda(
811845 """
812846 (f ,) = promote_dtypes_complex (f )
813847 return _healpix_fft_cuda_primitive .bind (
814- f , L = L , nside = nside , reality = reality , fft_type = "forward" , norm = norm
848+ f ,
849+ L = L ,
850+ nside = nside ,
851+ reality = reality ,
852+ fft_type = "forward" ,
853+ norm = norm ,
854+ adjoint = False ,
815855 )
816856
817857
@@ -841,5 +881,11 @@ def healpix_ifft_cuda(
841881 """
842882 (ftm ,) = promote_dtypes_complex (ftm )
843883 return _healpix_fft_cuda_primitive .bind (
844- ftm , L = L , nside = nside , reality = reality , fft_type = "backward" , norm = norm
884+ ftm ,
885+ L = L ,
886+ nside = nside ,
887+ reality = reality ,
888+ fft_type = "backward" ,
889+ norm = norm ,
890+ adjoint = False ,
845891 )
0 commit comments