11import math
2- import operator
32
43import mlx .core as mx
5- import numpy as np
64
75from keras .src .backend import standardize_dtype
6+ from keras .src .backend .common .backend_utils import canonicalize_axis
87from keras .src .backend .mlx .core import convert_to_tensor
98from keras .src .backend .mlx .linalg import det
109from keras .src .utils .module_utils import scipy
@@ -23,26 +22,31 @@ def _segment_reduction_fn(
2322 if num_segments is None :
2423 num_segments = mx .max (segment_ids ) + 1
2524
26- valid_indices = segment_ids >= 0
27- valid_data = mx .array (
28- np .array (data )[valid_indices ] # MLX does not support boolean indices
29- )
30- valid_segment_ids = mx .array (np .array (segment_ids )[valid_indices ])
31-
32- data_shape = list (valid_data .shape )
33- data_shape [0 ] = num_segments
25+ mask = segment_ids >= 0
26+ # pack segment_ids < 0 into index 0 and then handle below
27+ safe_segment_ids = mx .where (mask , segment_ids , 0 )
3428
3529 if not sorted :
36- sort_indices = mx .argsort (valid_segment_ids )
37- valid_segment_ids = valid_segment_ids [sort_indices ]
38- valid_data = valid_data [sort_indices ]
30+ sort_indices = mx .argsort (safe_segment_ids )
31+ safe_segment_ids = mx .take (safe_segment_ids , sort_indices )
32+ data = mx .take (data , sort_indices , axis = 0 )
33+ mask = mx .take (mask , sort_indices )
34+
35+ # expand mask dimensions to match data dimensions
36+ for i in range (1 , len (data .shape )):
37+ mask = mx .expand_dims (mask , axis = i )
38+
39+ data_shape = list (data .shape )
40+ data_shape [0 ] = num_segments
3941
4042 if reduction_method == "max" :
41- result = mx .ones (data_shape , dtype = valid_data .dtype ) * - mx .inf
42- result = result .at [valid_segment_ids ].maximum (valid_data )
43+ masked_data = mx .where (mask , data , - mx .inf )
44+ result = mx .ones (data_shape , dtype = data .dtype ) * - mx .inf
45+ result = result .at [safe_segment_ids ].maximum (masked_data )
4346 else : # sum
44- result = mx .zeros (data_shape , dtype = valid_data .dtype )
45- result = result .at [valid_segment_ids ].add (valid_data )
47+ masked_data = mx .where (mask , data , 0 )
48+ result = mx .zeros (data_shape , dtype = data .dtype )
49+ result = result .at [safe_segment_ids ].add (masked_data )
4650
4751 return result
4852
@@ -154,19 +158,6 @@ def irfft(x, fft_length=None):
154158 return real_output
155159
156160
157- def _canonicalize_axis (axis , num_dims ):
158- # Ref: jax.scipy.signal.stft
159- """Canonicalize an axis in [-num_dims, num_dims) to [0, num_dims)."""
160- axis = operator .index (axis )
161- if not - num_dims <= axis < num_dims :
162- raise ValueError (
163- f"axis { axis } is out of bounds for array of dimension { num_dims } "
164- )
165- if axis < 0 :
166- axis = axis + num_dims
167- return axis
168-
169-
170161def _create_sliding_windows (x , window_size , step ):
171162 batch_size , signal_length , _ = x .shape
172163 num_windows = (signal_length - window_size ) // step + 1
@@ -187,7 +178,7 @@ def _create_sliding_windows(x, window_size, step):
187178
188179def _stft (x , window , nperseg , noverlap , nfft , axis = - 1 ):
189180 # Ref: jax.scipy.signal.stft
190- axis = _canonicalize_axis (axis , x .ndim )
181+ axis = canonicalize_axis (axis , x .ndim )
191182 result_dtype = mx .complex64
192183
193184 if x .size == 0 :
@@ -364,8 +355,8 @@ def _istft(
364355 # Ref: jax.scipy.signal.istft
365356 if Zxx .ndim < 2 :
366357 raise ValueError ("Input stft must be at least 2d!" )
367- freq_axis = _canonicalize_axis (freq_axis , Zxx .ndim )
368- time_axis = _canonicalize_axis (time_axis , Zxx .ndim )
358+ freq_axis = canonicalize_axis (freq_axis , Zxx .ndim )
359+ time_axis = canonicalize_axis (time_axis , Zxx .ndim )
369360
370361 if freq_axis == time_axis :
371362 raise ValueError ("Must specify differing time and frequency axes!" )
0 commit comments