1+ from __future__ import annotations
2+
13from functools import partial
24
35import jax .numpy as jnp
46import numpy as np
57from jax import jit
68
9+ from s2fft import recursions
10+ from s2fft .utils import quadrature , quadrature_jax
11+
712
813def inverse_transform (
914 flmn : np .ndarray ,
10- DW : np .ndarray ,
1115 L : int ,
1216 N : int ,
17+ precomps : tuple [np .ndarray , np .ndarray ] | None = None ,
1318 reality : bool = False ,
1419 sampling : str = "mw" ,
1520) -> np .ndarray :
@@ -18,10 +23,11 @@ def inverse_transform(
1823
1924 Args:
2025 flmn (np.ndarray): Wigner coefficients.
21- DW (Tuple[np.ndarray, np.ndarray]): Fourier coefficients of the reduced
22- Wigner d-functions and the corresponding upsampled quadrature weights.
2326 L (int): Harmonic band-limit.
2427 N (int): Azimuthal band-limit.
28+ precomps (tuple[np.ndarray, np.ndarray], optional): Fourier coefficients of the
29+ reduced Wigner d-functions and the corresponding upsampled quadrature
30+ weights. Defaults to None.
2531 reality (bool, optional): Whether the signal on the sphere is real. If so,
2632 conjugate symmetry is exploited to reduce computational costs.
2733 Defaults to False.
@@ -37,9 +43,6 @@ def inverse_transform(
3743 f"Fourier-Wigner algorithm does not support { sampling } sampling."
3844 )
3945
40- # EXTRACT VARIOUS PRECOMPUTES
41- Delta , _ = DW
42-
4346 # INDEX VALUES
4447 n_start_ind = N - 1 if reality else 0
4548 n_dim = N if reality else 2 * N - 1
@@ -52,15 +55,29 @@ def inverse_transform(
5255 m = np .arange (- L + 1 - m_offset , L )
5356 n = np .arange (n_start_ind - N + 1 , N )
5457
55- # Calculate fmna = i^(n-m)\sum_L Delta ^l_am Delta ^l_an f^l_mn(2l+1)/(8pi^2)
58+ # Calculate fmna = i^(n-m)\sum_L delta ^l_am delta ^l_an f^l_mn(2l+1)/(8pi^2)
5659 x = np .zeros ((xnlm_size , xnlm_size , n_dim ), dtype = flmn .dtype )
57- x [m_offset :, m_offset :] = np .einsum (
58- "nlm,lam,lan,l->amn" ,
59- flmn [n_start_ind :],
60- Delta ,
61- Delta [:, :, L - 1 + n ],
62- (2 * np .arange (L ) + 1 ) / (8 * np .pi ** 2 ),
63- )
60+ flmn = np .einsum ("nlm,l->nlm" , flmn , (2 * np .arange (L ) + 1 ) / (8 * np .pi ** 2 ))
61+
62+ # PRECOMPUTE TRANSFORM
63+ if precomps is not None :
64+ delta , _ = precomps
65+ x = np .zeros ((xnlm_size , xnlm_size , n_dim ), dtype = flmn .dtype )
66+ x [m_offset :, m_offset :] = np .einsum (
67+ "nlm,lam,lan->amn" , flmn [n_start_ind :], delta , delta [:, :, L - 1 + n ]
68+ )
69+
70+ # OTF TRANSFORM
71+ else :
72+ delta_el = np .zeros ((2 * L - 1 , 2 * L - 1 ), dtype = np .float64 )
73+ for el in range (L ):
74+ delta_el = recursions .risbo .compute_full (delta_el , np .pi / 2 , L , el )
75+ x [m_offset :, m_offset :] += np .einsum (
76+ "nm,am,an->amn" ,
77+ flmn [n_start_ind :, el ],
78+ delta_el ,
79+ delta_el [:, L - 1 + n ],
80+ )
6481
6582 # APPLY SIGN FUNCTION AND PHASE SHIFT
6683 x = np .einsum ("amn,m,n,a->nam" , x , 1j ** (- m ), 1j ** (n ), np .exp (1j * m * theta0 ))
@@ -77,12 +94,12 @@ def inverse_transform(
7794 return np .fft .ifft2 (x , axes = (0 , 2 ), norm = "forward" )
7895
7996
80- @partial (jit , static_argnums = (2 , 3 , 4 , 5 ))
97+ @partial (jit , static_argnums = (1 , 2 , 4 , 5 ))
8198def inverse_transform_jax (
8299 flmn : jnp .ndarray ,
83- DW : jnp .ndarray ,
84100 L : int ,
85101 N : int ,
102+ precomps : tuple [jnp .ndarray , jnp .ndarray ] | None = None ,
86103 reality : bool = False ,
87104 sampling : str = "mw" ,
88105) -> jnp .ndarray :
@@ -91,10 +108,11 @@ def inverse_transform_jax(
91108
92109 Args:
93110 flmn (jnp.ndarray): Wigner coefficients.
94- DW (Tuple[np.ndarray, np.ndarray]): Fourier coefficients of the reduced
95- Wigner d-functions and the corresponding upsampled quadrature weights.
96111 L (int): Harmonic band-limit.
97112 N (int): Azimuthal band-limit.
113+ precomps (tuple[np.ndarray, np.ndarray], optional): Fourier coefficients of the
114+ reduced Wigner d-functions and the corresponding upsampled quadrature
115+ weights. Defaults to None.
98116 reality (bool, optional): Whether the signal on the sphere is real. If so,
99117 conjugate symmetry is exploited to reduce computational costs.
100118 Defaults to False.
@@ -110,9 +128,6 @@ def inverse_transform_jax(
110128 f"Fourier-Wigner algorithm does not support { sampling } sampling."
111129 )
112130
113- # EXTRACT VARIOUS PRECOMPUTES
114- Delta , _ = DW
115-
116131 # INDEX VALUES
117132 n_start_ind = N - 1 if reality else 0
118133 n_dim = N if reality else 2 * N - 1
@@ -125,14 +140,32 @@ def inverse_transform_jax(
125140 m = jnp .arange (- L + 1 - m_offset , L )
126141 n = jnp .arange (n_start_ind - N + 1 , N )
127142
128- # Calculate fmna = i^(n-m)\sum_L Delta ^l_am Delta ^l_an f^l_mn(2l+1)/(8pi^2)
143+ # Calculate fmna = i^(n-m)\sum_L delta ^l_am delta ^l_an f^l_mn(2l+1)/(8pi^2)
129144 x = jnp .zeros ((xnlm_size , xnlm_size , n_dim ), dtype = jnp .complex128 )
130145 flmn = jnp .einsum ("nlm,l->nlm" , flmn , (2 * jnp .arange (L ) + 1 ) / (8 * jnp .pi ** 2 ))
131- x = x .at [m_offset :, m_offset :].set (
132- jnp .einsum (
133- "nlm,lam,lan->amn" , flmn [n_start_ind :], Delta , Delta [:, :, L - 1 + n ]
146+
147+ # PRECOMPUTE TRANSFORM
148+ if precomps is not None :
149+ delta , _ = precomps
150+ x = x .at [m_offset :, m_offset :].set (
151+ jnp .einsum (
152+ "nlm,lam,lan->amn" , flmn [n_start_ind :], delta , delta [:, :, L - 1 + n ]
153+ )
134154 )
135- )
155+
156+ # OTF TRANSFORM
157+ else :
158+ delta_el = jnp .zeros ((2 * L - 1 , 2 * L - 1 ), dtype = jnp .float64 )
159+ for el in range (L ):
160+ delta_el = recursions .risbo_jax .compute_full (delta_el , jnp .pi / 2 , L , el )
161+ x = x .at [m_offset :, m_offset :].add (
162+ jnp .einsum (
163+ "nm,am,an->amn" ,
164+ flmn [n_start_ind :, el ],
165+ delta_el ,
166+ delta_el [:, L - 1 + n ],
167+ )
168+ )
136169
137170 # APPLY SIGN FUNCTION AND PHASE SHIFT
138171 x = jnp .einsum ("amn,m,n,a->nam" , x , 1j ** (- m ), 1j ** (n ), jnp .exp (1j * m * theta0 ))
@@ -151,9 +184,9 @@ def inverse_transform_jax(
151184
152185def forward_transform (
153186 f : np .ndarray ,
154- DW : np .ndarray ,
155187 L : int ,
156188 N : int ,
189+ precomps : tuple [np .ndarray , np .ndarray ] | None = None ,
157190 reality : bool = False ,
158191 sampling : str = "mw" ,
159192) -> np .ndarray :
@@ -162,10 +195,11 @@ def forward_transform(
162195
163196 Args:
164197 f (np.ndarray): Function sampled on the rotation group.
165- DW (Tuple[np.ndarray, np.ndarray], optional): Fourier coefficients of the reduced
166- Wigner d-functions and the corresponding upsampled quadrature weights.
167198 L (int): Harmonic band-limit.
168199 N (int): Azimuthal band-limit.
200+ precomps (tuple[np.ndarray, np.ndarray], optional): Fourier coefficients of the
201+ reduced Wigner d-functions and the corresponding upsampled quadrature
202+ weights. Defaults to None.
169203 reality (bool, optional): Whether the signal on the sphere is real. If so,
170204 conjugate symmetry is exploited to reduce computational costs.
171205 Defaults to False.
@@ -181,9 +215,6 @@ def forward_transform(
181215 f"Fourier-Wigner algorithm does not support { sampling } sampling."
182216 )
183217
184- # EXTRACT VARIOUS PRECOMPUTES
185- Delta , Quads = DW
186-
187218 # INDEX VALUES
188219 n_start_ind = N - 1 if reality else 0
189220 m_offset = 1 if sampling .lower () == "mwss" else 0
@@ -223,14 +254,39 @@ def forward_transform(
223254 # NB: Our convention here is conjugate to that of SSHT, in which
224255 # the weights are conjugate but applied flipped and therefore are
225256 # equivalent. To avoid flipping here we simply conjugate the weights.
226- x = np .einsum ("nbm,b->nbm" , x , Quads )
257+
258+ if precomps is not None :
259+ # PRECOMPUTE TRANSFORM
260+ delta , quads = precomps
261+ else :
262+ # OTF TRANSFORM
263+ delta = None
264+ # COMPUTE QUADRATURE WEIGHTS
265+ quads = np .zeros (4 * L - 3 , dtype = np .complex128 )
266+ for mm in range (- 2 * (L - 1 ), 2 * (L - 1 ) + 1 ):
267+ quads [mm + 2 * (L - 1 )] = quadrature .mw_weights (- mm )
268+ quads = np .fft .ifft (np .fft .ifftshift (quads ), norm = "forward" )
269+
270+ # APPLY QUADRATURE
271+ x = np .einsum ("nbm,b->nbm" , x , quads )
227272
228273 # COMPUTE GMM BY FFT
229274 x = np .fft .fft (x , axis = 1 , norm = "forward" )
230275 x = np .fft .fftshift (x , axes = 1 )[:, L - 1 : 3 * L - 2 ]
231276
232- # Calculate flmn = i^(n-m)\sum_t Delta^l_tm Delta^l_tn G_mnt
233- x = np .einsum ("nam,lam,lan->nlm" , x , Delta , Delta [:, :, L - 1 + n ])
277+ # CALCULATE flmn = i^(n-m)\sum_t delta^l_tm delta^l_tn G_mnt
278+ if delta is not None :
279+ # PRECOMPUTE TRANSFORM
280+ x = np .einsum ("nam,lam,lan->nlm" , x , delta , delta [:, :, L - 1 + n ])
281+ else :
282+ # OTF TRANSFORM
283+ delta_el = np .zeros ((2 * L - 1 , 2 * L - 1 ), dtype = np .float64 )
284+ xx = np .zeros ((x .shape [0 ], L , x .shape [- 1 ]), dtype = x .dtype )
285+ for el in range (L ):
286+ delta_el = recursions .risbo .compute_full (delta_el , np .pi / 2 , L , el )
287+ xx [:, el ] = np .einsum ("nam,am,an->nm" , x , delta_el , delta_el [:, L - 1 + n ])
288+ x = xx
289+
234290 x = np .einsum ("nbm,m,n->nbm" , x , 1j ** (m ), 1j ** (- n ))
235291
236292 # SYMMETRY REFLECT FOR N < 0
@@ -246,12 +302,12 @@ def forward_transform(
246302 return x * (2.0 * np .pi ) ** 2
247303
248304
249- @partial (jit , static_argnums = (2 , 3 , 4 , 5 ))
305+ @partial (jit , static_argnums = (1 , 2 , 4 , 5 ))
250306def forward_transform_jax (
251307 f : jnp .ndarray ,
252- DW : jnp .ndarray ,
253308 L : int ,
254309 N : int ,
310+ precomps : tuple [jnp .ndarray , jnp .ndarray ] | None = None ,
255311 reality : bool = False ,
256312 sampling : str = "mw" ,
257313) -> jnp .ndarray :
@@ -260,10 +316,11 @@ def forward_transform_jax(
260316
261317 Args:
262318 f (jnp.ndarray): Function sampled on the rotation group.
263- DW (Tuple[jnp.ndarray, jnp.ndarray]): Fourier coefficients of the reduced
264- Wigner d-functions and the corresponding upsampled quadrature weights.
265319 L (int): Harmonic band-limit.
266320 N (int): Azimuthal band-limit.
321+ precomps (tuple[np.ndarray, np.ndarray], optional): Fourier coefficients of the
322+ reduced Wigner d-functions and the corresponding upsampled quadrature
323+ weights. Defaults to None.
267324 reality (bool, optional): Whether the signal on the sphere is real. If so,
268325 conjugate symmetry is exploited to reduce computational costs.
269326 Defaults to False.
@@ -279,9 +336,6 @@ def forward_transform_jax(
279336 f"Fourier-Wigner algorithm does not support { sampling } sampling."
280337 )
281338
282- # EXTRACT VARIOUS PRECOMPUTES
283- Delta , Quads = DW
284-
285339 # INDEX VALUES
286340 n_start_ind = N - 1 if reality else 0
287341 m_offset = 1 if sampling .lower () == "mwss" else 0
@@ -321,14 +375,41 @@ def forward_transform_jax(
321375 # NB: Our convention here is conjugate to that of SSHT, in which
322376 # the weights are conjugate but applied flipped and therefore are
323377 # equivalent. To avoid flipping here we simply conjugate the weights.
324- x = jnp .einsum ("nbm,b->nbm" , x , Quads )
378+
379+ if precomps is not None :
380+ # PRECOMPUTE TRANSFORM
381+ delta , quads = precomps
382+ else :
383+ # OTF TRANSFORM
384+ delta = None
385+ # COMPUTE QUADRATURE WEIGHTS
386+ quads = jnp .zeros (4 * L - 3 , dtype = jnp .complex128 )
387+ for mm in range (- 2 * (L - 1 ), 2 * (L - 1 ) + 1 ):
388+ quads = quads .at [mm + 2 * (L - 1 )].set (quadrature_jax .mw_weights (- mm ))
389+ quads = jnp .fft .ifft (jnp .fft .ifftshift (quads ), norm = "forward" )
390+
391+ # APPLY QUADRATURE
392+ x = jnp .einsum ("nbm,b->nbm" , x , quads )
325393
326394 # COMPUTE GMM BY FFT
327395 x = jnp .fft .fft (x , axis = 1 , norm = "forward" )
328396 x = jnp .fft .fftshift (x , axes = 1 )[:, L - 1 : 3 * L - 2 ]
329397
330- # Calculate flmn = i^(n-m)\sum_t Delta^l_tm Delta^l_tn G_mnt
331- x = jnp .einsum ("nam,lam,lan->nlm" , x , Delta , Delta [:, :, L - 1 + n ])
398+ # Calculate flmn = i^(n-m)\sum_t delta^l_tm delta^l_tn G_mnt
399+ if delta is not None :
400+ # PRECOMPUTE TRANSFORM
401+ x = jnp .einsum ("nam,lam,lan->nlm" , x , delta , delta [:, :, L - 1 + n ])
402+ else :
403+ # OTF TRANSFORM
404+ delta_el = jnp .zeros ((2 * L - 1 , 2 * L - 1 ), dtype = jnp .float64 )
405+ xx = jnp .zeros ((x .shape [0 ], L , x .shape [- 1 ]), dtype = x .dtype )
406+ for el in range (L ):
407+ delta_el = recursions .risbo_jax .compute_full (delta_el , jnp .pi / 2 , L , el )
408+ xx = xx .at [:, el ].set (
409+ jnp .einsum ("nam,am,an->nm" , x , delta_el , delta_el [:, L - 1 + n ])
410+ )
411+ x = xx
412+
332413 x = jnp .einsum ("nbm,m,n->nbm" , x , 1j ** (m ), 1j ** (- n ))
333414
334415 # SYMMETRY REFLECT FOR N < 0
0 commit comments