Skip to content

Commit 3d79e9c

Browse files
committed
ldexp impl
1 parent 77882ca commit 3d79e9c

File tree

4 files changed

+368
-2
lines changed

4 files changed

+368
-2
lines changed

quaddtype/numpy_quaddtype/src/ops.hpp

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1047,6 +1047,60 @@ quad_spacing(const Sleef_quad *x)
10471047
return result;
10481048
}
10491049

1050+
// Mixed-type operations (quad, int) -> quad
1051+
typedef Sleef_quad (*ldexp_op_quad_def)(const Sleef_quad *, const int *);
1052+
typedef long double (*ldexp_op_longdouble_def)(const long double *, const int *);
1053+
1054+
static inline Sleef_quad
1055+
quad_ldexp(const Sleef_quad *x, const int *exp)
1056+
{
1057+
// ldexp(x, exp) returns x * 2^exp
1058+
1059+
// NaN input -> NaN output (with sign preserved)
1060+
if (Sleef_iunordq1(*x, *x)) {
1061+
return *x;
1062+
}
1063+
1064+
// ±0 * 2^exp = ±0 (preserves sign of zero)
1065+
if (Sleef_icmpeqq1(*x, QUAD_ZERO)) {
1066+
return *x;
1067+
}
1068+
1069+
// ±inf * 2^exp = ±inf (preserves sign of infinity)
1070+
if (quad_isinf(x)) {
1071+
return *x;
1072+
}
1073+
1074+
Sleef_quad result = Sleef_ldexpq1(*x, *exp);
1075+
1076+
return result;
1077+
}
1078+
1079+
static inline long double
1080+
ld_ldexp(const long double *x, const int *exp)
1081+
{
1082+
// ldexp(x, exp) returns x * 2^exp
1083+
1084+
// NaN input -> NaN output
1085+
if (isnan(*x)) {
1086+
return *x;
1087+
}
1088+
1089+
// ±0 * 2^exp = ±0 (preserves sign of zero)
1090+
if (*x == 0.0L) {
1091+
return *x;
1092+
}
1093+
1094+
// ±inf * 2^exp = ±inf (preserves sign of infinity)
1095+
if (isinf(*x)) {
1096+
return *x;
1097+
}
1098+
1099+
long double result = ldexpl(*x, *exp);
1100+
1101+
return result;
1102+
}
1103+
10501104
// Binary long double operations
10511105
typedef long double (*binary_op_longdouble_def)(const long double *, const long double *);
10521106
// Binary long double operations with 2 outputs (for divmod, modf, frexp)

quaddtype/numpy_quaddtype/src/umath/binary_ops.cpp

Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,175 @@ quad_generic_binop_2out_strided_loop_aligned(PyArrayMethod_Context *context, cha
282282
return 0;
283283
}
284284

285+
// todo: I'll preferrable get all this code duplication in templates later
286+
// Special resolve descriptors for ldexp (QuadPrecDType, int32) -> QuadPrecDType
287+
static NPY_CASTING
288+
quad_ldexp_resolve_descriptors(PyObject *self, PyArray_DTypeMeta *const dtypes[],
289+
PyArray_Descr *const given_descrs[],
290+
PyArray_Descr *loop_descrs[], npy_intp *NPY_UNUSED(view_offset))
291+
{
292+
QuadPrecDTypeObject *descr_in1 = (QuadPrecDTypeObject *)given_descrs[0];
293+
QuadBackendType target_backend = descr_in1->backend;
294+
295+
// Input 0: QuadPrecDType
296+
Py_INCREF(given_descrs[0]);
297+
loop_descrs[0] = given_descrs[0];
298+
299+
// Input 1: int (no need to incref, it's a builtin dtype)
300+
if (given_descrs[1] == NULL) {
301+
loop_descrs[1] = PyArray_DescrFromType(NPY_INT32);
302+
} else {
303+
Py_INCREF(given_descrs[1]);
304+
loop_descrs[1] = given_descrs[1];
305+
}
306+
307+
// Output: QuadPrecDType with same backend as input
308+
if (given_descrs[2] == NULL) {
309+
loop_descrs[2] = (PyArray_Descr *)new_quaddtype_instance(target_backend);
310+
if (!loop_descrs[2]) {
311+
return (NPY_CASTING)-1;
312+
}
313+
} else {
314+
QuadPrecDTypeObject *descr_out = (QuadPrecDTypeObject *)given_descrs[2];
315+
if (descr_out->backend != target_backend) {
316+
loop_descrs[2] = (PyArray_Descr *)new_quaddtype_instance(target_backend);
317+
if (!loop_descrs[2]) {
318+
return (NPY_CASTING)-1;
319+
}
320+
} else {
321+
Py_INCREF(given_descrs[2]);
322+
loop_descrs[2] = given_descrs[2];
323+
}
324+
}
325+
return NPY_NO_CASTING;
326+
}
327+
328+
// Strided loop for ldexp (unaligned)
329+
template <ldexp_op_quad_def sleef_op, ldexp_op_longdouble_def longdouble_op>
330+
int
331+
quad_ldexp_strided_loop_unaligned(PyArrayMethod_Context *context, char *const data[],
332+
npy_intp const dimensions[], npy_intp const strides[],
333+
NpyAuxData *auxdata)
334+
{
335+
npy_intp N = dimensions[0];
336+
char *in1_ptr = data[0]; // quad
337+
char *in2_ptr = data[1]; // int32
338+
char *out_ptr = data[2]; // quad
339+
npy_intp in1_stride = strides[0];
340+
npy_intp in2_stride = strides[1];
341+
npy_intp out_stride = strides[2];
342+
343+
QuadPrecDTypeObject *descr = (QuadPrecDTypeObject *)context->descriptors[0];
344+
QuadBackendType backend = descr->backend;
345+
size_t elem_size = (backend == BACKEND_SLEEF) ? sizeof(Sleef_quad) : sizeof(long double);
346+
347+
quad_value in1, out;
348+
int in2;
349+
while (N--) {
350+
memcpy(&in1, in1_ptr, elem_size);
351+
memcpy(&in2, in2_ptr, sizeof(int));
352+
if (backend == BACKEND_SLEEF) {
353+
out.sleef_value = sleef_op(&in1.sleef_value, &in2);
354+
} else {
355+
out.longdouble_value = longdouble_op(&in1.longdouble_value, &in2);
356+
}
357+
memcpy(out_ptr, &out, elem_size);
358+
359+
in1_ptr += in1_stride;
360+
in2_ptr += in2_stride;
361+
out_ptr += out_stride;
362+
}
363+
return 0;
364+
}
365+
366+
// Strided loop for ldexp (aligned)
367+
template <ldexp_op_quad_def sleef_op, ldexp_op_longdouble_def longdouble_op>
368+
int
369+
quad_ldexp_strided_loop_aligned(PyArrayMethod_Context *context, char *const data[],
370+
npy_intp const dimensions[], npy_intp const strides[],
371+
NpyAuxData *auxdata)
372+
{
373+
npy_intp N = dimensions[0];
374+
char *in1_ptr = data[0]; // quad
375+
char *in2_ptr = data[1]; // int32
376+
char *out_ptr = data[2]; // quad
377+
npy_intp in1_stride = strides[0];
378+
npy_intp in2_stride = strides[1];
379+
npy_intp out_stride = strides[2];
380+
381+
QuadPrecDTypeObject *descr = (QuadPrecDTypeObject *)context->descriptors[0];
382+
QuadBackendType backend = descr->backend;
383+
384+
while (N--) {
385+
if (backend == BACKEND_SLEEF) {
386+
*(Sleef_quad *)out_ptr = sleef_op((Sleef_quad *)in1_ptr, (int *)in2_ptr);
387+
} else {
388+
*(long double *)out_ptr = longdouble_op((long double *)in1_ptr, (int *)in2_ptr);
389+
}
390+
391+
in1_ptr += in1_stride;
392+
in2_ptr += in2_stride;
393+
out_ptr += out_stride;
394+
}
395+
return 0;
396+
}
397+
398+
// Create ldexp ufunc
399+
template <ldexp_op_quad_def sleef_op, ldexp_op_longdouble_def longdouble_op>
400+
int
401+
create_quad_ldexp_ufunc(PyObject *numpy, const char *ufunc_name)
402+
{
403+
PyObject *ufunc = PyObject_GetAttrString(numpy, ufunc_name);
404+
if (ufunc == NULL) {
405+
return -1;
406+
}
407+
408+
PyArray_DTypeMeta *dtypes[3] = {&QuadPrecDType, &PyArray_PyLongDType, &QuadPrecDType};
409+
410+
PyType_Slot slots[] = {
411+
{NPY_METH_resolve_descriptors, (void *)&quad_ldexp_resolve_descriptors},
412+
{NPY_METH_strided_loop,
413+
(void *)&quad_ldexp_strided_loop_aligned<sleef_op, longdouble_op>},
414+
{NPY_METH_unaligned_strided_loop,
415+
(void *)&quad_ldexp_strided_loop_unaligned<sleef_op, longdouble_op>},
416+
{0, NULL}};
417+
418+
PyArrayMethod_Spec Spec = {
419+
.name = "quad_ldexp",
420+
.nin = 2,
421+
.nout = 1,
422+
.casting = NPY_NO_CASTING,
423+
.flags = (NPY_ARRAYMETHOD_FLAGS)(NPY_METH_SUPPORTS_UNALIGNED | NPY_METH_IS_REORDERABLE),
424+
.dtypes = dtypes,
425+
.slots = slots,
426+
};
427+
428+
if (PyUFunc_AddLoopFromSpec(ufunc, &Spec) < 0) {
429+
return -1;
430+
}
431+
432+
PyObject *promoter_capsule =
433+
PyCapsule_New((void *)&quad_ufunc_promoter, "numpy._ufunc_promoter", NULL);
434+
if (promoter_capsule == NULL) {
435+
return -1;
436+
}
437+
438+
PyObject *DTypes = PyTuple_Pack(3, &PyArrayDescr_Type, &PyArray_PyLongDType, &PyArrayDescr_Type);
439+
if (DTypes == 0) {
440+
Py_DECREF(promoter_capsule);
441+
return -1;
442+
}
443+
444+
if (PyUFunc_AddPromoter(ufunc, DTypes, promoter_capsule) < 0) {
445+
Py_DECREF(promoter_capsule);
446+
Py_DECREF(DTypes);
447+
return -1;
448+
}
449+
Py_DECREF(promoter_capsule);
450+
Py_DECREF(DTypes);
451+
return 0;
452+
}
453+
285454
// Create binary ufunc with 2 outputs (generic for divmod, modf, frexp, etc.)
286455
template <binary_op_2out_quad_def sleef_op, binary_op_2out_longdouble_def longdouble_op>
287456
int
@@ -466,5 +635,8 @@ init_quad_binary_ops(PyObject *numpy)
466635
if (create_quad_binary_2out_ufunc<quad_divmod, ld_divmod>(numpy, "divmod") < 0) {
467636
return -1;
468637
}
638+
if (create_quad_ldexp_ufunc<quad_ldexp, ld_ldexp>(numpy, "ldexp") < 0) {
639+
return -1;
640+
}
469641
return 0;
470642
}

quaddtype/release_tracker.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@
8080
| nextafter |||
8181
| spacing |||
8282
| modf |||
83-
| ldexp | | |
83+
| ldexp | | |
8484
| frexp | | |
8585
| floor |||
8686
| ceil |||

0 commit comments

Comments
 (0)