33
44from keras .src .backend import standardize_data_format
55from keras .src .backend import standardize_dtype
6+ from keras .src .backend .common .backend_utils import (
7+ compute_conv_transpose_padding_args_for_mlx ,
8+ )
9+ from keras .src .backend .common .backend_utils import (
10+ compute_transpose_padding_args_for_mlx ,
11+ )
612from keras .src .backend .config import epsilon
713from keras .src .backend .mlx .core import convert_to_tensor
814from keras .src .backend .mlx .core import to_mlx_dtype
@@ -148,25 +154,15 @@ def conv(
148154 # mlx expects kernel with (out_channels, spatial..., in_channels)
149155 kernel = kernel .transpose (- 1 , * range (kernel .ndim - 2 ), - 2 )
150156
151- if padding == "valid" :
152- mlx_padding = 0
153- elif padding == "same" :
154- kernel_spatial_shape = kernel .shape [1 :- 1 ]
155- start_paddings = []
156- end_paddings = []
157- for dim_size , k_size , d_rate , s in zip (
158- inputs .shape [1 :- 1 ], kernel_spatial_shape , dilation_rate , strides
159- ):
160- out_size = (dim_size + s - 1 ) // s
161- effective_k_size = (k_size - 1 ) * d_rate + 1
162- total_pad = max (0 , (out_size - 1 ) * s + effective_k_size - dim_size )
163- pad_start = total_pad // 2
164- pad_end = total_pad - pad_start
165- start_paddings .append (pad_start )
166- end_paddings .append (pad_end )
167- mlx_padding = (start_paddings , end_paddings )
168- else :
169- raise ValueError (f"Invalid padding value: { padding } " )
157+ kernel_spatial_shape = kernel .shape [1 :- 1 ]
158+ input_spatial_shape = inputs .shape [1 :- 1 ]
159+ mlx_padding = compute_transpose_padding_args_for_mlx (
160+ padding ,
161+ input_spatial_shape ,
162+ kernel_spatial_shape ,
163+ dilation_rate ,
164+ strides ,
165+ )
170166
171167 channels = inputs .shape [- 1 ]
172168 kernel_in_channels = kernel .shape [- 1 ]
@@ -202,7 +198,53 @@ def depthwise_conv(
202198 data_format = None ,
203199 dilation_rate = 1 ,
204200):
205- raise NotImplementedError ("MLX backend doesn't support depthwise conv yet" )
201+ inputs = convert_to_tensor (inputs )
202+ kernel = convert_to_tensor (kernel )
203+ data_format = standardize_data_format (data_format )
204+ num_spatial_dims = inputs .ndim - 2
205+
206+ strides = standardize_tuple (strides , num_spatial_dims , "strides" )
207+ dilation_rate = standardize_tuple (
208+ dilation_rate , num_spatial_dims , "dilation_rate"
209+ )
210+
211+ if data_format == "channels_first" :
212+ # mlx expects channels_last
213+ inputs = inputs .transpose (0 , * range (2 , inputs .ndim ), 1 )
214+
215+ feature_group_count = inputs .shape [- 1 ]
216+
217+ # reshape first for depthwise conv, then transpose to expected mlx format
218+ kernel = kernel .reshape (
219+ * iter (kernel .shape [:- 2 ]), 1 , feature_group_count * kernel .shape [- 1 ]
220+ )
221+ # mlx expects kernel with (out_channels, spatial..., in_channels)
222+ kernel = kernel .transpose (- 1 , * range (kernel .ndim - 2 ), - 2 )
223+
224+ kernel_spatial_shape = kernel .shape [1 :- 1 ]
225+ input_spatial_shape = inputs .shape [1 :- 1 ]
226+ mlx_padding = compute_transpose_padding_args_for_mlx (
227+ padding ,
228+ input_spatial_shape ,
229+ kernel_spatial_shape ,
230+ dilation_rate ,
231+ strides ,
232+ )
233+
234+ result = mx .conv_general (
235+ inputs ,
236+ kernel ,
237+ stride = strides ,
238+ padding = mlx_padding ,
239+ kernel_dilation = dilation_rate ,
240+ input_dilation = 1 ,
241+ groups = feature_group_count ,
242+ flip = False ,
243+ )
244+ if data_format == "channels_first" :
245+ result = result .transpose (0 , - 1 , * range (1 , result .ndim - 1 ))
246+
247+ return result
206248
207249
208250def separable_conv (
@@ -214,7 +256,23 @@ def separable_conv(
214256 data_format = None ,
215257 dilation_rate = 1 ,
216258):
217- raise NotImplementedError ("MLX backend doesn't support separable conv yet" )
259+ data_format = standardize_data_format (data_format )
260+ depthwise_conv_output = depthwise_conv (
261+ inputs ,
262+ depthwise_kernel ,
263+ strides ,
264+ padding ,
265+ data_format ,
266+ dilation_rate ,
267+ )
268+ return conv (
269+ depthwise_conv_output ,
270+ pointwise_kernel ,
271+ strides = 1 ,
272+ padding = "valid" ,
273+ data_format = data_format ,
274+ dilation_rate = dilation_rate ,
275+ )
218276
219277
220278def conv_transpose (
@@ -226,7 +284,62 @@ def conv_transpose(
226284 data_format = None ,
227285 dilation_rate = 1 ,
228286):
229- raise NotImplementedError ("MLX backend doesn't support conv transpose yet" )
287+ inputs = convert_to_tensor (inputs )
288+ kernel = convert_to_tensor (kernel )
289+ data_format = standardize_data_format (data_format )
290+ num_spatial_dims = inputs .ndim - 2
291+
292+ strides = standardize_tuple (strides , num_spatial_dims , "strides" )
293+ dilation_rate = standardize_tuple (
294+ dilation_rate , num_spatial_dims , "dilation_rate"
295+ )
296+ if output_padding is not None :
297+ output_padding = standardize_tuple (
298+ output_padding , num_spatial_dims , "output_padding"
299+ )
300+
301+ if data_format == "channels_first" :
302+ # mlx expects channels_last
303+ inputs = inputs .transpose (0 , * range (2 , inputs .ndim ), 1 )
304+
305+ # mlx expects kernel with (out_channels, spatial..., in_channels)
306+ kernel = kernel .transpose (- 2 , * range (kernel .ndim - 2 ), - 1 )
307+ kernel_spatial_shape = kernel .shape [1 :- 1 ]
308+
309+ mlx_padding = compute_conv_transpose_padding_args_for_mlx (
310+ padding ,
311+ num_spatial_dims ,
312+ kernel_spatial_shape ,
313+ dilation_rate ,
314+ strides ,
315+ output_padding ,
316+ )
317+
318+ channels = inputs .shape [- 1 ]
319+ kernel_in_channels = kernel .shape [- 1 ]
320+ if channels % kernel_in_channels > 0 :
321+ raise ValueError (
322+ "The number of input channels must be evenly divisible by "
323+ f"kernel's in_channels. Received input channels { channels } and "
324+ f"kernel in_channels { kernel_in_channels } . "
325+ )
326+ groups = channels // kernel_in_channels
327+
328+ result = mx .conv_general (
329+ inputs ,
330+ kernel ,
331+ stride = 1 , # stride is handled by input_dilation
332+ padding = mlx_padding ,
333+ kernel_dilation = dilation_rate ,
334+ input_dilation = strides ,
335+ groups = groups ,
336+ flip = True ,
337+ )
338+
339+ if data_format == "channels_first" :
340+ result = result .transpose (0 , - 1 , * range (1 , result .ndim - 1 ))
341+
342+ return result
230343
231344
232345def one_hot (x , num_classes , axis = - 1 , dtype = "float32" , sparse = False ):
0 commit comments