11from functools import partial
22
3+ import jax
34import jax .numpy as jnp
45import jaxlib .mlir .ir as ir
56import numpy as np
89
910# did not find promote_dtypes_complex outside _src
1011from jax ._src .numpy .util import promote_dtypes_complex
11- from jax .lib import xla_client
12- from jaxlib .hlo_helpers import custom_call
1312from s2fft_lib import _s2fft
1413
1514from s2fft .sampling import s2_samples as samples
@@ -703,16 +702,18 @@ def _healpix_fft_cuda_abstract(f, L, nside, reality, fft_type, norm):
703702 assert f .shape == healpix_size
704703 return f .update (shape = ftm_size , dtype = f .dtype )
705704 elif fft_type == "backward" :
706- print (f"f.shape { f .shape } " )
707705 assert f .shape == ftm_size
708706 return f .update (shape = healpix_size , dtype = f .dtype )
709707 else :
710708 raise ValueError (f"fft_type { fft_type } not recognised." )
711709
712710
713711def _healpix_fft_cuda_lowering (ctx , f , * , L , nside , reality , fft_type , norm ):
712+ assert _s2fft .COMPILED_WITH_CUDA , """
713+ S2FFT was compiled without CUDA support. Cuda functions are not supported.
714+ Please make sure that nvcc is in your path and $CUDA_HOME is set then reinstall s2fft using pip.
715+ """
714716 (aval_out ,) = ctx .avals_out
715- a_type = ir .RankedTensorType (f .type )
716717
717718 out_dtype = aval_out .dtype
718719 if out_dtype == np .complex64 :
@@ -734,34 +735,53 @@ def _healpix_fft_cuda_lowering(ctx, f, *, L, nside, reality, fft_type, norm):
734735 else :
735736 raise ValueError (f"Unknown norm { norm } " )
736737
737- descriptor = _s2fft .build_healpix_fft_descriptor (
738- nside , L , reality , forward , normalize , is_double
738+ if is_double :
739+ ffi_lowered = jax .ffi .ffi_lowering ("healpix_fft_cuda_c128" )
740+ else :
741+ ffi_lowered = jax .ffi .ffi_lowering ("healpix_fft_cuda_c64" )
742+
743+ return ffi_lowered (
744+ ctx ,
745+ f ,
746+ nside = nside ,
747+ harmonic_band_limit = L ,
748+ reality = reality ,
749+ normalize = normalize ,
750+ forward = forward ,
739751 )
740752
741- layout = tuple (range (len (a_type .shape ) - 1 , - 1 , - 1 ))
742- out_layout = tuple (range (len (out_type .shape ) - 1 , - 1 , - 1 ))
743-
744- result = custom_call (
745- "healpix_fft_cuda" ,
746- result_types = [out_type ],
747- operands = [f ],
748- operand_layouts = [layout ],
749- result_layouts = [out_layout ],
750- has_side_effect = True ,
751- backend_config = descriptor ,
753+
754+ def _healpix_fft_cuda_transpose (
755+ df : jnp .ndarray , L : int , nside : int , reality : bool , fft_type : str , norm : str
756+ ) -> jnp .ndarray :
757+ scale_factors = (
758+ jnp .concatenate ((jnp .ones (L ), 2 * jnp .ones (L * (L - 1 ) // 2 )))
759+ * (3 * nside ** 2 )
760+ / jnp .pi
752761 )
753- return result .results
762+ if fft_type == "forward" :
763+ return (
764+ scale_factors
765+ * jnp .conj (healpix_ifft_cuda (jnp .conj (df ), L , nside , reality , norm )),
766+ )
767+ elif fft_type == "backward" :
768+ return (
769+ scale_factors
770+ * jnp .conj (healpix_fft_cuda (jnp .conj (df ), L , nside , reality , norm )),
771+ )
754772
755773
756774# Register healpfix_fft_cuda custom call target
757775for name , fn in _s2fft .registration ().items ():
758- xla_client . register_custom_call_target (name , fn , platform = "gpu " )
776+ jax . ffi . register_ffi_target (name , fn , platform = "CUDA " )
759777
760778_healpix_fft_cuda_primitive = register_primitive (
761779 "healpix_fft_cuda" ,
762780 multiple_results = False ,
763781 abstract_evaluation = _healpix_fft_cuda_abstract ,
764782 lowering_per_platform = {None : _healpix_fft_cuda_lowering },
783+ transpose = _healpix_fft_cuda_transpose ,
784+ is_linear = True ,
765785)
766786
767787
0 commit comments