@@ -54,24 +54,56 @@ def extract_sequences(x, sequence_length, sequence_stride):
5454 return x .reshape (* batch_shape , frames , sequence_length )
5555
5656
57+ def _get_complex_tensor_from_tuple (x ):
58+ if not isinstance (x , (tuple , list )) or len (x ) != 2 :
59+ raise ValueError (
60+ "Input `x` should be a tuple of two tensors - real and imaginary."
61+ f"Received: x={ x } "
62+ )
63+ real , imag = x
64+ real = convert_to_tensor (real )
65+ imag = convert_to_tensor (imag )
66+ # Check shapes.
67+ if real .shape != imag .shape :
68+ raise ValueError (
69+ "Input `x` should be a tuple of two tensors - real and imaginary."
70+ "Both the real and imaginary parts should have the same shape. "
71+ f"Received: x[0].shape = { real .shape } , x[1].shape = { imag .shape } "
72+ )
73+ # Ensure dtype is float.
74+ if not mx .issubdtype (real .dtype , mx .floating ) or not mx .issubdtype (
75+ imag .dtype , mx .floating
76+ ):
77+ raise ValueError (
78+ "At least one tensor in input `x` is not of type float."
79+ f"Received: x={ x } ."
80+ )
81+ complex_input = mx .add (real , 1j * imag )
82+ return complex_input
83+
84+
5785def fft (x ):
58- x = convert_to_tensor (x )
59- return mx .fft (x )
86+ x = _get_complex_tensor_from_tuple (x )
87+ complex_output = mx .fft .fft (x )
88+ return mx .real (complex_output ), mx .imag (complex_output )
6089
6190
6291def fft2 (x ):
63- # TODO: https://ml-explore.github.io/mlx/build/html/python/fft.html#fft
64- raise NotImplementedError ("fft not yet implemented in mlx" )
92+ x = _get_complex_tensor_from_tuple (x )
93+ complex_output = mx .fft .fft2 (x )
94+ return mx .real (complex_output ), mx .imag (complex_output )
6595
6696
6797def rfft (x , fft_length = None ):
68- # TODO: https://ml-explore.github.io/mlx/build/html/python/fft.html#fft
69- raise NotImplementedError ("fft not yet implemented in mlx" )
98+ x = convert_to_tensor (x )
99+ complex_output = mx .fft .rfft (x , n = fft_length )
100+ return mx .real (complex_output ), mx .imag (complex_output )
70101
71102
72103def irfft (x , fft_length = None ):
73- # TODO: https://ml-explore.github.io/mlx/build/html/python/fft.html#fft
74- raise NotImplementedError ("fft not yet implemented in mlx" )
104+ x = _get_complex_tensor_from_tuple (x )
105+ real_output = mx .fft .irfft (x , n = fft_length )
106+ return real_output
75107
76108
77109def stft (
0 commit comments