@@ -93,8 +93,9 @@ determine_operation_type(npy_intp m, npy_intp n, npy_intp p)
9393}
9494
9595static int
96- quad_matmul_strided_loop (PyArrayMethod_Context *context, char *const data[],
97- npy_intp const dimensions[], npy_intp const strides[], NpyAuxData *auxdata)
96+ quad_matmul_strided_loop_aligned (PyArrayMethod_Context *context, char *const data[],
97+ npy_intp const dimensions[], npy_intp const strides[],
98+ NpyAuxData *auxdata)
9899{
99100 // Extract dimensions
100101 npy_intp N = dimensions[0 ]; // Batch size, this remains always 1 for matmul afaik
@@ -149,6 +150,8 @@ quad_matmul_strided_loop(PyArrayMethod_Context *context, char *const data[],
149150 size_t incx = B_row_stride / sizeof (Sleef_quad);
150151 size_t incy = C_row_stride / sizeof (Sleef_quad);
151152
153+ memset (C_ptr, 0 , m * p * sizeof (Sleef_quad));
154+
152155 result =
153156 qblas_gemv (' R' , ' N' , m, n, &alpha, A_ptr, lda, B_ptr, incx, &beta, C_ptr, incy);
154157 break ;
@@ -159,32 +162,132 @@ quad_matmul_strided_loop(PyArrayMethod_Context *context, char *const data[],
159162 size_t ldb = B_row_stride / sizeof (Sleef_quad);
160163 size_t ldc_numpy = C_row_stride / sizeof (Sleef_quad);
161164
165+ memset (C_ptr, 0 , m * p * sizeof (Sleef_quad));
166+
167+ size_t ldc_temp = p;
168+
169+ result = qblas_gemm (' R' , ' N' , ' N' , m, p, n, &alpha, A_ptr, lda, B_ptr, ldb, &beta,
170+ C_ptr, ldc_numpy);
171+ break ;
172+ }
173+ }
174+
175+ if (result != 0 ) {
176+ PyErr_SetString (PyExc_RuntimeError, " QBLAS operation failed" );
177+ return -1 ;
178+ }
179+
180+ return 0 ;
181+ }
182+
183+ static int
184+ quad_matmul_strided_loop_unaligned (PyArrayMethod_Context *context, char *const data[],
185+ npy_intp const dimensions[], npy_intp const strides[],
186+ NpyAuxData *auxdata)
187+ {
188+ // Extract dimensions
189+ npy_intp N = dimensions[0 ]; // Batch size, this remains always 1 for matmul afaik
190+ npy_intp m = dimensions[1 ]; // Rows of first matrix
191+ npy_intp n = dimensions[2 ]; // Cols of first matrix / rows of second matrix
192+ npy_intp p = dimensions[3 ]; // Cols of second matrix
193+
194+ // batch strides
195+ npy_intp A_stride = strides[0 ];
196+ npy_intp B_stride = strides[1 ];
197+ npy_intp C_stride = strides[2 ];
198+
199+ // core strides for matrix dimensions
200+ npy_intp A_row_stride = strides[3 ];
201+ npy_intp A_col_stride = strides[4 ];
202+ npy_intp B_row_stride = strides[5 ];
203+ npy_intp B_col_stride = strides[6 ];
204+ npy_intp C_row_stride = strides[7 ];
205+ npy_intp C_col_stride = strides[8 ];
206+
207+ QuadPrecDTypeObject *descr = (QuadPrecDTypeObject *)context->descriptors [0 ];
208+ if (descr->backend != BACKEND_SLEEF) {
209+ PyErr_SetString (PyExc_RuntimeError, " Internal error: non-SLEEF backend in QBLAS matmul" );
210+ return -1 ;
211+ }
212+
213+ MatmulOperationType op_type = determine_operation_type (m, n, p);
214+ Sleef_quad alpha = Sleef_cast_from_doubleq1 (1.0 );
215+ Sleef_quad beta = Sleef_cast_from_doubleq1 (0.0 );
216+
217+ char *A = data[0 ];
218+ char *B = data[1 ];
219+ char *C = data[2 ];
220+
221+ Sleef_quad *A_ptr = (Sleef_quad *)A;
222+ Sleef_quad *B_ptr = (Sleef_quad *)B;
223+ Sleef_quad *C_ptr = (Sleef_quad *)C;
224+
225+ int result = -1 ;
226+
227+ switch (op_type) {
228+ case MATMUL_DOT: {
229+ Sleef_quad *temp_A_buffer = new Sleef_quad[n];
230+ Sleef_quad *temp_B_buffer = new Sleef_quad[n];
231+
232+ memcpy (temp_A_buffer, A_ptr, n * sizeof (Sleef_quad));
233+ memcpy (temp_B_buffer, B_ptr, n * sizeof (Sleef_quad));
234+
235+ size_t incx = 1 ;
236+ size_t incy = 1 ;
237+
238+ result = qblas_dot (n, temp_A_buffer, incx, temp_B_buffer, incy, C_ptr);
239+
240+ delete[] temp_A_buffer;
241+ delete[] temp_B_buffer;
242+ break ;
243+ }
244+
245+ case MATMUL_GEMV: {
246+ size_t lda = A_row_stride / sizeof (Sleef_quad);
247+ size_t incx = B_row_stride / sizeof (Sleef_quad);
248+ size_t incy = C_row_stride / sizeof (Sleef_quad);
249+
162250 Sleef_quad *temp_A_buffer = new Sleef_quad[m * n];
163- if (!temp_A_buffer) {
164- PyErr_SetString (PyExc_MemoryError, " Failed to allocate temporary buffer for GEMM" );
165- delete[] temp_A_buffer;
166- return -1 ;
167- }
168251 Sleef_quad *temp_B_buffer = new Sleef_quad[n * p];
169- if (!temp_B_buffer) {
170- PyErr_SetString (PyExc_MemoryError, " Failed to allocate temporary buffer for GEMM" );
171- delete[] temp_A_buffer;
172- return -1 ;
173- }
174252 memcpy (temp_A_buffer, A_ptr, m * n * sizeof (Sleef_quad));
175253 memcpy (temp_B_buffer, B_ptr, n * p * sizeof (Sleef_quad));
176254 A_ptr = temp_A_buffer;
177255 B_ptr = temp_B_buffer;
178256
257+ // Use temp_C_buffer to avoid unaligned writes
179258 Sleef_quad *temp_C_buffer = new Sleef_quad[m * p];
180- if (!temp_C_buffer) {
181- PyErr_SetString (PyExc_MemoryError,
182- " Failed to allocate temporary buffer for GEMM result" );
183- return -1 ;
184- }
185259
260+ lda = n;
261+ incx = 1 ;
262+ incy = 1 ;
263+
264+ memset (temp_C_buffer, 0 , m * p * sizeof (Sleef_quad));
265+
266+ result = qblas_gemv (' R' , ' N' , m, n, &alpha, A_ptr, lda, B_ptr, incx, &beta,
267+ temp_C_buffer, incy);
268+ break ;
269+ }
270+
271+ case MATMUL_GEMM: {
272+ size_t lda = A_row_stride / sizeof (Sleef_quad);
273+ size_t ldb = B_row_stride / sizeof (Sleef_quad);
274+ size_t ldc_numpy = C_row_stride / sizeof (Sleef_quad);
275+
276+ Sleef_quad *temp_A_buffer = new Sleef_quad[m * n];
277+ Sleef_quad *temp_B_buffer = new Sleef_quad[n * p];
278+ memcpy (temp_A_buffer, A_ptr, m * n * sizeof (Sleef_quad));
279+ memcpy (temp_B_buffer, B_ptr, n * p * sizeof (Sleef_quad));
280+ A_ptr = temp_A_buffer;
281+ B_ptr = temp_B_buffer;
282+
283+ // since these are now contiguous so,
284+ lda = n;
285+ ldb = p;
186286 size_t ldc_temp = p;
187287
288+ Sleef_quad *temp_C_buffer = new Sleef_quad[m * p];
289+ memset (temp_C_buffer, 0 , m * p * sizeof (Sleef_quad));
290+
188291 result = qblas_gemm (' R' , ' N' , ' N' , m, p, n, &alpha, A_ptr, lda, B_ptr, ldb, &beta,
189292 temp_C_buffer, ldc_temp);
190293
@@ -218,8 +321,8 @@ naive_matmul_strided_loop(PyArrayMethod_Context *context, char *const data[],
218321 npy_intp p = dimensions[3 ];
219322
220323 npy_intp A_batch_stride = strides[0 ];
221- npy_intp B_batch_stride = strides[1 ];
222- npy_intp C_batch_stride = strides[2 ];
324+ npy_intp B_stride = strides[1 ];
325+ npy_intp C_stride = strides[2 ];
223326
224327 npy_intp A_row_stride = strides[3 ];
225328 npy_intp A_col_stride = strides[4 ];
@@ -232,46 +335,44 @@ naive_matmul_strided_loop(PyArrayMethod_Context *context, char *const data[],
232335 QuadBackendType backend = descr->backend ;
233336 size_t elem_size = (backend == BACKEND_SLEEF) ? sizeof (Sleef_quad) : sizeof (long double );
234337
235- for (npy_intp batch = 0 ; batch < N; batch++) {
236- char *A_batch = data[0 ] + batch * A_batch_stride;
237- char *B_batch = data[1 ] + batch * B_batch_stride;
238- char *C_batch = data[2 ] + batch * C_batch_stride;
239-
240- for (npy_intp i = 0 ; i < m; i++) {
241- for (npy_intp j = 0 ; j < p; j++) {
242- char *C_ij = C_batch + i * C_row_stride + j * C_col_stride;
338+ char *A = data[0 ];
339+ char *B = data[1 ];
340+ char *C = data[2 ];
243341
244- if (backend == BACKEND_SLEEF) {
245- Sleef_quad sum = Sleef_cast_from_doubleq1 (0.0 );
342+ for (npy_intp i = 0 ; i < m; i++) {
343+ for (npy_intp j = 0 ; j < p; j++) {
344+ char *C_ij = C + i * C_row_stride + j * C_col_stride;
246345
247- for (npy_intp k = 0 ; k < n; k++) {
248- char *A_ik = A_batch + i * A_row_stride + k * A_col_stride;
249- char *B_kj = B_batch + k * B_row_stride + j * B_col_stride;
346+ if (backend == BACKEND_SLEEF) {
347+ Sleef_quad sum = Sleef_cast_from_doubleq1 (0.0 );
250348
251- Sleef_quad a_val, b_val;
252- memcpy (&a_val, A_ik, sizeof (Sleef_quad));
253- memcpy (&b_val, B_kj, sizeof (Sleef_quad));
254- sum = Sleef_fmaq1_u05 (a_val, b_val, sum);
255- }
349+ for (npy_intp k = 0 ; k < n; k++) {
350+ char *A_ik = A + i * A_row_stride + k * A_col_stride;
351+ char *B_kj = B + k * B_row_stride + j * B_col_stride;
256352
257- memcpy (C_ij, &sum, sizeof (Sleef_quad));
353+ Sleef_quad a_val, b_val;
354+ memcpy (&a_val, A_ik, sizeof (Sleef_quad));
355+ memcpy (&b_val, B_kj, sizeof (Sleef_quad));
356+ sum = Sleef_fmaq1_u05 (a_val, b_val, sum);
258357 }
259- else {
260- long double sum = 0 .0L ;
261358
262- for (npy_intp k = 0 ; k < n; k++) {
263- char *A_ik = A_batch + i * A_row_stride + k * A_col_stride;
264- char *B_kj = B_batch + k * B_row_stride + j * B_col_stride;
359+ memcpy (C_ij, &sum, sizeof (Sleef_quad));
360+ }
361+ else {
362+ long double sum = 0 .0L ;
265363
266- long double a_val, b_val;
267- memcpy (&a_val, A_ik, sizeof ( long double )) ;
268- memcpy (&b_val, B_kj, sizeof ( long double )) ;
364+ for (npy_intp k = 0 ; k < n; k++) {
365+ char *A_ik = A + i * A_row_stride + k * A_col_stride ;
366+ char *B_kj = B + k * B_row_stride + j * B_col_stride ;
269367
270- sum += a_val * b_val;
271- }
368+ long double a_val, b_val;
369+ memcpy (&a_val, A_ik, sizeof (long double ));
370+ memcpy (&b_val, B_kj, sizeof (long double ));
272371
273- memcpy (C_ij, & sum, sizeof ( long double )) ;
372+ sum += a_val * b_val ;
274373 }
374+
375+ memcpy (C_ij, &sum, sizeof (long double ));
275376 }
276377 }
277378 }
@@ -289,21 +390,22 @@ init_matmul_ops(PyObject *numpy)
289390
290391 PyArray_DTypeMeta *dtypes[3 ] = {&QuadPrecDType, &QuadPrecDType, &QuadPrecDType};
291392
292- #ifndef DISABLE_QUADBLAS
393+ #ifndef DISABLE_QUADBLAS
293394 // set threading to max
294395 int num_threads = _quadblas_get_num_threads ();
295396 _quadblas_set_num_threads (num_threads);
296397
297- PyType_Slot slots[] = {{NPY_METH_resolve_descriptors, (void *)&quad_matmul_resolve_descriptors},
298- {NPY_METH_strided_loop, (void *)&quad_matmul_strided_loop},
299- {NPY_METH_unaligned_strided_loop, (void *)&naive_matmul_strided_loop},
300- {0 , NULL }};
301- #else
398+ PyType_Slot slots[] = {
399+ {NPY_METH_resolve_descriptors, (void *)&quad_matmul_resolve_descriptors},
400+ {NPY_METH_strided_loop, (void *)&quad_matmul_strided_loop_aligned},
401+ {NPY_METH_unaligned_strided_loop, (void *)&quad_matmul_strided_loop_unaligned},
402+ {0 , NULL }};
403+ #else
302404 PyType_Slot slots[] = {{NPY_METH_resolve_descriptors, (void *)&quad_matmul_resolve_descriptors},
303405 {NPY_METH_strided_loop, (void *)&naive_matmul_strided_loop},
304406 {NPY_METH_unaligned_strided_loop, (void *)&naive_matmul_strided_loop},
305407 {0 , NULL }};
306- #endif // DISABLE_QUADBLAS
408+ #endif // DISABLE_QUADBLAS
307409
308410 PyArrayMethod_Spec Spec = {
309411 .name = " quad_matmul_qblas" ,
@@ -335,7 +437,7 @@ init_matmul_ops(PyObject *numpy)
335437 }
336438
337439 if (PyUFunc_AddPromoter (ufunc, DTypes, promoter_capsule) < 0 ) {
338- PyErr_Clear (); // Don't fail if promoter fails
440+ PyErr_Clear ();
339441 }
340442 else {
341443 }
0 commit comments