Skip to content

Commit 83296b9

Browse files
authored
Provide PyTorch implementations by wrapping JAX functions (#277)
* Make spherical precompute benchmarks compatible with pytorch * Add utilities for wrapping JAX function for use in PyTorch * Add initial wrapped versions of torch precompute transforms * Update array conversion in benchmarks to avoid byte alignment warning * Remove previous torch precompute spherical transform implementations * Make precompute Wigner benchmarks compatible with torch * Add torch wrapper precompute Wigner transforms * Removing docstring for now removed using_torch arg * Correct typo in docstring * Remove previous torch precompute Wigner transform implementations * Update references to JAX in docstring when wrapping for torch use * Try to infer differentiable args from annotations * Update annotations of wrapped functions * Remove explicit differentiable argument name defs * Add helper function for wrapping all JAX functions in module * Use wrappers for Torch resampling and quadrature modules * Remove using_torch option from signal generator functions * Update torch demo notebook to follow new wrapper interface * Use wrappers for torch HEALPix FFT functions * Make torch an optional dependency * Use backwards compatible tree_map import * Copy annotations in wrapped function and check for existence of doc * Start on torch wrapper tests * Add annotations futures import * Add additional torch wrapper tests * Make type alias Python 3.8 compatible * Account for differing complex derivatives conventions between torch and JAX * Add additional complex test cases and gradient checks to wrapper tests * Ignore complex warning due to casts in tests rather than erroring * Reduce max number iter tested for HEALPix to reduce test times * Maintain compatibility with older JAX versions * More maintaining compatibility with older JAX versions * Correct typo in comment * Explicitly cast kernels in einsum ops to avoid ComplexWarning causing test fails * Force JAX double precision mode in Wigner precompute tests * Add test for function checking torch available * Fix torch optional import logic to avoid errors when not installed * Refactor method dispatch logic in inverse transform * Refactor method dispatch logic in HEALPix FFTs * Expose option to use HEALPix custom primitive in inverse transform * Pass through method to select HEALPix (I)FFT function * Expose jax_cuda method in top-level spherical inverse function * Refactor method dispatch logic in forward transform * Add OTF spherical transform torch wrappers * Make torch wrapper diff arg inferring robust to non-type annotations * Mark use_healpix_custom_primitive arg as static * Include torch wrappers in tests * Refactor method dispatch logic in Wigner transforms * Add torch wrappers for Wigner OTF transforms * Update README and notebook to indicate wider torch support * Pin JAX version to less than v0.6.0 due to breaking changes
1 parent 351fc94 commit 83296b9

22 files changed

+941
-1412
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ f = fft.wigner.inverse_jax(flmn, L, N, method="jax")
195195
For further details on usage see the [documentation](https://astro-informatics.github.io/s2fft/) and associated [notebooks](https://astro-informatics.github.io/s2fft/tutorials/spherical_harmonic/spherical_harmonic_transform.html).
196196

197197
> [!NOTE]
198-
> We also provide PyTorch support for the precompute version of our transforms, as demonstrated in the [_Torch frontend_ tutorial notebook](https://astro-informatics.github.io/s2fft/tutorials/torch_frontend/torch_frontend.html).
198+
> We also provide PyTorch support for our transforms, as demonstrated in the [_Torch frontend_ tutorial notebook](https://astro-informatics.github.io/s2fft/tutorials/torch_frontend/torch_frontend.html). This wraps the JAX implementations so JAX will need to be installed in addition to PyTorch.
199199
200200
## SSHT & HEALPix wrappers 💡
201201

benchmarks/precompute_spherical.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Benchmarks for precompute spherical transforms."""
22

3+
import jax
34
import numpy as np
45
from benchmarking import (
56
BenchmarkSetup,
@@ -10,6 +11,7 @@
1011

1112
import s2fft
1213
import s2fft.precompute_transforms
14+
from s2fft.utils import torch_wrapper
1315

1416
L_VALUES = [8, 16, 32, 64, 128, 256]
1517
SPIN_VALUES = [0]
@@ -31,11 +33,17 @@ def setup_forward(method, L, sampling, spin, reality, recursion):
3133
sampling=sampling,
3234
reality=reality,
3335
)
34-
kernel_function = (
35-
s2fft.precompute_transforms.construct.spin_spherical_kernel_jax
36-
if method == "jax"
37-
else s2fft.precompute_transforms.construct.spin_spherical_kernel
38-
)
36+
# As torch method wraps JAX functions and converting NumPy array to torch tensor
37+
# causes warning 'DLPack buffer is not aligned' about byte aligment on subsequently
38+
# converting to JAX array using mutual DLPack support we first convert the NumPy
39+
# arrays to a JAX arrays before converting to torch tensors which eliminates this
40+
# warning
41+
if method.startswith("jax") or method.startswith("torch"):
42+
flm = jax.numpy.asarray(flm)
43+
f = jax.numpy.asarray(f)
44+
if method.startswith("torch"):
45+
flm, f = torch_wrapper.tree_map_jax_array_to_torch_tensor((flm, f))
46+
kernel_function = s2fft.precompute_transforms.spherical._kernel_functions[method]
3947
kernel = kernel_function(
4048
L=L,
4149
spin=spin,
@@ -73,11 +81,16 @@ def setup_inverse(method, L, sampling, spin, reality, recursion):
7381
skip("Reality only valid for scalar fields (spin=0).")
7482
rng = np.random.default_rng()
7583
flm = s2fft.utils.signal_generator.generate_flm(rng, L, spin=spin, reality=reality)
76-
kernel_function = (
77-
s2fft.precompute_transforms.construct.spin_spherical_kernel_jax
78-
if method == "jax"
79-
else s2fft.precompute_transforms.construct.spin_spherical_kernel
80-
)
84+
# As torch method wraps JAX functions and converting NumPy array to torch tensor
85+
# causes warning 'DLPack buffer is not aligned' about byte aligment on subsequently
86+
# converting to JAX array using mutual DLPack support we first convert the NumPy
87+
# array to a JAX array before converting to a torch tensor which eliminates this
88+
# warning
89+
if method.startswith("jax") or method.startswith("torch"):
90+
flm = jax.numpy.asarray(flm)
91+
if method.startswith("torch"):
92+
flm = torch_wrapper.jax_array_to_torch_tensor(flm)
93+
kernel_function = s2fft.precompute_transforms.spherical._kernel_functions[method]
8194
kernel = kernel_function(
8295
L=L,
8396
spin=spin,

benchmarks/precompute_wigner.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Benchmarks for precompute Wigner-d transforms."""
22

3+
import jax
34
import numpy as np
45
from benchmarking import (
56
BenchmarkSetup,
@@ -10,6 +11,7 @@
1011
import s2fft
1112
import s2fft.precompute_transforms
1213
from s2fft.base_transforms import wigner as base_wigner
14+
from s2fft.utils import torch_wrapper
1315

1416
L_VALUES = [16, 32, 64, 128, 256]
1517
N_VALUES = [2]
@@ -31,11 +33,17 @@ def setup_forward(method, L, N, L_lower, sampling, reality, mode):
3133
sampling=sampling,
3234
reality=reality,
3335
)
34-
kernel_function = (
35-
s2fft.precompute_transforms.construct.wigner_kernel_jax
36-
if "jax" in method
37-
else s2fft.precompute_transforms.construct.wigner_kernel
38-
)
36+
# As torch method wraps JAX functions and converting NumPy array to torch tensor
37+
# causes warning 'DLPack buffer is not aligned' about byte aligment on subsequently
38+
# converting to JAX array using mutual DLPack support we first convert the NumPy
39+
# arrays to a JAX arrays before converting to torch tensors which eliminates this
40+
# warning
41+
if method.startswith("jax") or method.startswith("torch"):
42+
flmn = jax.numpy.asarray(flmn)
43+
f = jax.numpy.asarray(f)
44+
if method.startswith("torch"):
45+
flmn, f = torch_wrapper.tree_map_jax_array_to_torch_tensor((flmn, f))
46+
kernel_function = s2fft.precompute_transforms.wigner._kernel_functions[method]
3947
kernel = kernel_function(
4048
L=L, N=N, reality=reality, sampling=sampling, forward=True, mode=mode
4149
)
@@ -67,11 +75,16 @@ def forward(f, kernel, method, L, N, L_lower, sampling, reality, mode):
6775
def setup_inverse(method, L, N, L_lower, sampling, reality, mode):
6876
rng = np.random.default_rng()
6977
flmn = s2fft.utils.signal_generator.generate_flmn(rng, L, N, reality=reality)
70-
kernel_function = (
71-
s2fft.precompute_transforms.construct.wigner_kernel_jax
72-
if method == "jax"
73-
else s2fft.precompute_transforms.construct.wigner_kernel
74-
)
78+
# As torch method wraps JAX functions and converting NumPy array to torch tensor
79+
# causes warning 'DLPack buffer is not aligned' about byte aligment on subsequently
80+
# converting to JAX array using mutual DLPack support we first convert the NumPy
81+
# arrays to a JAX arrays before converting to torch tensors which eliminates this
82+
# warning
83+
if method.startswith("jax") or method.startswith("torch"):
84+
flmn = jax.numpy.asarray(flmn)
85+
if method.startswith("torch"):
86+
flmn = torch_wrapper.jax_array_to_torch_tensor(flmn)
87+
kernel_function = s2fft.precompute_transforms.wigner._kernel_functions[method]
7588
kernel = kernel_function(
7689
L=L, N=N, reality=reality, sampling=sampling, forward=False, mode=mode
7790
)

notebooks/torch_frontend.ipynb

Lines changed: 89 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -28,27 +28,24 @@
2828
"cell_type": "markdown",
2929
"metadata": {},
3030
"source": [
31-
"This minimal tutorial demonstrates how to use the torch frontend for `S2FFT` to compute spherical harmonic transforms. Though `S2FFT` is primarily designed for JAX, this torch functionality is fully unit tested (including gradients) and can be used straightforwardly as a learnable layer within existing models."
31+
"This minimal tutorial demonstrates how to use the torch frontend for `S2FFT` to compute spherical harmonic transforms. Though `S2FFT` is primarily designed for JAX, this torch functionality is fully unit tested (including gradients) and can be used straightforwardly as a learnable layer within existing models. As the torch functions wrap the JAX implementations we need to configure JAX to use 64-bit precision floating point types by default to ensure sufficient precision for the transforms - `S2FFT` will emit a warning if this has not been done."
3232
]
3333
},
3434
{
3535
"cell_type": "code",
3636
"execution_count": 2,
3737
"metadata": {},
38-
"outputs": [
39-
{
40-
"name": "stderr",
41-
"output_type": "stream",
42-
"text": [
43-
"JAX is not using 64-bit precision. This will dramatically affect numerical precision at even moderate L.\n"
44-
]
45-
}
46-
],
38+
"outputs": [],
4739
"source": [
40+
"import jax\n",
41+
"jax.config.update(\"jax_enable_x64\", True)\n",
4842
"import torch \n",
49-
"import numpy as np \n",
50-
"from s2fft.precompute_transforms.spherical import inverse, forward\n",
51-
"from s2fft.precompute_transforms.construct import spin_spherical_kernel\n",
43+
"import numpy as np\n",
44+
"from s2fft.transforms.spherical import inverse, forward\n",
45+
"from s2fft.precompute_transforms.spherical import (\n",
46+
" inverse as precompute_inverse, forward as precompute_forward\n",
47+
")\n",
48+
"from s2fft.precompute_transforms.construct import spin_spherical_kernel_torch\n",
5249
"from s2fft.utils import signal_generator"
5350
]
5451
},
@@ -65,33 +62,40 @@
6562
"metadata": {},
6663
"outputs": [],
6764
"source": [
68-
"L = 64 # Spherical harmonic bandlimit\n",
69-
"rng = np.random.default_rng(1234951510) # Random seed for signal generator\n",
70-
"flm = signal_generator.generate_flm(rng, L, using_torch=True) # Random set of spherical harmonic coefficients"
65+
"L = 64 \n",
66+
"rng = np.random.default_rng(1234951510)\n",
67+
"flm = torch.from_numpy(signal_generator.generate_flm(rng, L))"
7168
]
7269
},
7370
{
7471
"cell_type": "markdown",
7572
"metadata": {},
7673
"source": [
77-
"For the fully precompute transform we must also generate the precompute kernels which we store as a torch tensors."
74+
"Now lets calculate the signal on the sphere by applying the inverse spherical harmonic transform"
7875
]
7976
},
8077
{
8178
"cell_type": "code",
8279
"execution_count": 4,
8380
"metadata": {},
84-
"outputs": [],
81+
"outputs": [
82+
{
83+
"name": "stderr",
84+
"output_type": "stream",
85+
"text": [
86+
"An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.\n"
87+
]
88+
}
89+
],
8590
"source": [
86-
"inverse_kernel = spin_spherical_kernel(L, using_torch=True, forward=False) \n",
87-
"forward_kernel = spin_spherical_kernel(L, using_torch=True, forward=True) "
91+
"f = inverse(flm, L, method=\"torch\")"
8892
]
8993
},
9094
{
9195
"cell_type": "markdown",
9296
"metadata": {},
9397
"source": [
94-
"Now lets calculate the signal on the sphere by applying the inverse spherical harmonic transform"
98+
"To calculate the corresponding spherical harmonic representation execute"
9599
]
96100
},
97101
{
@@ -100,53 +104,107 @@
100104
"metadata": {},
101105
"outputs": [],
102106
"source": [
103-
"f = inverse(flm, L, 0, inverse_kernel, method=\"torch\")"
107+
"flm_check = forward(f, L, method=\"torch\")"
104108
]
105109
},
106110
{
107111
"cell_type": "markdown",
108112
"metadata": {},
109113
"source": [
110-
"To calculate the corresponding spherical harmonic representation execute"
114+
"Finally, lets check the error on the round trip is as expected for 64 bit machine precision floating point arithmetic"
111115
]
112116
},
113117
{
114118
"cell_type": "code",
115119
"execution_count": 6,
116120
"metadata": {},
117-
"outputs": [],
121+
"outputs": [
122+
{
123+
"name": "stdout",
124+
"output_type": "stream",
125+
"text": [
126+
"Mean absolute error = 2.8915048238993476e-14\n"
127+
]
128+
}
129+
],
118130
"source": [
119-
"flm_check = forward(f, L, 0, forward_kernel, method=\"torch\")"
131+
"print(f\"Mean absolute error = {np.nanmean(np.abs(flm_check - flm))}\")"
120132
]
121133
},
122134
{
123135
"cell_type": "markdown",
124136
"metadata": {},
125137
"source": [
126-
"Finally, lets check the error on the roundtrip is at 64bit machine precision"
138+
"For the fully precompute transform we must also generate the precompute kernels which we store as a torch tensors."
127139
]
128140
},
129141
{
130142
"cell_type": "code",
131143
"execution_count": 7,
132144
"metadata": {},
145+
"outputs": [],
146+
"source": [
147+
"inverse_kernel = spin_spherical_kernel_torch(L, forward=False) \n",
148+
"forward_kernel = spin_spherical_kernel_torch(L, forward=True) "
149+
]
150+
},
151+
{
152+
"cell_type": "markdown",
153+
"metadata": {},
154+
"source": [
155+
"We then pass the kernels as additional arguments to the transform functions"
156+
]
157+
},
158+
{
159+
"cell_type": "code",
160+
"execution_count": null,
161+
"metadata": {},
162+
"outputs": [
163+
{
164+
"ename": "NameError",
165+
"evalue": "name 'orward_kernel' is not defined",
166+
"output_type": "error",
167+
"traceback": [
168+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
169+
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
170+
"Cell \u001b[0;32mIn[8], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m precompute_f \u001b[38;5;241m=\u001b[39m precompute_inverse(flm, L, kernel\u001b[38;5;241m=\u001b[39minverse_kernel, method\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtorch\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m----> 2\u001b[0m precompute_flm_check \u001b[38;5;241m=\u001b[39m precompute_forward(f, L, kernel\u001b[38;5;241m=\u001b[39m\u001b[43morward_kernel\u001b[49m, method\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtorch\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n",
171+
"\u001b[0;31mNameError\u001b[0m: name 'orward_kernel' is not defined"
172+
]
173+
}
174+
],
175+
"source": [
176+
"precompute_f = precompute_inverse(flm, L, kernel=inverse_kernel, method=\"torch\")\n",
177+
"precompute_flm_check = precompute_forward(f, L, kernel=forward_kernel, method=\"torch\")"
178+
]
179+
},
180+
{
181+
"cell_type": "markdown",
182+
"metadata": {},
183+
"source": [
184+
"Again, we check the error on the round trip is as expected"
185+
]
186+
},
187+
{
188+
"cell_type": "code",
189+
"execution_count": null,
190+
"metadata": {},
133191
"outputs": [
134192
{
135193
"name": "stdout",
136194
"output_type": "stream",
137195
"text": [
138-
"Mean absolute error = 1.1866908936078849e-14\n"
196+
"Mean absolute error = 2.8472981477378884e-14\n"
139197
]
140198
}
141199
],
142200
"source": [
143-
"print(f\"Mean absolute error = {np.nanmean(np.abs(flm_check - flm))}\")"
201+
"print(f\"Mean absolute error = {np.nanmean(np.abs(precompute_flm_check - flm))}\")"
144202
]
145203
}
146204
],
147205
"metadata": {
148206
"kernelspec": {
149-
"display_name": "Python 3.10.4 ('s2fft')",
207+
"display_name": "s2fft",
150208
"language": "python",
151209
"name": "python3"
152210
},
@@ -160,14 +218,9 @@
160218
"name": "python",
161219
"nbconvert_exporter": "python",
162220
"pygments_lexer": "ipython3",
163-
"version": "3.10.0"
221+
"version": "3.11.10"
164222
},
165-
"orig_nbformat": 4,
166-
"vscode": {
167-
"interpreter": {
168-
"hash": "3425e24474cbe920550266ea26b478634978cc419579f9dbcf479231067df6a3"
169-
}
170-
}
223+
"orig_nbformat": 4
171224
},
172225
"nbformat": 4,
173226
"nbformat_minor": 2

pyproject.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ dependencies = [
3030
"numpy>=1.20",
3131
"jax>=0.3.13,<0.6.0",
3232
"jaxlib",
33-
"torch",
3433
]
3534
dynamic = [
3635
"version",
@@ -74,6 +73,10 @@ tests = [
7473
"pytest-cov",
7574
"so3",
7675
"pyssht",
76+
"torch",
77+
]
78+
torch = [
79+
"torch",
7780
]
7881

7982
[tool.scikit-build]

0 commit comments

Comments
 (0)