55#define NO_IMPORT_ARRAY
66#define NO_IMPORT_UFUNC
77
8+ extern " C" {
89#include < Python.h>
910#include < cstdio>
11+ #include < string.h>
1012
1113#include " numpy/arrayobject.h"
14+ #include " numpy/ndarraytypes.h"
1215#include " numpy/ufuncobject.h"
1316#include " numpy/dtype_api.h"
14- # include " numpy/ndarraytypes.h "
17+ }
1518
1619#include " ../quad_common.h"
1720#include " ../scalar.h"
1821#include " ../dtype.h"
1922#include " ../ops.hpp"
20- #include " binary_ops.h"
2123#include " matmul.h"
24+ #include " promoters.hpp"
2225
23- #include < iostream>
24-
26+ /* *
27+ * Resolve descriptors for matmul operation.
28+ * Follows the same pattern as binary_ops.cpp
29+ */
2530static NPY_CASTING
2631quad_matmul_resolve_descriptors (PyObject *self, PyArray_DTypeMeta *const dtypes[],
2732 PyArray_Descr *const given_descrs[], PyArray_Descr *loop_descrs[],
2833 npy_intp *NPY_UNUSED (view_offset))
2934{
30- NPY_CASTING casting = NPY_NO_CASTING;
31- std::cout << " exiting the descriptor " ;
32- return casting ;
33- }
35+ // Follow the exact same pattern as quad_binary_op_resolve_descriptors
36+ QuadPrecDTypeObject *descr_in1 = (QuadPrecDTypeObject *)given_descrs[ 0 ] ;
37+ QuadPrecDTypeObject *descr_in2 = (QuadPrecDTypeObject *)given_descrs[ 1 ] ;
38+ QuadBackendType target_backend;
3439
35- template <binary_op_quad_def sleef_op, binary_op_longdouble_def longdouble_op>
36- int
37- quad_generic_matmul_strided_loop_unaligned (PyArrayMethod_Context *context, char *const data[],
38- npy_intp const dimensions[], npy_intp const strides[],
39- NpyAuxData *auxdata)
40- {
41- npy_intp N = dimensions[0 ];
42- char *in1_ptr = data[0 ], *in2_ptr = data[1 ];
43- char *out_ptr = data[2 ];
44- npy_intp in1_stride = strides[0 ];
45- npy_intp in2_stride = strides[1 ];
46- npy_intp out_stride = strides[2 ];
47-
48- QuadPrecDTypeObject *descr = (QuadPrecDTypeObject *)context->descriptors [0 ];
49- QuadBackendType backend = descr->backend ;
50- size_t elem_size = (backend == BACKEND_SLEEF) ? sizeof (Sleef_quad) : sizeof (long double );
40+ // Determine target backend and if casting is needed
41+ NPY_CASTING casting = NPY_NO_CASTING;
42+ if (descr_in1->backend != descr_in2->backend ) {
43+ target_backend = BACKEND_LONGDOUBLE;
44+ casting = NPY_SAFE_CASTING;
45+ }
46+ else {
47+ target_backend = descr_in1->backend ;
48+ }
5149
52- quad_value in1, in2, out;
53- while (N--) {
54- memcpy (&in1, in1_ptr, elem_size);
55- memcpy (&in2, in2_ptr, elem_size);
56- if (backend == BACKEND_SLEEF) {
57- out.sleef_value = sleef_op (&in1.sleef_value , &in2.sleef_value );
50+ // Set up input descriptors, casting if necessary
51+ for (int i = 0 ; i < 2 ; i++) {
52+ if (((QuadPrecDTypeObject *)given_descrs[i])->backend != target_backend) {
53+ loop_descrs[i] = (PyArray_Descr *)new_quaddtype_instance (target_backend);
54+ if (!loop_descrs[i]) {
55+ return (NPY_CASTING)-1 ;
56+ }
5857 }
5958 else {
60- out.longdouble_value = longdouble_op (&in1.longdouble_value , &in2.longdouble_value );
59+ Py_INCREF (given_descrs[i]);
60+ loop_descrs[i] = given_descrs[i];
6161 }
62- memcpy (out_ptr, &out, elem_size);
62+ }
6363
64- in1_ptr += in1_stride;
65- in2_ptr += in2_stride;
66- out_ptr += out_stride;
64+ // Set up output descriptor
65+ if (given_descrs[2 ] == NULL ) {
66+ loop_descrs[2 ] = (PyArray_Descr *)new_quaddtype_instance (target_backend);
67+ if (!loop_descrs[2 ]) {
68+ return (NPY_CASTING)-1 ;
69+ }
6770 }
68- return 0 ;
71+ else {
72+ QuadPrecDTypeObject *descr_out = (QuadPrecDTypeObject *)given_descrs[2 ];
73+ if (descr_out->backend != target_backend) {
74+ loop_descrs[2 ] = (PyArray_Descr *)new_quaddtype_instance (target_backend);
75+ if (!loop_descrs[2 ]) {
76+ return (NPY_CASTING)-1 ;
77+ }
78+ }
79+ else {
80+ Py_INCREF (given_descrs[2 ]);
81+ loop_descrs[2 ] = given_descrs[2 ];
82+ }
83+ }
84+ return casting;
6985}
7086
71- template <binary_op_quad_def sleef_op, binary_op_longdouble_def longdouble_op>
72- int
73- quad_generic_matmul_strided_loop_aligned (PyArrayMethod_Context *context, char *const data[],
74- npy_intp const dimensions[], npy_intp const strides[],
75- NpyAuxData *auxdata)
87+ /* *
88+ * Matrix multiplication strided loop using NumPy 2.0 API.
89+ * Implements general matrix multiplication for arbitrary dimensions.
90+ *
91+ * For matmul with signature (m?,n),(n,p?)->(m?,p?):
92+ * - dimensions[0] = N (loop dimension, number of batch operations)
93+ * - dimensions[1] = m (rows of first matrix)
94+ * - dimensions[2] = n (cols of first matrix / rows of second matrix)
95+ * - dimensions[3] = p (cols of second matrix)
96+ *
97+ * - strides[0], strides[1], strides[2] = batch strides for A, B, C
98+ * - strides[3], strides[4] = row stride, col stride for A (m, n)
99+ * - strides[5], strides[6] = row stride, col stride for B (n, p)
100+ * - strides[7], strides[8] = row stride, col stride for C (m, p)
101+ */
102+ static int
103+ quad_matmul_strided_loop (PyArrayMethod_Context *context, char *const data[],
104+ npy_intp const dimensions[], npy_intp const strides[], NpyAuxData *auxdata)
76105{
77- npy_intp N = dimensions[0 ];
78- char *in1_ptr = data[0 ], *in2_ptr = data[1 ];
79- char *out_ptr = data[2 ];
80- npy_intp in1_stride = strides[0 ];
81- npy_intp in2_stride = strides[1 ];
82- npy_intp out_stride = strides[2 ];
83-
106+ // Extract dimensions
107+ npy_intp N = dimensions[0 ]; // Number of batch operations
108+ npy_intp m = dimensions[1 ]; // Rows of first matrix
109+ npy_intp n = dimensions[2 ]; // Cols of first matrix / rows of second matrix
110+ npy_intp p = dimensions[3 ]; // Cols of second matrix
111+
112+ // Extract batch strides
113+ npy_intp A_batch_stride = strides[0 ];
114+ npy_intp B_batch_stride = strides[1 ];
115+ npy_intp C_batch_stride = strides[2 ];
116+
117+ // Extract core strides for matrix dimensions
118+ npy_intp A_row_stride = strides[3 ]; // Stride along m dimension of A
119+ npy_intp A_col_stride = strides[4 ]; // Stride along n dimension of A
120+ npy_intp B_row_stride = strides[5 ]; // Stride along n dimension of B
121+ npy_intp B_col_stride = strides[6 ]; // Stride along p dimension of B
122+ npy_intp C_row_stride = strides[7 ]; // Stride along m dimension of C
123+ npy_intp C_col_stride = strides[8 ]; // Stride along p dimension of C
124+
125+ // Get backend from descriptor
84126 QuadPrecDTypeObject *descr = (QuadPrecDTypeObject *)context->descriptors [0 ];
85127 QuadBackendType backend = descr->backend ;
128+ size_t elem_size = (backend == BACKEND_SLEEF) ? sizeof (Sleef_quad) : sizeof (long double );
86129
87- while (N--) {
88- if (backend == BACKEND_SLEEF) {
89- *(Sleef_quad *)out_ptr = sleef_op ((Sleef_quad *)in1_ptr, (Sleef_quad *)in2_ptr);
90- }
91- else {
92- *(long double *)out_ptr = longdouble_op ((long double *)in1_ptr, (long double *)in2_ptr);
130+ // Process each batch
131+ for (npy_intp batch = 0 ; batch < N; batch++) {
132+ char *A_batch = data[0 ] + batch * A_batch_stride;
133+ char *B_batch = data[1 ] + batch * B_batch_stride;
134+ char *C_batch = data[2 ] + batch * C_batch_stride;
135+
136+ // Perform matrix multiplication: C = A @ B
137+ // C[i,j] = sum_k(A[i,k] * B[k,j])
138+ for (npy_intp i = 0 ; i < m; i++) {
139+ for (npy_intp j = 0 ; j < p; j++) {
140+ char *C_ij = C_batch + i * C_row_stride + j * C_col_stride;
141+
142+ if (backend == BACKEND_SLEEF) {
143+ Sleef_quad sum = Sleef_cast_from_doubleq1 (0.0 ); // Initialize to 0
144+
145+ for (npy_intp k = 0 ; k < n; k++) {
146+ char *A_ik = A_batch + i * A_row_stride + k * A_col_stride;
147+ char *B_kj = B_batch + k * B_row_stride + j * B_col_stride;
148+
149+ Sleef_quad a_val, b_val;
150+ memcpy (&a_val, A_ik, sizeof (Sleef_quad));
151+ memcpy (&b_val, B_kj, sizeof (Sleef_quad));
152+
153+ // sum += A[i,k] * B[k,j]
154+ sum = Sleef_addq1_u05 (sum, Sleef_mulq1_u05 (a_val, b_val));
155+ }
156+
157+ memcpy (C_ij, &sum, sizeof (Sleef_quad));
158+ }
159+ else {
160+ // Long double backend
161+ long double sum = 0 .0L ;
162+
163+ for (npy_intp k = 0 ; k < n; k++) {
164+ char *A_ik = A_batch + i * A_row_stride + k * A_col_stride;
165+ char *B_kj = B_batch + k * B_row_stride + j * B_col_stride;
166+
167+ long double a_val, b_val;
168+ memcpy (&a_val, A_ik, sizeof (long double ));
169+ memcpy (&b_val, B_kj, sizeof (long double ));
170+
171+ sum += a_val * b_val;
172+ }
173+
174+ memcpy (C_ij, &sum, sizeof (long double ));
175+ }
176+ }
93177 }
94-
95- in1_ptr += in1_stride;
96- in2_ptr += in2_stride;
97- out_ptr += out_stride;
98178 }
179+
99180 return 0 ;
100181}
101182
102- template <binary_op_quad_def sleef_op, binary_op_longdouble_def longdouble_op>
183+ /* *
184+ * Register matmul support following the exact same pattern as binary_ops.cpp
185+ */
103186int
104- create_matmul_ufunc (PyObject *numpy, const char *ufunc_name )
187+ init_matmul_ops (PyObject *numpy)
105188{
106- PyObject *ufunc = PyObject_GetAttrString (numpy, ufunc_name);
189+ printf (" DEBUG: init_matmul_ops - registering matmul using NumPy 2.0 API\n " );
190+
191+ // Get the existing matmul ufunc - same pattern as binary_ops
192+ PyObject *ufunc = PyObject_GetAttrString (numpy, " matmul" );
107193 if (ufunc == NULL ) {
194+ printf (" DEBUG: Failed to get numpy.matmul\n " );
108195 return -1 ;
109196 }
110197
198+ // Use the same pattern as binary_ops.cpp
111199 PyArray_DTypeMeta *dtypes[3 ] = {&QuadPrecDType, &QuadPrecDType, &QuadPrecDType};
112200
113- PyType_Slot slots[] = {
114- {NPY_METH_resolve_descriptors, (void *)&quad_matmul_resolve_descriptors},
115- {NPY_METH_strided_loop,
116- (void *)&quad_generic_matmul_strided_loop_aligned<sleef_op, longdouble_op>},
117- {NPY_METH_unaligned_strided_loop,
118- (void *)&quad_generic_matmul_strided_loop_unaligned<sleef_op, longdouble_op>},
119- {0 , NULL }};
201+ PyType_Slot slots[] = {{NPY_METH_resolve_descriptors, (void *)&quad_matmul_resolve_descriptors},
202+ {NPY_METH_strided_loop, (void *)&quad_matmul_strided_loop},
203+ {NPY_METH_unaligned_strided_loop, (void *)&quad_matmul_strided_loop},
204+ {0 , NULL }};
120205
121206 PyArrayMethod_Spec Spec = {
122207 .name = " quad_matmul" ,
123208 .nin = 2 ,
124209 .nout = 1 ,
125210 .casting = NPY_NO_CASTING,
126- .flags = (NPY_ARRAYMETHOD_FLAGS)( NPY_METH_SUPPORTS_UNALIGNED | NPY_METH_IS_REORDERABLE) ,
211+ .flags = NPY_METH_SUPPORTS_UNALIGNED,
127212 .dtypes = dtypes,
128213 .slots = slots,
129214 };
130215
216+ printf (" DEBUG: About to add loop to matmul ufunc...\n " );
217+
131218 if (PyUFunc_AddLoopFromSpec (ufunc, &Spec) < 0 ) {
219+ printf (" DEBUG: Failed to add loop to matmul ufunc\n " );
220+ Py_DECREF (ufunc);
132221 return -1 ;
133222 }
134- // my guess we don't need any promoter here as of now, since matmul is quad specific
135- return 0 ;
136- }
137223
138- int
139- init_matmul_ops (PyObject *numpy)
140- {
141- if (create_matmul_ufunc<quad_add, ld_add>(numpy, " matmul" ) < 0 ) {
224+ printf (" DEBUG: Successfully added matmul loop!\n " );
225+
226+ // Add promoter following binary_ops pattern
227+ PyObject *promoter_capsule =
228+ PyCapsule_New ((void *)&quad_ufunc_promoter, " numpy._ufunc_promoter" , NULL );
229+ if (promoter_capsule == NULL ) {
230+ Py_DECREF (ufunc);
231+ return -1 ;
232+ }
233+
234+ PyObject *DTypes = PyTuple_Pack (3 , &PyArrayDescr_Type, &PyArrayDescr_Type, &PyArrayDescr_Type);
235+ if (DTypes == NULL ) {
236+ Py_DECREF (promoter_capsule);
237+ Py_DECREF (ufunc);
142238 return -1 ;
143239 }
240+
241+ if (PyUFunc_AddPromoter (ufunc, DTypes, promoter_capsule) < 0 ) {
242+ printf (" DEBUG: Failed to add promoter (continuing anyway)\n " );
243+ PyErr_Clear (); // Don't fail if promoter fails
244+ }
245+ else {
246+ printf (" DEBUG: Successfully added promoter\n " );
247+ }
248+
249+ Py_DECREF (DTypes);
250+ Py_DECREF (promoter_capsule);
251+ Py_DECREF (ufunc);
252+
253+ printf (" DEBUG: init_matmul_ops completed successfully\n " );
144254 return 0 ;
145- }
255+ }
0 commit comments