Skip to content

Commit 5ac0265

Browse files
authored
Merge pull request #172 from SwayamInSync/171
2 parents e47738c + cd19212 commit 5ac0265

File tree

2 files changed

+137
-7
lines changed

2 files changed

+137
-7
lines changed

quaddtype/numpy_quaddtype/src/dtype.c

Lines changed: 66 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,13 @@ common_instance(QuadPrecDTypeObject *dtype1, QuadPrecDTypeObject *dtype2)
9797
static PyArray_DTypeMeta *
9898
common_dtype(PyArray_DTypeMeta *cls, PyArray_DTypeMeta *other)
9999
{
100+
// Handle Python abstract dtypes (PyLongDType, PyFloatDType)
101+
// These have type_num = -1
102+
if (other == &PyArray_PyLongDType || other == &PyArray_PyFloatDType) {
103+
Py_INCREF(cls);
104+
return cls;
105+
}
106+
100107
// Promote integer and floating-point types to QuadPrecDType
101108
if (other->type_num >= 0 &&
102109
(PyTypeNum_ISINTEGER(other->type_num) || PyTypeNum_ISFLOAT(other->type_num))) {
@@ -116,14 +123,21 @@ common_dtype(PyArray_DTypeMeta *cls, PyArray_DTypeMeta *other)
116123
static PyArray_Descr *
117124
quadprec_discover_descriptor_from_pyobject(PyArray_DTypeMeta *NPY_UNUSED(cls), PyObject *obj)
118125
{
119-
if (Py_TYPE(obj) != &QuadPrecision_Type) {
120-
PyErr_SetString(PyExc_TypeError, "Can only store QuadPrecision in a QuadPrecDType array.");
121-
return NULL;
126+
if (Py_TYPE(obj) == &QuadPrecision_Type) {
127+
/* QuadPrecision scalar: use its backend */
128+
QuadPrecisionObject *quad_obj = (QuadPrecisionObject *)obj;
129+
return (PyArray_Descr *)new_quaddtype_instance(quad_obj->backend);
122130
}
123-
124-
QuadPrecisionObject *quad_obj = (QuadPrecisionObject *)obj;
125-
126-
return (PyArray_Descr *)new_quaddtype_instance(quad_obj->backend);
131+
132+
/* For Python int/float/other numeric types: return default descriptor */
133+
/* The casting machinery will handle conversion to QuadPrecision */
134+
if (PyLong_Check(obj) || PyFloat_Check(obj)) {
135+
return (PyArray_Descr *)new_quaddtype_instance(BACKEND_SLEEF);
136+
}
137+
138+
/* Unknown type - ERROR */
139+
PyErr_SetString(PyExc_TypeError, "Can only store QuadPrecision, int, or float in a QuadPrecDType array.");
140+
return NULL;
127141
}
128142

129143
static int
@@ -261,6 +275,50 @@ quadprec_get_constant(PyArray_Descr *descr, int constant_id, void *ptr)
261275
return 1;
262276
}
263277

278+
/*
279+
* Fill function.
280+
* The buffer already has the first two elements set:
281+
* buffer[0] = start
282+
* buffer[1] = start + step
283+
* We need to fill buffer[2..length-1] with the arithmetic progression.
284+
*/
285+
static int
286+
quadprec_fill(void *buffer, npy_intp length, void *arr_)
287+
{
288+
PyArrayObject *arr = (PyArrayObject *)arr_;
289+
QuadPrecDTypeObject *descr = (QuadPrecDTypeObject *)PyArray_DESCR(arr);
290+
QuadBackendType backend = descr->backend;
291+
npy_intp i;
292+
293+
if (length < 2) {
294+
return 0; // Nothing to fill
295+
}
296+
297+
if (backend == BACKEND_SLEEF) {
298+
Sleef_quad *buf = (Sleef_quad *)buffer;
299+
Sleef_quad start = buf[0];
300+
Sleef_quad delta = Sleef_subq1_u05(buf[1], start); // delta = buf[1] - start
301+
302+
for (i = 2; i < length; ++i) {
303+
// buf[i] = start + i * delta
304+
Sleef_quad i_quad = Sleef_cast_from_doubleq1(i);
305+
Sleef_quad i_delta = Sleef_mulq1_u05(i_quad, delta);
306+
buf[i] = Sleef_addq1_u05(start, i_delta);
307+
}
308+
}
309+
else {
310+
long double *buf = (long double *)buffer;
311+
long double start = buf[0];
312+
long double delta = buf[1] - start;
313+
314+
for (i = 2; i < length; ++i) {
315+
buf[i] = start + i * delta;
316+
}
317+
}
318+
319+
return 0;
320+
}
321+
264322
static PyType_Slot QuadPrecDType_Slots[] = {
265323
{NPY_DT_ensure_canonical, &ensure_canonical},
266324
{NPY_DT_common_instance, &common_instance},
@@ -270,6 +328,7 @@ static PyType_Slot QuadPrecDType_Slots[] = {
270328
{NPY_DT_getitem, &quadprec_getitem},
271329
{NPY_DT_default_descr, &quadprec_default_descr},
272330
{NPY_DT_get_constant, &quadprec_get_constant},
331+
{NPY_DT_PyArray_ArrFuncs_fill, &quadprec_fill},
273332
{0, NULL}};
274333

275334
static PyObject *

quaddtype/tests/test_quaddtype.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -692,3 +692,74 @@ def test_hyperbolic_functions(op, val):
692692
if float_result == 0.0:
693693
assert np.signbit(float_result) == np.signbit(
694694
quad_result), f"Zero sign mismatch for {op}({val})"
695+
696+
697+
class TestTypePomotionWithPythonAbstractTypes:
698+
"""Tests for common_dtype handling of Python abstract dtypes (PyLongDType, PyFloatDType)"""
699+
700+
def test_promotion_with_python_int(self):
701+
"""Test that Python int promotes to QuadPrecDType"""
702+
# Create array from Python int
703+
arr = np.array([1, 2, 3], dtype=QuadPrecDType)
704+
assert arr.dtype.name == "QuadPrecDType128"
705+
assert len(arr) == 3
706+
assert float(arr[0]) == 1.0
707+
assert float(arr[1]) == 2.0
708+
assert float(arr[2]) == 3.0
709+
710+
def test_promotion_with_python_float(self):
711+
"""Test that Python float promotes to QuadPrecDType"""
712+
# Create array from Python float
713+
arr = np.array([1.5, 2.7, 3.14], dtype=QuadPrecDType)
714+
assert arr.dtype.name == "QuadPrecDType128"
715+
assert len(arr) == 3
716+
np.testing.assert_allclose(float(arr[0]), 1.5, rtol=1e-15)
717+
np.testing.assert_allclose(float(arr[1]), 2.7, rtol=1e-15)
718+
np.testing.assert_allclose(float(arr[2]), 3.14, rtol=1e-15)
719+
720+
def test_result_dtype_binary_ops_with_python_types(self):
721+
"""Test that binary operations between QuadPrecDType and Python scalars return QuadPrecDType"""
722+
quad_arr = np.array([QuadPrecision("1.0"), QuadPrecision("2.0")])
723+
724+
# Addition with Python int
725+
result = quad_arr + 5
726+
assert result.dtype.name == "QuadPrecDType128"
727+
assert float(result[0]) == 6.0
728+
assert float(result[1]) == 7.0
729+
730+
# Multiplication with Python float
731+
result = quad_arr * 2.5
732+
assert result.dtype.name == "QuadPrecDType128"
733+
np.testing.assert_allclose(float(result[0]), 2.5, rtol=1e-15)
734+
np.testing.assert_allclose(float(result[1]), 5.0, rtol=1e-15)
735+
736+
def test_concatenate_with_python_types(self):
737+
"""Test concatenation handles Python numeric types correctly"""
738+
quad_arr = np.array([QuadPrecision("1.0")])
739+
# This should work if promotion is correct
740+
int_arr = np.array([2], dtype=np.int64)
741+
742+
# The result dtype should be QuadPrecDType
743+
result = np.concatenate([quad_arr, int_arr.astype(QuadPrecDType)])
744+
assert result.dtype.name == "QuadPrecDType128"
745+
assert len(result) == 2
746+
747+
748+
@pytest.mark.parametrize("func,args,expected", [
749+
# arange tests
750+
(np.arange, (0, 10), list(range(10))),
751+
(np.arange, (0, 10, 2), [0, 2, 4, 6, 8]),
752+
(np.arange, (0.0, 5.0, 0.5), [0.0, 0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5]),
753+
(np.arange, (10, 0, -1), [10, 9, 8, 7, 6, 5, 4, 3, 2, 1]),
754+
(np.arange, (-5, 5), list(range(-5, 5))),
755+
# linspace tests
756+
(np.linspace, (0, 10, 11), list(range(11))),
757+
(np.linspace, (0, 1, 5), [0.0, 0.25, 0.5, 0.75, 1.0]),
758+
])
759+
def test_fill_function(func, args, expected):
760+
"""Test quadprec_fill function with arange and linspace"""
761+
arr = func(*args, dtype=QuadPrecDType())
762+
assert arr.dtype.name == "QuadPrecDType128"
763+
assert len(arr) == len(expected)
764+
for i, exp_val in enumerate(expected):
765+
np.testing.assert_allclose(float(arr[i]), float(exp_val), rtol=1e-15, atol=1e-15)

0 commit comments

Comments
 (0)