Skip to content

Commit 16cbd57

Browse files
committed
tests + logic
1 parent f52ffdf commit 16cbd57

File tree

2 files changed

+106
-4
lines changed

2 files changed

+106
-4
lines changed

quaddtype/numpy_quaddtype/src/scalar.c

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "scalar.h"
1515
#include "scalar_ops.h"
1616
#include "dragon4.h"
17+
#include "dtype.h"
1718

1819
// For IEEE 754 binary128 (quad precision), we need 36 decimal digits
1920
// to guarantee round-trip conversion (string -> parse -> equals original value)
@@ -42,7 +43,29 @@ QuadPrecision_raw_new(QuadBackendType backend)
4243

4344
QuadPrecisionObject *
4445
QuadPrecision_from_object(PyObject *value, QuadBackendType backend)
45-
{
46+
{
47+
48+
if (PyArray_Check(value) || (PySequence_Check(value) && !PyUnicode_Check(value) && !PyBytes_Check(value)))
49+
{
50+
QuadPrecDTypeObject *dtype_descr = new_quaddtype_instance(backend);
51+
if (dtype_descr == NULL) {
52+
return NULL;
53+
}
54+
55+
56+
PyObject *result = PyArray_FromAny(
57+
value,
58+
(PyArray_Descr *)dtype_descr,
59+
0,
60+
0,
61+
NPY_ARRAY_ENSUREARRAY, // this should handle the casting if possible
62+
NULL
63+
);
64+
65+
// PyArray_FromAny steals the reference to dtype_descr, so no need to DECREF
66+
return (QuadPrecisionObject *)result;
67+
}
68+
4669
QuadPrecisionObject *self = QuadPrecision_raw_new(backend);
4770
if (!self)
4871
return NULL;
@@ -105,21 +128,21 @@ QuadPrecision_from_object(PyObject *value, QuadBackendType backend)
105128
const char *type_cstr = PyUnicode_AsUTF8(type_str);
106129
if (type_cstr != NULL) {
107130
PyErr_Format(PyExc_TypeError,
108-
"QuadPrecision value must be a quad, float, int or string, but got %s "
131+
"QuadPrecision value must be a quad, float, int, string, array or sequence, but got %s "
109132
"instead",
110133
type_cstr);
111134
}
112135
else {
113136
PyErr_SetString(
114137
PyExc_TypeError,
115-
"QuadPrecision value must be a quad, float, int or string, but got an "
138+
"QuadPrecision value must be a quad, float, int, string, array or sequence, but got an "
116139
"unknown type instead");
117140
}
118141
Py_DECREF(type_str);
119142
}
120143
else {
121144
PyErr_SetString(PyExc_TypeError,
122-
"QuadPrecision value must be a quad, float, int or string, but got an "
145+
"QuadPrecision value must be a quad, float, int, string, array or sequence, but got an "
123146
"unknown type instead");
124147
}
125148
Py_DECREF(self);

quaddtype/tests/test_quaddtype.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,85 @@ def test_create_scalar_simple():
1313
assert isinstance(QuadPrecision(1), QuadPrecision)
1414

1515

16+
class TestQuadPrecisionArrayCreation:
17+
"""Test suite for QuadPrecision array creation from sequences and arrays."""
18+
19+
def test_create_array_from_list(self):
20+
"""Test that QuadPrecision can create arrays from lists."""
21+
# Test with simple list
22+
result = QuadPrecision([3, 4, 5])
23+
assert isinstance(result, np.ndarray)
24+
assert result.dtype.name == "QuadPrecDType128"
25+
assert result.shape == (3,)
26+
assert np.array_equal(result, np.array([3, 4, 5], dtype=QuadPrecDType))
27+
28+
# Test with float list
29+
result = QuadPrecision([1.5, 2.5, 3.5])
30+
assert isinstance(result, np.ndarray)
31+
assert result.dtype.name == "QuadPrecDType128"
32+
assert result.shape == (3,)
33+
34+
def test_create_array_from_tuple(self):
35+
"""Test that QuadPrecision can create arrays from tuples."""
36+
result = QuadPrecision((10, 20, 30))
37+
assert isinstance(result, np.ndarray)
38+
assert result.dtype.name == "QuadPrecDType128"
39+
assert result.shape == (3,)
40+
assert np.array_equal(result, np.array([10, 20, 30], dtype=QuadPrecDType))
41+
42+
def test_create_array_from_ndarray(self):
43+
"""Test that QuadPrecision can create arrays from numpy arrays."""
44+
arr = np.array([1, 2, 3, 4])
45+
result = QuadPrecision(arr)
46+
assert isinstance(result, np.ndarray)
47+
assert result.dtype.name == "QuadPrecDType128"
48+
assert result.shape == (4,)
49+
assert np.array_equal(result, arr.astype(QuadPrecDType))
50+
51+
def test_create_2d_array_from_nested_list(self):
52+
"""Test that QuadPrecision can create 2D arrays from nested lists."""
53+
result = QuadPrecision([[1, 2], [3, 4]])
54+
assert isinstance(result, np.ndarray)
55+
assert result.dtype.name == "QuadPrecDType128"
56+
assert result.shape == (2, 2)
57+
expected = np.array([[1, 2], [3, 4]], dtype=QuadPrecDType)
58+
assert np.array_equal(result, expected)
59+
60+
def test_create_array_with_backend(self):
61+
"""Test that QuadPrecision respects backend parameter for arrays."""
62+
# Test with sleef backend (default)
63+
result_sleef = QuadPrecision([1, 2, 3], backend='sleef')
64+
assert isinstance(result_sleef, np.ndarray)
65+
assert result_sleef.dtype == QuadPrecDType(backend='sleef')
66+
67+
# Test with longdouble backend
68+
result_ld = QuadPrecision([1, 2, 3], backend='longdouble')
69+
assert isinstance(result_ld, np.ndarray)
70+
assert result_ld.dtype == QuadPrecDType(backend='longdouble')
71+
72+
def test_quad_precision_array_vs_astype_equivalence(self):
73+
"""Test that QuadPrecision(array) is equivalent to array.astype(QuadPrecDType)."""
74+
test_arrays = [
75+
[1, 2, 3],
76+
[1.5, 2.5, 3.5],
77+
[[1, 2], [3, 4]],
78+
np.array([10, 20, 30]),
79+
]
80+
81+
for arr in test_arrays:
82+
result_quad = QuadPrecision(arr)
83+
result_astype = np.array(arr).astype(QuadPrecDType)
84+
assert np.array_equal(result_quad, result_astype)
85+
assert result_quad.dtype == result_astype.dtype
86+
87+
def test_create_empty_array(self):
88+
"""Test that QuadPrecision can create arrays from empty sequences."""
89+
result = QuadPrecision([])
90+
assert isinstance(result, np.ndarray)
91+
assert result.dtype.name == "QuadPrecDType128"
92+
assert result.shape == (0,)
93+
94+
1695
def test_string_roundtrip():
1796
# Test with various values that require full quad precision
1897
test_values = [

0 commit comments

Comments
 (0)