Skip to content

Commit 2fd3c8a

Browse files
committed
Update JAX Binding to use FFI
1 parent 6d55715 commit 2fd3c8a

File tree

6 files changed

+222
-245
lines changed

6 files changed

+222
-245
lines changed

.pre-commit-config.yaml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,10 @@ repos:
44
hooks:
55
- id: ruff
66
- id: ruff-format
7+
- repo: https://github.com/pre-commit/mirrors-clang-format
8+
rev: v18.1.4
9+
hooks:
10+
- id: clang-format
11+
files: '\.(c|cc|cpp|h|hpp|cxx|hh|cu|cuh)$'
12+
exclude: '^third_party/|/pybind11/'
13+
name: clang-format

CMakeLists.txt

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,15 @@ if(CMAKE_CUDA_COMPILER)
2828
else()
2929
find_package(CUDAToolkit REQUIRED)
3030

31-
find_package(Python 3.8 COMPONENTS Interpreter Development.Module REQUIRED)
31+
# Add the executable
32+
find_package(Python 3.8
33+
REQUIRED COMPONENTS Interpreter Development.Module
34+
OPTIONAL_COMPONENTS Development.SABIModule)
35+
set(XLA_DIR ${Python_SITELIB}/jaxlib/include)
36+
message(STATUS "XLA_DIR: ${XLA_DIR}")
3237

3338
# Detect the installed nanobind package and import it into CMake
34-
execute_process(
35-
COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir
36-
OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE nanobind_ROOT)
37-
find_package(nanobind CONFIG REQUIRED)
39+
find_package(nanobind CONFIG REQUIRED)
3840

3941
nanobind_add_module(_s2fft STABLE_ABI
4042
${CMAKE_CURRENT_LIST_DIR}/lib/src/extensions.cc
@@ -45,7 +47,10 @@ if(CMAKE_CUDA_COMPILER)
4547
)
4648

4749
target_link_libraries(_s2fft PRIVATE CUDA::cudart_static CUDA::cufft_static CUDA::culibos)
48-
target_include_directories(_s2fft PUBLIC ${CMAKE_CURRENT_LIST_DIR}/lib/include)
50+
target_include_directories(_s2fft PUBLIC
51+
${CMAKE_CURRENT_LIST_DIR}/lib/include
52+
${XLA_DIR}
53+
)
4954
set_target_properties(_s2fft PROPERTIES
5055
LINKER_LANGUAGE CUDA
5156
CUDA_SEPARABLE_COMPILATION ON)

lib/include/kernel_helpers.h

Lines changed: 0 additions & 76 deletions
This file was deleted.

lib/include/kernel_nanobind_helpers.h

Lines changed: 0 additions & 51 deletions
This file was deleted.

lib/include/s2fft.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,16 +28,16 @@ void s2fft_nphi_2_rings(float *data, int nside);
2828

2929
class s2fftDescriptor {
3030
public:
31-
int nside;
32-
int harmonic_band_limit;
31+
int64_t nside;
32+
int64_t harmonic_band_limit;
3333
bool reality;
3434

3535
bool forward = true;
3636
s2fftKernels::fft_norm norm = s2fftKernels::BACKWARD;
3737
bool shift = true;
3838
bool double_precision = false;
3939

40-
s2fftDescriptor(int nside, int harmonic_band_limit, bool reality, bool forward = true,
40+
s2fftDescriptor(int64_t nside, int64_t harmonic_band_limit, bool reality, bool forward = true,
4141
s2fftKernels::fft_norm norm = s2fftKernels::BACKWARD, bool shift = true,
4242
bool double_precision = false)
4343
: nside(nside),
@@ -95,7 +95,7 @@ namespace std {
9595
template <>
9696
struct hash<s2fft::s2fftDescriptor> {
9797
std::size_t operator()(const s2fft::s2fftDescriptor &k) const {
98-
size_t hash = std::hash<int>()(k.nside) ^ (std::hash<int>()(k.harmonic_band_limit) << 1) ^
98+
size_t hash = std::hash<int64_t>()(k.nside) ^ (std::hash<int64_t>()(k.harmonic_band_limit) << 1) ^
9999
(std::hash<bool>()(k.reality) << 2) ^ (std::hash<int>()(k.norm) << 3) ^
100100
(std::hash<bool>()(k.shift) << 4) ^ (std::hash<bool>()(k.double_precision) << 5);
101101
return hash;

0 commit comments

Comments
 (0)