@@ -29,11 +29,43 @@ def dot(
2929 lhs : Array ,
3030 rhs : Array ,
3131 / ,
32- lhs_opts : MatProp = MatProp .NONE ,
33- rhs_opts : MatProp = MatProp .NONE ,
3432 * ,
3533 return_scalar : bool = False ,
3634) -> int | float | complex | Array :
35+ """
36+ Calculates the dot product of two input arrays, with options to modify the operation
37+ on the input arrays and the possibility to return the result as a scalar.
38+
39+ Parameters
40+ ----------
41+ lhs : Array
42+ A 1-dimensional, int of float Array instance, representing an array.
43+
44+ rhs : Array
45+ A 1-dimensional, int of float Array instance, representing another array.
46+
47+ return_scalar : bool, optional
48+ When set to True, the input arrays are flattened, and the output is a scalar value.
49+ Default is False.
50+
51+ Returns
52+ -------
53+ out : int | float | complex | Array
54+ The result of the dot product. Returns an Array unless `return_scalar` is True,
55+ in which case a scalar value (int, float, or complex) is returned based on the
56+ data type of the inputs.
57+
58+ Note
59+ -----
60+ - The data types of `lhs` and `rhs` should be the same.
61+ - Batch operations are not supported.
62+ - Modification options for `lhs` and `rhs` are currently disabled as function supports only `MatProp.NONE`.
63+ """
64+ # TODO
65+ # Add support of lhs_opts and rhs_opts and return them as key arguments.
66+ lhs_opts : MatProp = MatProp .NONE
67+ rhs_opts : MatProp = MatProp .NONE
68+
3769 if return_scalar :
3870 return wrapper .dot_all (lhs .arr , rhs .arr , lhs_opts , rhs_opts )
3971
@@ -50,11 +82,105 @@ def gemm(
5082 alpha : int | float = 1.0 ,
5183 beta : int | float = 0.0 ,
5284) -> Array :
85+ """
86+ Performs BLAS general matrix multiplication (GEMM) on two Array instances.
87+
88+ The operation is defined as: C = alpha * op(lhs) * op(rhs) + beta * C, where op(X) is
89+ one of no operation, transpose, or Hermitian transpose, determined by lhs_opts and rhs_opts.
90+
91+ Parameters
92+ ----------
93+ lhs : Array
94+ A 2-dimensional, real or complex array representing the left-hand side matrix.
95+
96+ rhs : Array
97+ A 2-dimensional, real or complex array representing the right-hand side matrix.
98+
99+ lhs_opts : MatProp, optional
100+ Operation to perform on `lhs` before multiplication. Default is MatProp.NONE. Options include:
101+ - MatProp.NONE: No operation.
102+ - MatProp.TRANS: Transpose.
103+ - MatProp.CTRANS: Hermitian transpose.
104+
105+ rhs_opts : MatProp, optional
106+ Operation to perform on `rhs` before multiplication. Default is MatProp.NONE. Options include:
107+ - MatProp.NONE: No operation.
108+ - MatProp.TRANS: Transpose.
109+ - MatProp.CTRANS: Hermitian transpose.
110+
111+ alpha : int | float, optional
112+ Scalar multiplier for the product of `lhs` and `rhs`. Default is 1.0.
113+
114+ beta : int | float, optional
115+ Scalar multiplier for the existing matrix C in the accumulation. Default is 0.0.
116+
117+ Returns
118+ -------
119+ Array
120+ The result of the matrix multiplication operation.
121+
122+ Note
123+ -----
124+ - The data types of `lhs` and `rhs` must be compatible.
125+ - Batch operations are not supported in this version.
126+ """
53127 return cast (Array , wrapper .gemm (lhs .arr , rhs .arr , lhs_opts , rhs_opts , alpha , beta ))
54128
55129
56130@afarray_as_array
57131def matmul (lhs : Array , rhs : Array , / , lhs_opts : MatProp = MatProp .NONE , rhs_opts : MatProp = MatProp .NONE ) -> Array :
132+ """
133+ Performs generalized matrix multiplication between two arrays with optional
134+ transposition or hermitian transposition operations on the input matrices.
135+
136+ Parameters
137+ ----------
138+ lhs : af.Array
139+ A 2-dimensional, real or complex ArrayFire array representing the left-hand side matrix.
140+
141+ rhs : af.Array
142+ A 2-dimensional, real or complex ArrayFire array representing the right-hand side matrix.
143+
144+ lhs_opts : af.MATPROP, optional
145+ Operation to perform on the `lhs` matrix before multiplication. Defaults to af.MATPROP.NONE.
146+ Options include:
147+ - af.MATPROP.NONE: No operation.
148+ - af.MATPROP.TRANS: Transpose `lhs`.
149+ - af.MATPROP.CTRANS: Hermitian transpose (conjugate transpose) `lhs`.
150+
151+ rhs_opts : af.MATPROP, optional
152+ Operation to perform on the `rhs` matrix before multiplication. Defaults to af.MATPROP.NONE.
153+ Options include:
154+ - af.MATPROP.NONE: No operation.
155+ - af.MATPROP.TRANS: Transpose `rhs`.
156+ - af.MATPROP.CTRANS: Hermitian transpose (conjugate transpose) `rhs`.
157+
158+ Returns
159+ -------
160+ out : af.Array
161+ The result of the matrix multiplication. The output is a 2-dimensional ArrayFire array.
162+
163+ Notes
164+ -----
165+ - The data types of `lhs` and `rhs` must be the same.
166+ - Batch operations (multiplying multiple pairs of matrices at once) are not supported in this implementation.
167+
168+ Examples
169+ --------
170+ Basic matrix multiplication:
171+
172+ A = af.randu(5, 4, dtype=af.Dtype.f32)
173+ B = af.randu(4, 6, dtype=af.Dtype.f32)
174+ C = matmul(A, B)
175+
176+ Matrix multiplication with the left-hand side transposed:
177+
178+ C = matmul(A, B, lhs_opts=af.MATPROP.TRANS)
179+
180+ Matrix multiplication with both matrices transposed:
181+
182+ C = matmul(A, B, lhs_opts=af.MATPROP.TRANS, rhs_opts=af.MATPROP.TRANS)
183+ """
58184 return cast (Array , wrapper .matmul (lhs .arr , rhs .arr , lhs_opts , rhs_opts ))
59185
60186
@@ -121,5 +247,5 @@ def solve(a: Array, b: Array, /, *, options: MatProp = MatProp.NONE, pivot: None
121247 return cast (Array , wrapper .solve (a .arr , b .arr , options ))
122248
123249
124- # TODO
125- # Add Sparse functions? #good_first_issue
250+ # TODO #good_first_issue
251+ # Add Sparse functions
0 commit comments