@@ -202,3 +202,109 @@ def dot(lhs, rhs, lhs_opts=MATPROP.NONE, rhs_opts=MATPROP.NONE, return_scalar =
202202 safe_call (backend .get ().af_dot (c_pointer (out .arr ), lhs .arr , rhs .arr ,
203203 lhs_opts .value , rhs_opts .value ))
204204 return out
205+
206+ def gemm (lhs , rhs , alpha = 1.0 , beta = 0.0 , lhs_opts = MATPROP .NONE , rhs_opts = MATPROP .NONE , C = None ):
207+ """
208+ BLAS general matrix multiply (GEMM) of two af_array objects.
209+
210+ This provides a general interface to the BLAS level 3 general matrix multiply (GEMM), which is generally defined as:
211+
212+ C = α ∗ opA(A) opB(B)+ β∗C
213+
214+ where α (alpha) and β (beta) are both scalars; A and B are the matrix multiply operands;
215+ and opA and opB are noop (if AF_MAT_NONE) or transpose (if AF_MAT_TRANS) operations
216+ on A or B before the actual GEMM operation.
217+ Batched GEMM is supported if at least either A or B have more than two dimensions
218+ (see af::matmul for more details on broadcasting).
219+ However, only one alpha and one beta can be used for all of the batched matrix operands.
220+
221+ Parameters
222+ ----------
223+
224+ lhs : af.Array
225+ A 2 dimensional, real or complex arrayfire array.
226+
227+ rhs : af.Array
228+ A 2 dimensional, real or complex arrayfire array.
229+
230+ alpha : scalar
231+
232+ beta : scalar
233+
234+ lhs_opts: optional: af.MATPROP. default: af.MATPROP.NONE.
235+ Can be one of
236+ - af.MATPROP.NONE - If no op should be done on `lhs`.
237+ - af.MATPROP.TRANS - If `lhs` has to be transposed before multiplying.
238+ - af.MATPROP.CTRANS - If `lhs` has to be hermitian transposed before multiplying.
239+
240+ rhs_opts: optional: af.MATPROP. default: af.MATPROP.NONE.
241+ Can be one of
242+ - af.MATPROP.NONE - If no op should be done on `rhs`.
243+ - af.MATPROP.TRANS - If `rhs` has to be transposed before multiplying.
244+ - af.MATPROP.CTRANS - If `rhs` has to be hermitian transposed before multiplying.
245+
246+ Returns
247+ -------
248+
249+ out : af.Array
250+ Output of the matrix multiplication on `lhs` and `rhs`.
251+
252+ Note
253+ -----
254+
255+ - The data types of `lhs` and `rhs` should be the same.
256+ - Batches are not supported.
257+
258+ """
259+ if C is None :
260+ out = Array ()
261+ else :
262+ out = C
263+
264+ ltype = lhs .dtype ()
265+
266+ if ltype == Dtype .f32 :
267+ aptr = c_cast (c_pointer (c_float_t (alpha )),c_void_ptr_t )
268+ bptr = c_cast (c_pointer (c_float_t (beta )), c_void_ptr_t )
269+ elif ltype == Dtype .c32 :
270+ if isinstance (alpha , af_cfloat_t ):
271+ aptr = c_cast (c_pointer (alpha ), c_void_ptr_t )
272+ elif isinstance (alpha , tuple ):
273+ aptr = c_cast (c_pointer (af_cfloat_t (alpha [0 ], alpha [1 ])), c_void_ptr_t )
274+ else :
275+ aptr = c_cast (c_pointer (af_cfloat_t (alpha )), c_void_ptr_t )
276+
277+ if isinstance (beta , af_cfloat_t ):
278+ bptr = c_cast (c_pointer (beta ), c_void_ptr_t )
279+ elif isinstance (beta , tuple ):
280+ bptr = c_cast (c_pointer (af_cfloat_t (beta [0 ], beta [1 ])), c_void_ptr_t )
281+ else :
282+ bptr = c_cast (c_pointer (af_cfloat_t (beta )), c_void_ptr_t )
283+
284+ elif ltype == Dtype .f64 :
285+ aptr = c_cast (c_pointer (c_double_t (alpha )),c_void_ptr_t )
286+ bptr = c_cast (c_pointer (c_double_t (beta )), c_void_ptr_t )
287+ elif ltype == Dtype .c64 :
288+ if isinstance (alpha , af_cdouble_t ):
289+ aptr = c_cast (c_pointer (alpha ), c_void_ptr_t )
290+ elif isinstance (alpha , tuple ):
291+ aptr = c_cast (c_pointer (af_cdouble_t (alpha [0 ], alpha [1 ])), c_void_ptr_t )
292+ else :
293+ aptr = c_cast (c_pointer (af_cdouble_t (alpha )), c_void_ptr_t )
294+
295+ if isinstance (beta , af_cdouble_t ):
296+ bptr = c_cast (c_pointer (beta ), c_void_ptr_t )
297+ elif isinstance (beta , tuple ):
298+ bptr = c_cast (c_pointer (af_cdouble_t (beta [0 ], beta [1 ])), c_void_ptr_t )
299+ else :
300+ bptr = c_cast (c_pointer (af_cdouble_t (beta )), c_void_ptr_t )
301+ elif ltype == Dtype .f16 :
302+ raise TypeError ("fp16 currently unsupported gemm() input type" )
303+ else :
304+ raise TypeError ("unsupported input type" )
305+
306+
307+ safe_call (backend .get ().af_gemm (c_pointer (out .arr ),
308+ lhs_opts .value , rhs_opts .value ,
309+ aptr , lhs .arr , rhs .arr , bptr ))
310+ return out
0 commit comments