77# http://arrayfire.com/licenses/BSD-3-Clause
88########################################################
99
10+ """
11+ dense linear algebra functions for arrayfire.
12+ """
13+
1014from .library import *
1115from .array import *
1216
1317def lu (A ):
18+ """
19+ LU decomposition.
20+
21+ Parameters
22+ ----------
23+ A: af.Array
24+ A 2 dimensional arrayfire array.
25+
26+ Returns
27+ -------
28+ (L,U,P): tuple of af.Arrays
29+ - L - Lower triangular matrix.
30+ - U - Upper triangular matrix.
31+ - P - Permutation array.
32+
33+ Note
34+ ----
35+
36+ The original matrix `A` can be reconstructed using the outputs in the following manner.
37+
38+ >>> A[P, :] = af.matmul(L, U)
39+
40+ """
1441 L = Array ()
1542 U = Array ()
1643 P = Array ()
1744 safe_call (backend .get ().af_lu (ct .pointer (L .arr ), ct .pointer (U .arr ), ct .pointer (P .arr ), A .arr ))
1845 return L ,U ,P
1946
2047def lu_inplace (A , pivot = "lapack" ):
48+ """
49+ In place LU decomposition.
50+
51+ Parameters
52+ ----------
53+ A: af.Array
54+ - a 2 dimensional arrayfire array on entry.
55+ - Contains L in the lower triangle on exit.
56+ - Contains U in the upper triangle on exit.
57+
58+ Returns
59+ -------
60+ P: af.Array
61+ - Permutation array.
62+
63+ Note
64+ ----
65+
66+ This function is primarily used with `af.solve_lu` to reduce computations.
67+
68+ """
2169 P = Array ()
2270 is_pivot_lapack = False if (pivot == "full" ) else True
2371 safe_call (backend .get ().af_lu_inplace (ct .pointer (P .arr ), A .arr , is_pivot_lapack ))
2472 return P
2573
2674def qr (A ):
75+ """
76+ QR decomposition.
77+
78+ Parameters
79+ ----------
80+ A: af.Array
81+ A 2 dimensional arrayfire array.
82+
83+ Returns
84+ -------
85+ (Q,R,T): tuple of af.Arrays
86+ - Q - Orthogonal matrix.
87+ - R - Upper triangular matrix.
88+ - T - Vector containing additional information to solve a least squares problem.
89+
90+ Note
91+ ----
92+
93+ The outputs of this funciton have the following properties.
94+
95+ >>> A = af.matmul(Q, R)
96+ >>> I = af.matmulNT(Q, Q) # Identity matrix
97+ """
2798 Q = Array ()
2899 R = Array ()
29100 T = Array ()
30101 safe_call (backend .get ().af_lu (ct .pointer (Q .arr ), ct .pointer (R .arr ), ct .pointer (T .arr ), A .arr ))
31102 return Q ,R ,T
32103
33104def qr_inplace (A ):
105+ """
106+ In place QR decomposition.
107+
108+ Parameters
109+ ----------
110+ A: af.Array
111+ - a 2 dimensional arrayfire array on entry.
112+ - Packed Q and R matrices on exit.
113+
114+ Returns
115+ -------
116+ T: af.Array
117+ - Vector containing additional information to solve a least squares problem.
118+
119+ Note
120+ ----
121+
122+ This function is used to save space only when `R` is required.
123+ """
34124 T = Array ()
35125 safe_call (backend .get ().af_qr_inplace (ct .pointer (T .arr ), A .arr ))
36126 return T
37127
38128def cholesky (A , is_upper = True ):
129+ """
130+ Cholesky decomposition
131+
132+ Parameters
133+ ----------
134+ A: af.Array
135+ A 2 dimensional, symmetric, positive definite matrix.
136+
137+ is_upper: optional: bool. default: True
138+ Specifies if output `R` is upper triangular (if True) or lower triangular (if False).
139+
140+ Returns
141+ -------
142+ (R,info): tuple of af.Array, int.
143+ - R - triangular matrix.
144+ - info - 0 if decomposition sucessful.
145+ Note
146+ ----
147+
148+ The original matrix `A` can be reconstructed using the outputs in the following manner.
149+
150+ >>> A = af.matmulNT(R, R) #if R is upper triangular
151+
152+ """
39153 R = Array ()
40154 info = ct .c_int (0 )
41155 safe_call (backend .get ().af_cholesky (ct .pointer (R .arr ), ct .pointer (info ), A .arr , is_upper ))
42156 return R , info .value
43157
44158def cholesky_inplace (A , is_upper = True ):
159+ """
160+ In place Cholesky decomposition.
161+
162+ Parameters
163+ ----------
164+ A: af.Array
165+ - a 2 dimensional, symmetric, positive definite matrix.
166+ - Trinangular matrix on exit.
167+
168+ is_upper: optional: bool. default: True.
169+ Specifies if output `R` is upper triangular (if True) or lower triangular (if False).
170+
171+ Returns
172+ -------
173+ info : int.
174+ 0 if decomposition sucessful.
175+
176+ """
45177 info = ct .c_int (0 )
46178 safe_call (backend .get ().af_cholesky_inplace (ct .pointer (info ), A .arr , is_upper ))
47179 return info .value
48180
49181def solve (A , B , options = MATPROP .NONE ):
182+ """
183+ Solve a system of linear equations.
184+
185+ Parameters
186+ ----------
187+
188+ A: af.Array
189+ A 2 dimensional arrayfire array representing the coefficients of the system.
190+
191+ B: af.Array
192+ A 1 or 2 dimensional arrayfire array representing the constants of the system.
193+
194+ options: optional: af.MATPROP. default: af.MATPROP.NONE.
195+ - Additional options to speed up computations.
196+ - Currently needs to be one of `af.MATPROP.NONE`, `af.MATPROP.LOWER`, `af.MATPROP.UPPER`.
197+
198+ Returns
199+ -------
200+ X: af.Array
201+ A 1 or 2 dimensional arrayfire array representing the unknowns in the system.
202+
203+ """
50204 X = Array ()
51205 safe_call (backend .get ().af_solve (ct .pointer (X .arr ), A .arr , B .arr , options .value ))
52206 return X
53207
54208def solve_lu (A , P , B , options = MATPROP .NONE ):
209+ """
210+ Solve a system of linear equations, using LU decomposition.
211+
212+ Parameters
213+ ----------
214+
215+ A: af.Array
216+ - A 2 dimensional arrayfire array representing the coefficients of the system.
217+ - This matrix should be decomposed previously using `lu_inplace(A)`.
218+
219+ P: af.Array
220+ - Permutation array.
221+ - This array is the output of an earlier call to `lu_inplace(A)`
222+
223+ B: af.Array
224+ A 1 or 2 dimensional arrayfire array representing the constants of the system.
225+
226+ Returns
227+ -------
228+ X: af.Array
229+ A 1 or 2 dimensional arrayfire array representing the unknowns in the system.
230+
231+ """
55232 X = Array ()
56233 safe_call (backend .get ().af_solve_lu (ct .pointer (X .arr ), A .arr , P .arr , B .arr , options .value ))
57234 return X
58235
59236def inverse (A , options = MATPROP .NONE ):
60- I = Array ()
61- safe_call (backend .get ().af_inverse (ct .pointer (I .arr ), A .arr , options .value ))
62- return I
237+ """
238+ Invert a matrix.
239+
240+ Parameters
241+ ----------
242+
243+ A: af.Array
244+ - A 2 dimensional arrayfire array
245+
246+ options: optional: af.MATPROP. default: af.MATPROP.NONE.
247+ - Additional options to speed up computations.
248+ - Currently needs to be one of `af.MATPROP.NONE`.
249+
250+ Returns
251+ -------
252+
253+ AI: af.Array
254+ - A 2 dimensional array that is the inverse of `A`
255+
256+ Note
257+ ----
258+
259+ `A` needs to be a square matrix.
260+
261+ """
262+ AI = Array ()
263+ safe_call (backend .get ().af_inverse (ct .pointer (AI .arr ), A .arr , options .value ))
264+ return AI
63265
64266def rank (A , tol = 1E-5 ):
267+ """
268+ Rank of a matrix.
269+
270+ Parameters
271+ ----------
272+
273+ A: af.Array
274+ - A 2 dimensional arrayfire array
275+
276+ tol: optional: scalar. default: 1E-5.
277+ - Tolerance for calculating rank
278+
279+ Returns
280+ -------
281+
282+ r: int
283+ - Rank of `A` within the given tolerance
284+ """
65285 r = ct .c_uint (0 )
66286 safe_call (backend .get ().af_rank (ct .pointer (r ), A .arr , ct .c_double (tol )))
67287 return r .value
68288
69289def det (A ):
290+ """
291+ Determinant of a matrix.
292+
293+ Parameters
294+ ----------
295+
296+ A: af.Array
297+ - A 2 dimensional arrayfire array
298+
299+ Returns
300+ -------
301+
302+ res: scalar
303+ - Determinant of the matrix.
304+ """
70305 re = ct .c_double (0 )
71306 im = ct .c_double (0 )
72307 safe_call (backend .get ().af_det (ct .pointer (re ), ct .pointer (im ), A .arr ))
@@ -75,6 +310,31 @@ def det(A):
75310 return re if (im == 0 ) else re + im * 1j
76311
77312def norm (A , norm_type = NORM .EUCLID , p = 1.0 , q = 1.0 ):
313+ """
314+ Norm of an array or a matrix.
315+
316+ Parameters
317+ ----------
318+
319+ A: af.Array
320+ - A 1 or 2 dimensional arrayfire array
321+
322+ norm_type: optional: af.NORM. default: af.NORM.EUCLID.
323+ - Type of norm to be calculated.
324+
325+ p: scalar. default 1.0.
326+ - Used only if `norm_type` is one of `af.NORM.VECTOR_P`, `af.NORM_MATRIX_L_PQ`
327+
328+ q: scalar. default 1.0.
329+ - Used only if `norm_type` is `af.NORM_MATRIX_L_PQ`
330+
331+ Returns
332+ -------
333+
334+ res: scalar
335+ - norm of the input
336+
337+ """
78338 res = ct .c_double (0 )
79339 safe_call (backend .get ().af_norm (ct .pointer (res ), A .arr , norm_type .value ,
80340 ct .c_double (p ), ct .c_double (q )))
0 commit comments