@@ -38,6 +38,7 @@ __all__ += [
3838 " dpnp_dot" ,
3939 " dpnp_inner" ,
4040 " dpnp_kron" ,
41+ " dpnp_matmul" ,
4142 " dpnp_outer"
4243]
4344
@@ -196,6 +197,90 @@ cpdef dparray dpnp_kron(dparray in_array1, dparray in_array2):
196197 return result
197198
198199
200+ cpdef dparray dpnp_matmul(dparray in_array1, dparray in_array2, dparray out = None ):
201+
202+ cdef vector[Py_ssize_t] shape_result
203+
204+ cdef vector[Py_ssize_t] shape1 = in_array1.shape
205+ cdef vector[Py_ssize_t] shape2 = in_array2.shape
206+
207+ cdef size_t size_m = 0
208+ cdef size_t size_n = 0
209+ cdef size_t size_k = 0
210+
211+ # Calling this function on an empty container causes undefined behavior.
212+ if not shape1.empty():
213+ size_m = shape1.front()
214+ if not shape2.empty():
215+ size_n = shape2.back()
216+ if not shape1.empty():
217+ size_k = shape1.back()
218+
219+ cdef size_t ndim_max = max (in_array1.ndim, in_array2.ndim)
220+
221+ if in_array1.ndim < ndim_max or ndim_max == 1 :
222+ """
223+ shape1(2,), shape2(2,4)
224+ test: pytest tests/test_matmul.py::test_matmul[shape_pair4-types0] -v -s
225+ or
226+ shape1(2,), shape2(2,)
227+ test: pytest tests/test_matmul.py::test_matmul[shape_pair8-types0] -v -s
228+ """
229+ size_m = 1
230+
231+ if in_array2.ndim < ndim_max or ndim_max == 1 :
232+ """
233+ shape1(5,2), shape2(2,)
234+ test: pytest tests/test_matmul.py::test_matmul[shape_pair6-types0] -v -s
235+ or
236+ shape1(3,), shape2(3,)
237+ test: pytest tests/test_matmul.py::test_matmul[shape_pair8-types0] -v -s
238+ """
239+ size_n = 1
240+
241+ if ndim_max > 2 :
242+ """
243+ shape1(5, 3, 2) * shape2(5, 2, 4) -> result(5, 3, 4)
244+ test: pytest tests/test_matmul.py::test_matmul[shape_pair10-types0] -v -s
245+ """
246+ shape_result = shape1[:- 1 ] + [shape2.back()]
247+ else :
248+ """
249+ shape1(5,2) * shape2(2,3) -> result(5,3)
250+ test: pytest tests/test_matmul.py::test_matmul[shape_pair0-types0] -v -s
251+ """
252+ shape_result = shape1[:- 1 ] + shape2[1 :]
253+
254+ # convert string type names (dparray.dtype) to C enum DPNPFuncType
255+ cdef DPNPFuncType param1_type = dpnp_dtype_to_DPNPFuncType(in_array1.dtype)
256+ cdef DPNPFuncType param2_type = dpnp_dtype_to_DPNPFuncType(in_array2.dtype)
257+
258+ # get the FPTR data structure
259+ cdef DPNPFuncData kernel_data = get_dpnp_function_ptr(DPNP_FN_MATMUL, param1_type, param2_type)
260+
261+ result_type = dpnp_DPNPFuncType_to_dtype( < size_t > kernel_data.return_type)
262+
263+ cdef dparray result
264+
265+ if out is not None :
266+ if out.dtype != result_type:
267+ utils.checker_throw_value_error(' matmul' , ' out.dtype' , out.dtype, result_type)
268+ if out.shape != shape_result:
269+ utils.checker_throw_value_error(' matmul' , ' out.shape' , out.shape, shape_result)
270+ result = out
271+ else :
272+ result = dparray(shape_result, dtype = result_type)
273+
274+ if result.size == 0 :
275+ return result
276+
277+ cdef fptr_blas_gemm_2in_1out_t func = < fptr_blas_gemm_2in_1out_t > kernel_data.ptr
278+ # call FPTR function
279+ func(in_array1.get_data(), in_array2.get_data(), result.get_data(), size_m, size_n, size_k)
280+
281+ return result
282+
283+
199284cpdef dparray dpnp_outer(dparray array1, dparray array2):
200285 cdef dparray_shape_type result_shape = (array1.size, array2.size)
201286 result_type = numpy.promote_types(array1.dtype, array1.dtype)
0 commit comments