@@ -35,7 +35,8 @@ def det(a):
3535
3636
3737def eig (a ):
38- raise NotImplementedError ("eig not yet implemented in mlx." )
38+ # Using numpy for now, as mlx does not support eig yet.
39+ return np .linalg .eig (a )
3940
4041
4142def eigh (a ):
@@ -44,7 +45,9 @@ def eigh(a):
4445
4546
4647def lu_factor (a ):
47- raise NotImplementedError ("lu_factor not yet implemented in mlx." )
48+ with mx .stream (mx .cpu ):
49+ # This op is not yet supported on the GPU.
50+ return mx .linalg .lu_factor (a )
4851
4952
5053def solve (a , b ):
@@ -55,7 +58,15 @@ def solve(a, b):
5558
5659
5760def solve_triangular (a , b , lower = False ):
58- raise NotImplementedError ("solve_triangular not yet implemented in mlx." )
61+ upper = not lower
62+ with mx .stream (mx .cpu ):
63+ # This op is not yet supported on the GPU.
64+ if b .ndim == a .ndim - 1 :
65+ b = mx .expand_dims (b , axis = - 1 )
66+ return mx .squeeze (
67+ mx .linalg .solve_triangular (a , b , upper = upper ), axis = - 1
68+ )
69+ return mx .linalg .solve_triangular (a , b , upper = upper )
5970
6071
6172def qr (x , mode = "reduced" ):
@@ -103,4 +114,43 @@ def inv(a):
103114
104115
105116def lstsq (a , b , rcond = None ):
106- raise NotImplementedError ("lstsq not yet implemented in mlx." )
117+ a = convert_to_tensor (a )
118+ b = convert_to_tensor (b )
119+ if a .shape [0 ] != b .shape [0 ]:
120+ raise ValueError (
121+ "Incompatible shapes: a and b must have the same number of rows."
122+ )
123+ b_orig_ndim = b .ndim
124+ if b .ndim == 1 :
125+ b = mx .expand_dims (b , axis = - 1 )
126+ elif b .ndim > 2 :
127+ raise ValueError ("b must be 1D or 2D." )
128+
129+ if b .ndim != 2 :
130+ raise ValueError ("b must be 1D or 2D." )
131+
132+ m , n = a .shape
133+ dtype = a .dtype
134+
135+ eps = np .finfo (np .array (a ).dtype ).eps
136+ if a .shape == ():
137+ s = mx .zeros ((0 ,), dtype = dtype )
138+ x = mx .zeros ((n , * b .shape [1 :]), dtype = dtype )
139+ else :
140+ if rcond is None :
141+ rcond = eps * max (m , n )
142+ else :
143+ rcond = mx .where (rcond < 0 , eps , rcond )
144+ u , s , vt = svd (a , full_matrices = False )
145+
146+ mask = s >= mx .array (rcond , dtype = s .dtype ) * s [0 ]
147+ safe_s = mx .array (mx .where (mask , s , 1 ), dtype = dtype )
148+ s_inv = mx .where (mask , 1 / safe_s , 0 )
149+ s_inv = mx .expand_dims (s_inv , axis = 1 )
150+ u_t_b = mx .matmul (mx .transpose (mx .conj (u )), b )
151+ x = mx .matmul (mx .transpose (mx .conj (vt )), s_inv * u_t_b )
152+
153+ if b_orig_ndim == 1 :
154+ x = mx .squeeze (x , axis = - 1 )
155+
156+ return x
0 commit comments