1- from __future__ import annotations
2-
31from collections .abc import Sequence
4- from typing import Union , Optional , Literal
2+ from typing import Literal , TypeAlias
3+
4+ from ._typing import Array , Device , DType , Namespace
55
6- from . _typing import Device , Array , DType , Namespace
6+ _Norm : TypeAlias = Literal [ "backward" , "ortho" , "forward" ]
77
88# Note: NumPy fft functions improperly upcast float32 and complex64 to
99# complex128, which is why we require wrapping them all here.
@@ -13,9 +13,9 @@ def fft(
1313 / ,
1414 xp : Namespace ,
1515 * ,
16- n : Optional [ int ] = None ,
16+ n : int | None = None ,
1717 axis : int = - 1 ,
18- norm : Literal [ "backward" , "ortho" , "forward" ] = "backward" ,
18+ norm : _Norm = "backward" ,
1919) -> Array :
2020 res = xp .fft .fft (x , n = n , axis = axis , norm = norm )
2121 if x .dtype in [xp .float32 , xp .complex64 ]:
@@ -27,9 +27,9 @@ def ifft(
2727 / ,
2828 xp : Namespace ,
2929 * ,
30- n : Optional [ int ] = None ,
30+ n : int | None = None ,
3131 axis : int = - 1 ,
32- norm : Literal [ "backward" , "ortho" , "forward" ] = "backward" ,
32+ norm : _Norm = "backward" ,
3333) -> Array :
3434 res = xp .fft .ifft (x , n = n , axis = axis , norm = norm )
3535 if x .dtype in [xp .float32 , xp .complex64 ]:
@@ -41,9 +41,9 @@ def fftn(
4141 / ,
4242 xp : Namespace ,
4343 * ,
44- s : Sequence [int ] = None ,
45- axes : Sequence [int ] = None ,
46- norm : Literal [ "backward" , "ortho" , "forward" ] = "backward" ,
44+ s : Sequence [int ] | None = None ,
45+ axes : Sequence [int ] | None = None ,
46+ norm : _Norm = "backward" ,
4747) -> Array :
4848 res = xp .fft .fftn (x , s = s , axes = axes , norm = norm )
4949 if x .dtype in [xp .float32 , xp .complex64 ]:
@@ -55,9 +55,9 @@ def ifftn(
5555 / ,
5656 xp : Namespace ,
5757 * ,
58- s : Sequence [int ] = None ,
59- axes : Sequence [int ] = None ,
60- norm : Literal [ "backward" , "ortho" , "forward" ] = "backward" ,
58+ s : Sequence [int ] | None = None ,
59+ axes : Sequence [int ] | None = None ,
60+ norm : _Norm = "backward" ,
6161) -> Array :
6262 res = xp .fft .ifftn (x , s = s , axes = axes , norm = norm )
6363 if x .dtype in [xp .float32 , xp .complex64 ]:
@@ -69,9 +69,9 @@ def rfft(
6969 / ,
7070 xp : Namespace ,
7171 * ,
72- n : Optional [ int ] = None ,
72+ n : int | None = None ,
7373 axis : int = - 1 ,
74- norm : Literal [ "backward" , "ortho" , "forward" ] = "backward" ,
74+ norm : _Norm = "backward" ,
7575) -> Array :
7676 res = xp .fft .rfft (x , n = n , axis = axis , norm = norm )
7777 if x .dtype == xp .float32 :
@@ -83,9 +83,9 @@ def irfft(
8383 / ,
8484 xp : Namespace ,
8585 * ,
86- n : Optional [ int ] = None ,
86+ n : int | None = None ,
8787 axis : int = - 1 ,
88- norm : Literal [ "backward" , "ortho" , "forward" ] = "backward" ,
88+ norm : _Norm = "backward" ,
8989) -> Array :
9090 res = xp .fft .irfft (x , n = n , axis = axis , norm = norm )
9191 if x .dtype == xp .complex64 :
@@ -97,9 +97,9 @@ def rfftn(
9797 / ,
9898 xp : Namespace ,
9999 * ,
100- s : Sequence [int ] = None ,
101- axes : Sequence [int ] = None ,
102- norm : Literal [ "backward" , "ortho" , "forward" ] = "backward" ,
100+ s : Sequence [int ] | None = None ,
101+ axes : Sequence [int ] | None = None ,
102+ norm : _Norm = "backward" ,
103103) -> Array :
104104 res = xp .fft .rfftn (x , s = s , axes = axes , norm = norm )
105105 if x .dtype == xp .float32 :
@@ -111,9 +111,9 @@ def irfftn(
111111 / ,
112112 xp : Namespace ,
113113 * ,
114- s : Sequence [int ] = None ,
115- axes : Sequence [int ] = None ,
116- norm : Literal [ "backward" , "ortho" , "forward" ] = "backward" ,
114+ s : Sequence [int ] | None = None ,
115+ axes : Sequence [int ] | None = None ,
116+ norm : _Norm = "backward" ,
117117) -> Array :
118118 res = xp .fft .irfftn (x , s = s , axes = axes , norm = norm )
119119 if x .dtype == xp .complex64 :
@@ -125,9 +125,9 @@ def hfft(
125125 / ,
126126 xp : Namespace ,
127127 * ,
128- n : Optional [ int ] = None ,
128+ n : int | None = None ,
129129 axis : int = - 1 ,
130- norm : Literal [ "backward" , "ortho" , "forward" ] = "backward" ,
130+ norm : _Norm = "backward" ,
131131) -> Array :
132132 res = xp .fft .hfft (x , n = n , axis = axis , norm = norm )
133133 if x .dtype in [xp .float32 , xp .complex64 ]:
@@ -139,9 +139,9 @@ def ihfft(
139139 / ,
140140 xp : Namespace ,
141141 * ,
142- n : Optional [ int ] = None ,
142+ n : int | None = None ,
143143 axis : int = - 1 ,
144- norm : Literal [ "backward" , "ortho" , "forward" ] = "backward" ,
144+ norm : _Norm = "backward" ,
145145) -> Array :
146146 res = xp .fft .ihfft (x , n = n , axis = axis , norm = norm )
147147 if x .dtype in [xp .float32 , xp .complex64 ]:
@@ -154,8 +154,8 @@ def fftfreq(
154154 xp : Namespace ,
155155 * ,
156156 d : float = 1.0 ,
157- dtype : Optional [ DType ] = None ,
158- device : Optional [ Device ] = None ,
157+ dtype : DType | None = None ,
158+ device : Device | None = None ,
159159) -> Array :
160160 if device not in ["cpu" , None ]:
161161 raise ValueError (f"Unsupported device { device !r} " )
@@ -170,8 +170,8 @@ def rfftfreq(
170170 xp : Namespace ,
171171 * ,
172172 d : float = 1.0 ,
173- dtype : Optional [ DType ] = None ,
174- device : Optional [ Device ] = None ,
173+ dtype : DType | None = None ,
174+ device : Device | None = None ,
175175) -> Array :
176176 if device not in ["cpu" , None ]:
177177 raise ValueError (f"Unsupported device { device !r} " )
@@ -181,12 +181,12 @@ def rfftfreq(
181181 return res
182182
183183def fftshift (
184- x : Array , / , xp : Namespace , * , axes : Union [ int , Sequence [int ]] = None
184+ x : Array , / , xp : Namespace , * , axes : int | Sequence [int ] | None = None
185185) -> Array :
186186 return xp .fft .fftshift (x , axes = axes )
187187
188188def ifftshift (
189- x : Array , / , xp : Namespace , * , axes : Union [ int , Sequence [int ]] = None
189+ x : Array , / , xp : Namespace , * , axes : int | Sequence [int ] | None = None
190190) -> Array :
191191 return xp .fft .ifftshift (x , axes = axes )
192192
0 commit comments