1- import jax .numpy as jnp
21import mlx .core as mx
2+ import numpy as np
33
44from keras .src .backend .common import dtypes
55from keras .src .backend .common import standardize_dtype
@@ -29,8 +29,8 @@ def det(a):
2929 return _det_3x3 (a )
3030 # elif len(a_shape) >= 2 and a_shape[-1] == a_shape[-2]:
3131 # TODO: Swap to mlx.linalg.det when supported
32- a = jnp .array (a )
33- output = jnp .linalg .det (a )
32+ a = np .array (a )
33+ output = np .linalg .det (a )
3434 return mx .array (output )
3535
3636
@@ -56,15 +56,26 @@ def solve_triangular(a, b, lower=False):
5656
5757
5858def qr (x , mode = "reduced" ):
59- # TODO: Swap to mlx.linalg.qr when it supports non-square matrices
60- x = jnp .array (x )
61- output = jnp .linalg .qr (x , mode = mode )
62- return mx .array (output [0 ]), mx .array (output [1 ])
59+ if mode != "reduced" :
60+ raise ValueError (
61+ "`mode` argument value not supported. "
62+ "Only 'reduced' is supported by the mlx backend. "
63+ f"Received: mode={ mode } "
64+ )
65+ with mx .stream (mx .cpu ):
66+ return mx .linalg .qr (x )
6367
6468
6569def svd (x , full_matrices = True , compute_uv = True ):
6670 with mx .stream (mx .cpu ):
67- return mx .linalg .svd (x )
71+ u , s , vt = mx .linalg .svd (x )
72+ if not compute_uv :
73+ return s
74+ if not full_matrices :
75+ n = min (x .shape [- 2 :])
76+ return u [..., :n ], s , vt [:n , ...]
77+ # mlx returns full matrices by default
78+ return u , s , vt
6879
6980
7081def cholesky (a ):
@@ -78,11 +89,15 @@ def norm(x, ord=None, axis=None, keepdims=False):
7889 dtype = dtypes .result_type (x .dtype , "float32" )
7990 x = convert_to_tensor (x , dtype = dtype )
8091 # TODO: swap to mlx.linalg.norm when it support singular value norms
81- x = jnp .array (x )
82- output = jnp .linalg .norm (x , ord = ord , axis = axis , keepdims = keepdims )
92+ x = np .array (x )
93+ output = np .linalg .norm (x , ord = ord , axis = axis , keepdims = keepdims )
8394 return mx .array (output )
8495
8596
8697def inv (a ):
8798 with mx .stream (mx .cpu ):
8899 return mx .linalg .inv (a )
100+
101+
102+ def lstsq (a , b , rcond = None ):
103+ raise NotImplementedError ("lstsq not yet implemented in mlx." )
0 commit comments