99
1010extern " C" {
1111#include < Python.h>
12+ #include < cstdio>
1213
1314#include " numpy/arrayobject.h"
1415#include " numpy/ndarraytypes.h"
1516#include " numpy/ufuncobject.h"
1617
1718#include " numpy/dtype_api.h"
1819}
19-
2020#include " dtype.h"
2121#include " umath.h"
2222#include " ops.hpp"
@@ -33,18 +33,22 @@ quad_generic_unary_op_strided_loop(PyArrayMethod_Context *context, char *const d
3333 npy_intp in_stride = strides[0 ];
3434 npy_intp out_stride = strides[1 ];
3535
36+ Sleef_quad in, out;
3637 while (N--) {
37- unary_op ((Sleef_quad *)in_ptr, (Sleef_quad *)out_ptr);
38+ memcpy (&in, in_ptr, sizeof (Sleef_quad));
39+ unary_op (&in, &out);
40+ memcpy (out_ptr, &out, sizeof (Sleef_quad));
41+
3842 in_ptr += in_stride;
3943 out_ptr += out_stride;
4044 }
4145 return 0 ;
4246}
4347
4448static NPY_CASTING
45- quad_unary_op_resolve_descriptors (PyObject *self, PyArray_DTypeMeta *dtypes[],
46- QuadPrecDTypeObject * given_descrs[],
47- QuadPrecDTypeObject *loop_descrs[], npy_intp *unused )
49+ quad_unary_op_resolve_descriptors (PyObject *self, PyArray_DTypeMeta *const dtypes[],
50+ PyArray_Descr * const given_descrs[], PyArray_Descr *loop_descrs [],
51+ npy_intp *NPY_UNUSED (view_offset) )
4852{
4953 Py_INCREF (given_descrs[0 ]);
5054 loop_descrs[0 ] = given_descrs[0 ];
@@ -57,7 +61,7 @@ quad_unary_op_resolve_descriptors(PyObject *self, PyArray_DTypeMeta *dtypes[],
5761 Py_INCREF (given_descrs[1 ]);
5862 loop_descrs[1 ] = given_descrs[1 ];
5963
60- return NPY_NO_CASTING; // Quad precision is always the same precision
64+ return NPY_NO_CASTING;
6165}
6266
6367template <unary_op_def unary_op>
@@ -156,8 +160,12 @@ quad_generic_binop_strided_loop(PyArrayMethod_Context *context, char *const data
156160 npy_intp in2_stride = strides[1 ];
157161 npy_intp out_stride = strides[2 ];
158162
163+ Sleef_quad in1, in2, out;
159164 while (N--) {
160- binop ((Sleef_quad *)out_ptr, (Sleef_quad *)in1_ptr, (Sleef_quad *)in2_ptr);
165+ memcpy (&in1, in1_ptr, sizeof (Sleef_quad));
166+ memcpy (&in2, in2_ptr, sizeof (Sleef_quad));
167+ binop (&out, &in1, &in2);
168+ memcpy (out_ptr, &out, sizeof (Sleef_quad));
161169
162170 in1_ptr += in1_stride;
163171 in2_ptr += in2_stride;
@@ -167,35 +175,186 @@ quad_generic_binop_strided_loop(PyArrayMethod_Context *context, char *const data
167175}
168176
169177static NPY_CASTING
170- quad_binary_op_resolve_descriptors (PyObject *self, PyArray_DTypeMeta *dtypes[],
171- QuadPrecDTypeObject * given_descrs[],
172- QuadPrecDTypeObject *loop_descrs[], npy_intp *unused )
178+ quad_binary_op_resolve_descriptors (PyObject *self, PyArray_DTypeMeta *const dtypes[],
179+ PyArray_Descr * const given_descrs[],
180+ PyArray_Descr *loop_descrs[], npy_intp *NPY_UNUSED (view_offset) )
173181{
174182 Py_INCREF (given_descrs[0 ]);
175183 loop_descrs[0 ] = given_descrs[0 ];
176184 Py_INCREF (given_descrs[1 ]);
177185 loop_descrs[1 ] = given_descrs[1 ];
178186
179187 if (given_descrs[2 ] == NULL ) {
188+ PyArray_Descr *out_descr = (PyArray_Descr *)new_quaddtype_instance ();
189+ if (!out_descr) {
190+ return (NPY_CASTING)-1 ;
191+ }
180192 Py_INCREF (given_descrs[0 ]);
181- loop_descrs[2 ] = given_descrs[ 0 ] ;
193+ loop_descrs[2 ] = out_descr ;
182194 }
183195 else {
184196 Py_INCREF (given_descrs[2 ]);
185197 loop_descrs[2 ] = given_descrs[2 ];
186198 }
187199
188- return NPY_NO_CASTING; // Quad precision is always the same precision
200+ return NPY_NO_CASTING;
189201}
190202
191- // todo: skipping the promoter for now, since same type operation will be requried
203+ // helper debugging function
204+ static const char *
205+ get_dtype_name (PyArray_DTypeMeta *dtype)
206+ {
207+ if (dtype == &QuadPrecDType) {
208+ return " QuadPrecDType" ;
209+ }
210+ else if (dtype == &PyArray_BoolDType) {
211+ return " BoolDType" ;
212+ }
213+ else if (dtype == &PyArray_ByteDType) {
214+ return " ByteDType" ;
215+ }
216+ else if (dtype == &PyArray_UByteDType) {
217+ return " UByteDType" ;
218+ }
219+ else if (dtype == &PyArray_ShortDType) {
220+ return " ShortDType" ;
221+ }
222+ else if (dtype == &PyArray_UShortDType) {
223+ return " UShortDType" ;
224+ }
225+ else if (dtype == &PyArray_IntDType) {
226+ return " IntDType" ;
227+ }
228+ else if (dtype == &PyArray_UIntDType) {
229+ return " UIntDType" ;
230+ }
231+ else if (dtype == &PyArray_LongDType) {
232+ return " LongDType" ;
233+ }
234+ else if (dtype == &PyArray_ULongDType) {
235+ return " ULongDType" ;
236+ }
237+ else if (dtype == &PyArray_LongLongDType) {
238+ return " LongLongDType" ;
239+ }
240+ else if (dtype == &PyArray_ULongLongDType) {
241+ return " ULongLongDType" ;
242+ }
243+ else if (dtype == &PyArray_FloatDType) {
244+ return " FloatDType" ;
245+ }
246+ else if (dtype == &PyArray_DoubleDType) {
247+ return " DoubleDType" ;
248+ }
249+ else if (dtype == &PyArray_LongDoubleDType) {
250+ return " LongDoubleDType" ;
251+ }
252+ else {
253+ return " UnknownDType" ;
254+ }
255+ }
256+
257+ static int
258+ quad_ufunc_promoter (PyUFuncObject *ufunc, PyArray_DTypeMeta *op_dtypes[],
259+ PyArray_DTypeMeta *signature[], PyArray_DTypeMeta *new_op_dtypes[])
260+ {
261+ // printf("quad_ufunc_promoter called for ufunc: %s\n", ufunc->name);
262+ // printf("Entering quad_ufunc_promoter\n");
263+ // printf("Ufunc name: %s\n", ufunc->name);
264+ // printf("nin: %d, nargs: %d\n", ufunc->nin, ufunc->nargs);
265+
266+ int nin = ufunc->nin ;
267+ int nargs = ufunc->nargs ;
268+ PyArray_DTypeMeta *common = NULL ;
269+ bool has_quad = false ;
270+
271+ // Handle the special case for reductions
272+ if (op_dtypes[0 ] == NULL ) {
273+ assert (nin == 2 && ufunc->nout == 1 ); /* must be reduction */
274+ for (int i = 0 ; i < 3 ; i++) {
275+ Py_INCREF (op_dtypes[1 ]);
276+ new_op_dtypes[i] = op_dtypes[1 ];
277+ // printf("new_op_dtypes[%d] set to %s\n", i, get_dtype_name(new_op_dtypes[i]));
278+ }
279+ return 0 ;
280+ }
281+
282+ // Check if any input or signature is QuadPrecision
283+ for (int i = 0 ; i < nargs; i++) {
284+ if ((i < nin && op_dtypes[i] == &QuadPrecDType) || (signature[i] == &QuadPrecDType)) {
285+ has_quad = true ;
286+ // printf("QuadPrecision detected in input %d or signature\n", i);
287+ break ;
288+ }
289+ }
290+
291+ if (has_quad) {
292+ // If QuadPrecision is involved, use it for all arguments
293+ common = &QuadPrecDType;
294+ // printf("Using QuadPrecDType as common type\n");
295+ }
296+ else {
297+ // Check if output signature is homogeneous
298+ for (int i = nin; i < nargs; i++) {
299+ if (signature[i] != NULL ) {
300+ if (common == NULL ) {
301+ Py_INCREF (signature[i]);
302+ common = signature[i];
303+ // printf("Common type set to %s from signature\n", get_dtype_name(common));
304+ }
305+ else if (common != signature[i]) {
306+ Py_CLEAR (common); // Not homogeneous, unset common
307+ // printf("Output signature not homogeneous, cleared common type\n");
308+ break ;
309+ }
310+ }
311+ }
312+
313+ // If no common output dtype, use standard promotion for inputs
314+ if (common == NULL ) {
315+ // printf("Using standard promotion for inputs\n");
316+ common = PyArray_PromoteDTypeSequence (nin, op_dtypes);
317+ if (common == NULL ) {
318+ if (PyErr_ExceptionMatches (PyExc_TypeError)) {
319+ PyErr_Clear (); // Do not propagate normal promotion errors
320+ }
321+ // printf("Exiting quad_ufunc_promoter (promotion failed)\n");
322+ return -1 ;
323+ }
324+ // printf("Common type after promotion: %s\n", get_dtype_name(common));
325+ }
326+ }
327+
328+ // Set all new_op_dtypes to the common dtype
329+ for (int i = 0 ; i < nargs; i++) {
330+ if (signature[i]) {
331+ // If signature is specified for this argument, use it
332+ Py_INCREF (signature[i]);
333+ new_op_dtypes[i] = signature[i];
334+ // printf("new_op_dtypes[%d] set to %s (from signature)\n", i,
335+ // get_dtype_name(new_op_dtypes[i]));
336+ }
337+ else {
338+ // Otherwise, use the common dtype
339+ Py_INCREF (common);
340+ new_op_dtypes[i] = common;
341+ // printf("new_op_dtypes[%d] set to %s (from common)\n", i,
342+ // get_dtype_name(new_op_dtypes[i]));
343+ }
344+ }
345+
346+ Py_XDECREF (common);
347+ // printf("Exiting quad_ufunc_promoter\n");
348+ return 0 ;
349+ }
192350
193351template <binop_def binop>
194352int
195353create_quad_binary_ufunc (PyObject *numpy, const char *ufunc_name)
196354{
197355 PyObject *ufunc = PyObject_GetAttrString (numpy, ufunc_name);
198356 if (ufunc == NULL ) {
357+ Py_DecRef (ufunc);
199358 return -1 ;
200359 }
201360
@@ -220,6 +379,25 @@ create_quad_binary_ufunc(PyObject *numpy, const char *ufunc_name)
220379 return -1 ;
221380 }
222381
382+ PyObject *promoter_capsule =
383+ PyCapsule_New ((void *)&quad_ufunc_promoter, " numpy._ufunc_promoter" , NULL );
384+ if (promoter_capsule == NULL ) {
385+ return -1 ;
386+ }
387+
388+ PyObject *DTypes = PyTuple_Pack (3 , &PyArrayDescr_Type, &PyArrayDescr_Type, &PyArrayDescr_Type);
389+ if (DTypes == 0 ) {
390+ Py_DECREF (promoter_capsule);
391+ return -1 ;
392+ }
393+
394+ if (PyUFunc_AddPromoter (ufunc, DTypes, promoter_capsule) < 0 ) {
395+ Py_DECREF (promoter_capsule);
396+ Py_DECREF (DTypes);
397+ return -1 ;
398+ }
399+ Py_DECREF (promoter_capsule);
400+ Py_DECREF (DTypes);
223401 return 0 ;
224402}
225403
@@ -272,6 +450,22 @@ quad_generic_comp_strided_loop(PyArrayMethod_Context *context, char *const data[
272450 return 0 ;
273451}
274452
453+ NPY_NO_EXPORT int
454+ comparison_ufunc_promoter (PyUFuncObject *ufunc, PyArray_DTypeMeta *op_dtypes[],
455+ PyArray_DTypeMeta *signature[], PyArray_DTypeMeta *new_op_dtypes[])
456+ {
457+ PyArray_DTypeMeta *new_signature[NPY_MAXARGS];
458+
459+ memcpy (new_signature, signature, 3 * sizeof (PyArray_DTypeMeta *));
460+ new_signature[2 ] = NULL ;
461+ int res = quad_ufunc_promoter (ufunc, op_dtypes, new_signature, new_op_dtypes);
462+ if (res < 0 ) {
463+ return -1 ;
464+ }
465+ Py_XSETREF (new_op_dtypes[2 ], &PyArray_BoolDType);
466+ return 0 ;
467+ }
468+
275469template <cmp_def comp>
276470int
277471create_quad_comparison_ufunc (PyObject *numpy, const char *ufunc_name)
@@ -300,6 +494,26 @@ create_quad_comparison_ufunc(PyObject *numpy, const char *ufunc_name)
300494 return -1 ;
301495 }
302496
497+ PyObject *promoter_capsule =
498+ PyCapsule_New ((void *)&comparison_ufunc_promoter, " numpy._ufunc_promoter" , NULL );
499+ if (promoter_capsule == NULL ) {
500+ return -1 ;
501+ }
502+
503+ PyObject *DTypes = PyTuple_Pack (3 , &PyArrayDescr_Type, &PyArrayDescr_Type, &PyArray_BoolDType);
504+ if (DTypes == 0 ) {
505+ Py_DECREF (promoter_capsule);
506+ return -1 ;
507+ }
508+
509+ if (PyUFunc_AddPromoter (ufunc, DTypes, promoter_capsule) < 0 ) {
510+ Py_DECREF (promoter_capsule);
511+ Py_DECREF (DTypes);
512+ return -1 ;
513+ }
514+ Py_DECREF (promoter_capsule);
515+ Py_DECREF (DTypes);
516+
303517 return 0 ;
304518}
305519
0 commit comments