@@ -150,7 +150,7 @@ def matmulTT(lhs, rhs):
150150 MATPROP .TRANS .value , MATPROP .TRANS .value ))
151151 return out
152152
153- def dot (lhs , rhs , lhs_opts = MATPROP .NONE , rhs_opts = MATPROP .NONE ):
153+ def dot (lhs , rhs , lhs_opts = MATPROP .NONE , rhs_opts = MATPROP .NONE , return_scalar = False ):
154154 """
155155 Dot product of two input vectors.
156156
@@ -173,10 +173,13 @@ def dot(lhs, rhs, lhs_opts=MATPROP.NONE, rhs_opts=MATPROP.NONE):
173173 - af.MATPROP.NONE - If no op should be done on `rhs`.
174174 - No other options are currently supported.
175175
176+ return_scalar: optional: bool. default: False.
177+ - When set to true, the input arrays are flattened and the output is a scalar
178+
176179 Returns
177180 -------
178181
179- out : af.Array
182+ out : af.Array or scalar
180183 Output of dot product of `lhs` and `rhs`.
181184
182185 Note
@@ -186,7 +189,16 @@ def dot(lhs, rhs, lhs_opts=MATPROP.NONE, rhs_opts=MATPROP.NONE):
186189 - Batches are not supported.
187190
188191 """
189- out = Array ()
190- safe_call (backend .get ().af_dot (c_pointer (out .arr ), lhs .arr , rhs .arr ,
191- lhs_opts .value , rhs_opts .value ))
192- return out
192+ if return_scalar :
193+ real = c_double_t (0 )
194+ imag = c_double_t (0 )
195+ safe_call (backend .get ().af_dot_all (c_pointer (real ), c_pointer (imag ),
196+ lhs .arr , rhs .arr , lhs_opts .value , rhs_opts .value ))
197+ real = real .value
198+ imag = imag .value
199+ return real if imag == 0 else real + imag * 1j
200+ else :
201+ out = Array ()
202+ safe_call (backend .get ().af_dot (c_pointer (out .arr ), lhs .arr , rhs .arr ,
203+ lhs_opts .value , rhs_opts .value ))
204+ return out
0 commit comments