@@ -110,14 +110,13 @@ def _commit_descriptor(a, forward, in_place, c2c, a_strides, index, batch_fft):
110110 return dsc , out_strides
111111
112112
113- def _complex_nd_fft (
113+ def _c2c_nd_fft (
114114 a ,
115115 s ,
116116 norm ,
117117 out ,
118118 forward ,
119119 in_place ,
120- c2c ,
121120 axes ,
122121 batch_fft ,
123122 * ,
@@ -126,34 +125,38 @@ def _complex_nd_fft(
126125 """Computes complex-to-complex FFT of the input N-D array."""
127126
128127 len_axes = len (axes )
129- # OneMKL supports up to 3-dimensional FFT on GPU
130- # repeated axis in OneMKL FFT is not allowed
128+ # oneMKL supports up to 3-dimensional FFT on GPU
129+ # repeated axis in oneMKL FFT is not allowed
131130 if len_axes > 3 or len (set (axes )) < len_axes :
132131 axes_chunk , shape_chunk = _extract_axes_chunk (
133132 axes , s , chunk_size = 3 , reversed_axes = reversed_axes
134133 )
134+
135+ # We try to use in-place calculations where possible, which is
136+ # everywhere except when the size changes after the first iteration.
137+ size_changes = [axis for axis , n in zip (axes , s ) if a .shape [axis ] != n ]
138+
139+ # cannot use out in the intermediate steps if size changes
140+ res = None if size_changes else out
141+
135142 for i , (s_chunk , a_chunk ) in enumerate (zip (shape_chunk , axes_chunk )):
136143 a = _truncate_or_pad (a , shape = s_chunk , axes = a_chunk )
137- # if out is used in an intermediate step, it will have memory
138- # overlap with input and cannot be used in the final step (a new
139- # result array will be created for the final step), so there is no
140- # benefit in using out in an intermediate step
141- if i == len (axes_chunk ) - 1 :
142- tmp_out = out
143- else :
144- tmp_out = None
144+ # if size_changes, out cannot be used in intermediate steps
145+ if size_changes and i == len (axes_chunk ) - 1 :
146+ res = out
145147
146148 a = _fft (
147149 a ,
148150 norm = norm ,
149- out = tmp_out ,
151+ out = res ,
150152 forward = forward ,
151- # TODO: in-place FFT is only implemented for c2c, see SAT-7154
152- in_place = in_place and c2c ,
153- c2c = c2c ,
153+ in_place = in_place ,
154+ c2c = True ,
154155 axes = a_chunk ,
155156 )
156-
157+ if not size_changes :
158+ # Default output for next iteration.
159+ res = a
157160 return a
158161
159162 a = _truncate_or_pad (a , s , axes )
@@ -165,9 +168,8 @@ def _complex_nd_fft(
165168 norm = norm ,
166169 out = out ,
167170 forward = forward ,
168- # TODO: in-place FFT is only implemented for c2c, see SAT-7154
169- in_place = in_place and c2c ,
170- c2c = c2c ,
171+ in_place = in_place ,
172+ c2c = True ,
171173 axes = axes ,
172174 batch_fft = batch_fft ,
173175 )
@@ -198,7 +200,7 @@ def _compute_result(dsc, a, out, forward, c2c, out_strides):
198200 res_usm = dpnp .get_usm_ndarray (out )
199201 result = out
200202 else :
201- # Result array that is used in OneMKL must have the exact same
203+ # Result array that is used in oneMKL must have the exact same
202204 # stride as input array
203205
204206 if c2c : # c2c FFT
@@ -277,9 +279,9 @@ def _copy_array(x, complex_input):
277279 dtype = x .dtype
278280 copy_flag = False
279281 if numpy .min (x .strides ) < 0 :
280- # negative stride is not allowed in OneMKL FFT
282+ # negative stride is not allowed in oneMKL FFT
281283 # TODO: support for negative strides will be added in the future
282- # versions of OneMKL , see discussion in MKLD-17597
284+ # versions of oneMKL , see discussion in MKLD-17597
283285 copy_flag = True
284286
285287 if complex_input and not dpnp .issubdtype (dtype , dpnp .complexfloating ):
@@ -388,6 +390,9 @@ def _fft(a, norm, out, forward, in_place, c2c, axes, batch_fft=True):
388390
389391 index = 0
390392 fft_1d = isinstance (axes , int )
393+ if not in_place and out is not None :
394+ # if input and output are the same array, use in-place FFT
395+ in_place = dpnp .are_same_logical_tensors (a , out )
391396 if batch_fft :
392397 len_axes = 1 if fft_1d else len (axes )
393398 local_axes = numpy .arange (- len_axes , 0 )
@@ -627,9 +632,6 @@ def dpnp_fft(a, forward, real, n=None, axis=-1, norm=None, out=None):
627632 _validate_out_keyword (a , out , (n ,), (axis ,), c2c , c2r , r2c )
628633 # if input array is copied, in-place FFT can be used
629634 a , in_place = _copy_array (a , c2c or c2r )
630- if not in_place and out is not None :
631- # if input is also given for out, in-place FFT can be used
632- in_place = dpnp .are_same_logical_tensors (a , out )
633635
634636 if a .size == 0 :
635637 return dpnp .get_result_array (a , out = out , casting = "same_kind" )
@@ -695,63 +697,74 @@ def dpnp_fftn(a, forward, real, s=None, axes=None, norm=None, out=None):
695697 )
696698
697699 if r2c :
698- # a 1D real-to-complext FFT is performed on the last axis and then
700+ size_changes = [axis for axis , n in zip (axes , s ) if a .shape [axis ] != n ]
701+ # cannot use out in the intermediate steps if size changes
702+ res = None if size_changes else out
703+
704+ # a 1D real-to-complex FFT is performed on the last axis and then
699705 # an N-D complex-to-complex FFT over the remaining axes
700706 a = _truncate_or_pad (a , (s [- 1 ],), (axes [- 1 ],))
701707 a = _fft (
702708 a ,
703709 norm = norm ,
704- # if out is used in an intermediate step, it will have memory
705- # overlap with input and cannot be used in the final step (a new
706- # result array will be created for the final step), so there is no
707- # benefit in using out in an intermediate step
708- out = None ,
710+ out = res ,
709711 forward = forward ,
710- in_place = in_place and c2c ,
711- c2c = c2c ,
712+ in_place = False ,
713+ c2c = False ,
712714 axes = axes [- 1 ],
713715 batch_fft = a .ndim != 1 ,
714716 )
715- return _complex_nd_fft (
717+ return _c2c_nd_fft (
716718 a ,
717- s = s ,
719+ s = s [: - 1 ] ,
718720 norm = norm ,
719721 out = out ,
720722 forward = forward ,
721723 in_place = in_place ,
722- c2c = True ,
723724 axes = axes [:- 1 ],
724725 batch_fft = a .ndim != len_axes - 1 ,
725726 )
726727
727728 if c2r :
728729 # an N-D complex-to-complex FFT is performed on all axes except the
729730 # last one then a 1D complex-to-real FFT is performed on the last axis
730- a = _complex_nd_fft (
731+ a = _c2c_nd_fft (
731732 a ,
732- s = s ,
733+ s = s [: - 1 ] ,
733734 norm = norm ,
734735 # out has real dtype and cannot be used in intermediate steps
735736 out = None ,
736737 forward = forward ,
737738 in_place = in_place ,
738- c2c = True ,
739739 axes = axes [:- 1 ],
740740 batch_fft = a .ndim != len_axes - 1 ,
741741 reversed_axes = False ,
742742 )
743743 a = _truncate_or_pad (a , (s [- 1 ],), (axes [- 1 ],))
744- if c2r :
745- a = _make_array_hermitian (
746- a , axes [- 1 ], dpnp .are_same_logical_tensors (a , a_orig )
747- )
744+ a = _make_array_hermitian (
745+ a , axes [- 1 ], dpnp .are_same_logical_tensors (a , a_orig )
746+ )
748747 return _fft (
749- a , norm , out , forward , in_place and c2c , c2c , axes [- 1 ], a .ndim != 1
748+ a ,
749+ norm = norm ,
750+ out = out ,
751+ forward = forward ,
752+ in_place = False ,
753+ c2c = False ,
754+ axes = axes [- 1 ],
755+ batch_fft = a .ndim != 1 ,
750756 )
751757
752758 # c2c
753- return _complex_nd_fft (
754- a , s , norm , out , forward , in_place , c2c , axes , a .ndim != len_axes
759+ return _c2c_nd_fft (
760+ a ,
761+ s = s ,
762+ norm = norm ,
763+ out = out ,
764+ forward = forward ,
765+ in_place = in_place ,
766+ axes = axes ,
767+ batch_fft = a .ndim != len_axes ,
755768 )
756769
757770
0 commit comments