Skip to content

Commit f8a9a6d

Browse files
committed
Merge remote-tracking branch 'origin/main' into ASKabalan
2 parents 6f6c07e + 0de6f11 commit f8a9a6d

28 files changed

+1584
-1910
lines changed

.all-contributorsrc

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,15 @@
115115
"review",
116116
"test"
117117
]
118+
},
119+
{
120+
"login": "mdavezac",
121+
"name": "Mayeul d'Avezac",
122+
"avatar_url": "https://avatars.githubusercontent.com/u/2745737?v=4",
123+
"profile": "https://github.com/mdavezac",
124+
"contributions": [
125+
"infra"
126+
]
118127
}
119128
],
120129
"contributorsPerLine": 7,

.github/workflows/build.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ jobs:
5555
fetch-depth: 0
5656
fetch-tags: true
5757
- name: Build wheels
58-
uses: pypa/cibuildwheel@v2.23.2
58+
uses: pypa/cibuildwheel@v2.23.3
5959
env:
6060
CIBW_SKIP: pp*-macosx_arm64
6161
- uses: actions/upload-artifact@v4

.github/workflows/tests.yml

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,15 @@ jobs:
3333
- os: macos-latest
3434
python-version: "3.8"
3535
fail-fast: false
36+
env:
37+
CMAKE_POLICY_VERSION_MINIMUM: 3.5
3638

3739
steps:
3840
- name: Checkout Source
3941
uses: actions/checkout@v4.2.2
42+
with:
43+
fetch-depth: 0
44+
fetch-tags: true
4045

4146
- name: Set up Python ${{ matrix.python-version }}
4247
uses: actions/setup-python@v5
@@ -49,7 +54,13 @@ jobs:
4954
python -m pip install --upgrade pip
5055
pip install .[tests]
5156
57+
- name: Run tests (skipping slow)
58+
if: github.event_name == 'pull_request'
59+
run: |
60+
pytest -v --cov-report=xml --cov=s2fft --cov-config=.coveragerc -m "not slow"
61+
5262
- name: Run tests
63+
if: github.event_name != 'pull_request'
5364
run: |
5465
pytest -v --cov-report=xml --cov=s2fft --cov-config=.coveragerc
5566

README.md

Lines changed: 73 additions & 43 deletions
Large diffs are not rendered by default.

benchmarks/precompute_spherical.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Benchmarks for precompute spherical transforms."""
22

3+
import jax
34
import numpy as np
45
from benchmarking import (
56
BenchmarkSetup,
@@ -10,6 +11,7 @@
1011

1112
import s2fft
1213
import s2fft.precompute_transforms
14+
from s2fft.utils import torch_wrapper
1315

1416
L_VALUES = [8, 16, 32, 64, 128, 256]
1517
SPIN_VALUES = [0]
@@ -31,11 +33,17 @@ def setup_forward(method, L, sampling, spin, reality, recursion):
3133
sampling=sampling,
3234
reality=reality,
3335
)
34-
kernel_function = (
35-
s2fft.precompute_transforms.construct.spin_spherical_kernel_jax
36-
if method == "jax"
37-
else s2fft.precompute_transforms.construct.spin_spherical_kernel
38-
)
36+
# As torch method wraps JAX functions and converting NumPy array to torch tensor
37+
# causes warning 'DLPack buffer is not aligned' about byte aligment on subsequently
38+
# converting to JAX array using mutual DLPack support we first convert the NumPy
39+
# arrays to a JAX arrays before converting to torch tensors which eliminates this
40+
# warning
41+
if method.startswith("jax") or method.startswith("torch"):
42+
flm = jax.numpy.asarray(flm)
43+
f = jax.numpy.asarray(f)
44+
if method.startswith("torch"):
45+
flm, f = torch_wrapper.tree_map_jax_array_to_torch_tensor((flm, f))
46+
kernel_function = s2fft.precompute_transforms.spherical._kernel_functions[method]
3947
kernel = kernel_function(
4048
L=L,
4149
spin=spin,
@@ -73,11 +81,16 @@ def setup_inverse(method, L, sampling, spin, reality, recursion):
7381
skip("Reality only valid for scalar fields (spin=0).")
7482
rng = np.random.default_rng()
7583
flm = s2fft.utils.signal_generator.generate_flm(rng, L, spin=spin, reality=reality)
76-
kernel_function = (
77-
s2fft.precompute_transforms.construct.spin_spherical_kernel_jax
78-
if method == "jax"
79-
else s2fft.precompute_transforms.construct.spin_spherical_kernel
80-
)
84+
# As torch method wraps JAX functions and converting NumPy array to torch tensor
85+
# causes warning 'DLPack buffer is not aligned' about byte aligment on subsequently
86+
# converting to JAX array using mutual DLPack support we first convert the NumPy
87+
# array to a JAX array before converting to a torch tensor which eliminates this
88+
# warning
89+
if method.startswith("jax") or method.startswith("torch"):
90+
flm = jax.numpy.asarray(flm)
91+
if method.startswith("torch"):
92+
flm = torch_wrapper.jax_array_to_torch_tensor(flm)
93+
kernel_function = s2fft.precompute_transforms.spherical._kernel_functions[method]
8194
kernel = kernel_function(
8295
L=L,
8396
spin=spin,

benchmarks/precompute_wigner.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Benchmarks for precompute Wigner-d transforms."""
22

3+
import jax
34
import numpy as np
45
from benchmarking import (
56
BenchmarkSetup,
@@ -10,6 +11,7 @@
1011
import s2fft
1112
import s2fft.precompute_transforms
1213
from s2fft.base_transforms import wigner as base_wigner
14+
from s2fft.utils import torch_wrapper
1315

1416
L_VALUES = [16, 32, 64, 128, 256]
1517
N_VALUES = [2]
@@ -31,11 +33,17 @@ def setup_forward(method, L, N, L_lower, sampling, reality, mode):
3133
sampling=sampling,
3234
reality=reality,
3335
)
34-
kernel_function = (
35-
s2fft.precompute_transforms.construct.wigner_kernel_jax
36-
if "jax" in method
37-
else s2fft.precompute_transforms.construct.wigner_kernel
38-
)
36+
# As torch method wraps JAX functions and converting NumPy array to torch tensor
37+
# causes warning 'DLPack buffer is not aligned' about byte aligment on subsequently
38+
# converting to JAX array using mutual DLPack support we first convert the NumPy
39+
# arrays to a JAX arrays before converting to torch tensors which eliminates this
40+
# warning
41+
if method.startswith("jax") or method.startswith("torch"):
42+
flmn = jax.numpy.asarray(flmn)
43+
f = jax.numpy.asarray(f)
44+
if method.startswith("torch"):
45+
flmn, f = torch_wrapper.tree_map_jax_array_to_torch_tensor((flmn, f))
46+
kernel_function = s2fft.precompute_transforms.wigner._kernel_functions[method]
3947
kernel = kernel_function(
4048
L=L, N=N, reality=reality, sampling=sampling, forward=True, mode=mode
4149
)
@@ -67,11 +75,16 @@ def forward(f, kernel, method, L, N, L_lower, sampling, reality, mode):
6775
def setup_inverse(method, L, N, L_lower, sampling, reality, mode):
6876
rng = np.random.default_rng()
6977
flmn = s2fft.utils.signal_generator.generate_flmn(rng, L, N, reality=reality)
70-
kernel_function = (
71-
s2fft.precompute_transforms.construct.wigner_kernel_jax
72-
if method == "jax"
73-
else s2fft.precompute_transforms.construct.wigner_kernel
74-
)
78+
# As torch method wraps JAX functions and converting NumPy array to torch tensor
79+
# causes warning 'DLPack buffer is not aligned' about byte aligment on subsequently
80+
# converting to JAX array using mutual DLPack support we first convert the NumPy
81+
# arrays to a JAX arrays before converting to torch tensors which eliminates this
82+
# warning
83+
if method.startswith("jax") or method.startswith("torch"):
84+
flmn = jax.numpy.asarray(flmn)
85+
if method.startswith("torch"):
86+
flmn = torch_wrapper.jax_array_to_torch_tensor(flmn)
87+
kernel_function = s2fft.precompute_transforms.wigner._kernel_functions[method]
7588
kernel = kernel_function(
7689
L=L, N=N, reality=reality, sampling=sampling, forward=False, mode=mode
7790
)

notebooks/torch_frontend.ipynb

Lines changed: 76 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -28,125 +28,171 @@
2828
"cell_type": "markdown",
2929
"metadata": {},
3030
"source": [
31-
"This minimal tutorial demonstrates how to use the torch frontend for `S2FFT` to compute spherical harmonic transforms. Though `S2FFT` is primarily designed for JAX, this torch functionality is fully unit tested (including gradients) and can be used straightforwardly as a learnable layer within existing models."
31+
"This minimal tutorial demonstrates how to use the torch frontend for `S2FFT` to compute spherical harmonic transforms. Though `S2FFT` is primarily designed for JAX, this torch functionality is fully unit tested (including gradients) and can be used straightforwardly as a learnable layer within existing models. As the torch functions wrap the JAX implementations we need to configure JAX to use 64-bit precision floating point types by default to ensure sufficient precision for the transforms - `S2FFT` will emit a warning if this has not been done."
3232
]
3333
},
3434
{
3535
"cell_type": "code",
3636
"execution_count": 2,
3737
"metadata": {},
38+
"outputs": [],
39+
"source": [
40+
"import jax\n",
41+
"jax.config.update(\"jax_enable_x64\", True)\n",
42+
"import torch \n",
43+
"import numpy as np\n",
44+
"from s2fft.transforms.spherical import inverse, forward\n",
45+
"from s2fft.precompute_transforms.spherical import (\n",
46+
" inverse as precompute_inverse, forward as precompute_forward\n",
47+
")\n",
48+
"from s2fft.precompute_transforms.construct import spin_spherical_kernel_torch\n",
49+
"from s2fft.utils import signal_generator"
50+
]
51+
},
52+
{
53+
"cell_type": "markdown",
54+
"metadata": {},
55+
"source": [
56+
"Lets set up a mock problem by specifiying a bandlimit $L$ and generating some arbitrary harmonic coefficients."
57+
]
58+
},
59+
{
60+
"cell_type": "code",
61+
"execution_count": 3,
62+
"metadata": {},
63+
"outputs": [],
64+
"source": [
65+
"L = 64 \n",
66+
"rng = np.random.default_rng(1234951510)\n",
67+
"flm = torch.from_numpy(signal_generator.generate_flm(rng, L))"
68+
]
69+
},
70+
{
71+
"cell_type": "markdown",
72+
"metadata": {},
73+
"source": [
74+
"Now lets calculate the signal on the sphere by applying the inverse spherical harmonic transform"
75+
]
76+
},
77+
{
78+
"cell_type": "code",
79+
"execution_count": 4,
80+
"metadata": {},
3881
"outputs": [
3982
{
4083
"name": "stderr",
4184
"output_type": "stream",
4285
"text": [
43-
"JAX is not using 64-bit precision. This will dramatically affect numerical precision at even moderate L.\n"
86+
"An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.\n"
4487
]
4588
}
4689
],
4790
"source": [
48-
"import torch \n",
49-
"import numpy as np \n",
50-
"from s2fft.precompute_transforms.spherical import inverse, forward\n",
51-
"from s2fft.precompute_transforms.construct import spin_spherical_kernel\n",
52-
"from s2fft.utils import signal_generator"
91+
"f = inverse(flm, L, method=\"torch\")"
5392
]
5493
},
5594
{
5695
"cell_type": "markdown",
5796
"metadata": {},
5897
"source": [
59-
"Lets set up a mock problem by specifiying a bandlimit $L$ and generating some arbitrary harmonic coefficients."
98+
"To calculate the corresponding spherical harmonic representation execute"
6099
]
61100
},
62101
{
63102
"cell_type": "code",
64-
"execution_count": 3,
103+
"execution_count": 5,
65104
"metadata": {},
66105
"outputs": [],
67106
"source": [
68-
"L = 64 # Spherical harmonic bandlimit\n",
69-
"rng = np.random.default_rng(1234951510) # Random seed for signal generator\n",
70-
"flm = signal_generator.generate_flm(rng, L, using_torch=True) # Random set of spherical harmonic coefficients"
107+
"flm_check = forward(f, L, method=\"torch\")"
71108
]
72109
},
73110
{
74111
"cell_type": "markdown",
75112
"metadata": {},
76113
"source": [
77-
"For the fully precompute transform we must also generate the precompute kernels which we store as a torch tensors."
114+
"Finally, lets check the error on the round trip is as expected for 64 bit machine precision floating point arithmetic"
78115
]
79116
},
80117
{
81118
"cell_type": "code",
82-
"execution_count": 4,
119+
"execution_count": 6,
83120
"metadata": {},
84-
"outputs": [],
121+
"outputs": [
122+
{
123+
"name": "stdout",
124+
"output_type": "stream",
125+
"text": [
126+
"Mean absolute error = 2.8915048238993476e-14\n"
127+
]
128+
}
129+
],
85130
"source": [
86-
"inverse_kernel = spin_spherical_kernel(L, using_torch=True, forward=False) \n",
87-
"forward_kernel = spin_spherical_kernel(L, using_torch=True, forward=True) "
131+
"print(f\"Mean absolute error = {np.nanmean(np.abs(flm_check - flm))}\")"
88132
]
89133
},
90134
{
91135
"cell_type": "markdown",
92136
"metadata": {},
93137
"source": [
94-
"Now lets calculate the signal on the sphere by applying the inverse spherical harmonic transform"
138+
"For the fully precompute transform we must also generate the precompute kernels which we store as a torch tensors."
95139
]
96140
},
97141
{
98142
"cell_type": "code",
99-
"execution_count": 5,
143+
"execution_count": 7,
100144
"metadata": {},
101145
"outputs": [],
102146
"source": [
103-
"f = inverse(flm, L, 0, inverse_kernel, method=\"torch\")"
147+
"inverse_kernel = spin_spherical_kernel_torch(L, forward=False) \n",
148+
"forward_kernel = spin_spherical_kernel_torch(L, forward=True) "
104149
]
105150
},
106151
{
107152
"cell_type": "markdown",
108153
"metadata": {},
109154
"source": [
110-
"To calculate the corresponding spherical harmonic representation execute"
155+
"We then pass the kernels as additional arguments to the transform functions"
111156
]
112157
},
113158
{
114159
"cell_type": "code",
115-
"execution_count": 6,
160+
"execution_count": 8,
116161
"metadata": {},
117162
"outputs": [],
118163
"source": [
119-
"flm_check = forward(f, L, 0, forward_kernel, method=\"torch\")"
164+
"precompute_f = precompute_inverse(flm, L, kernel=inverse_kernel, method=\"torch\")\n",
165+
"precompute_flm_check = precompute_forward(f, L, kernel=forward_kernel, method=\"torch\")"
120166
]
121167
},
122168
{
123169
"cell_type": "markdown",
124170
"metadata": {},
125171
"source": [
126-
"Finally, lets check the error on the roundtrip is at 64bit machine precision"
172+
"Again, we check the error on the round trip is as expected"
127173
]
128174
},
129175
{
130176
"cell_type": "code",
131-
"execution_count": 7,
177+
"execution_count": 9,
132178
"metadata": {},
133179
"outputs": [
134180
{
135181
"name": "stdout",
136182
"output_type": "stream",
137183
"text": [
138-
"Mean absolute error = 1.1866908936078849e-14\n"
184+
"Mean absolute error = 2.904741595325594e-14\n"
139185
]
140186
}
141187
],
142188
"source": [
143-
"print(f\"Mean absolute error = {np.nanmean(np.abs(flm_check - flm))}\")"
189+
"print(f\"Mean absolute error = {np.nanmean(np.abs(precompute_flm_check - flm))}\")"
144190
]
145191
}
146192
],
147193
"metadata": {
148194
"kernelspec": {
149-
"display_name": "Python 3.10.4 ('s2fft')",
195+
"display_name": "s2fft",
150196
"language": "python",
151197
"name": "python3"
152198
},
@@ -160,14 +206,9 @@
160206
"name": "python",
161207
"nbconvert_exporter": "python",
162208
"pygments_lexer": "ipython3",
163-
"version": "3.10.0"
209+
"version": "3.11.10"
164210
},
165-
"orig_nbformat": 4,
166-
"vscode": {
167-
"interpreter": {
168-
"hash": "3425e24474cbe920550266ea26b478634978cc419579f9dbcf479231067df6a3"
169-
}
170-
}
211+
"orig_nbformat": 4
171212
},
172213
"nbformat": 4,
173214
"nbformat_minor": 2

0 commit comments

Comments
 (0)