Skip to content

Commit c0b77fa

Browse files
committed
handling np.numbers + tests
1 parent 16cbd57 commit c0b77fa

File tree

2 files changed

+146
-0
lines changed

2 files changed

+146
-0
lines changed

quaddtype/numpy_quaddtype/src/scalar.c

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,54 @@ QuadPrecision_raw_new(QuadBackendType backend)
4444
QuadPrecisionObject *
4545
QuadPrecision_from_object(PyObject *value, QuadBackendType backend)
4646
{
47+
// Handle numpy scalars (np.int32, np.float32, etc.) before arrays
48+
// We need to check this before PySequence_Check because some numpy scalars are sequences
49+
if (PyArray_CheckScalar(value)) {
50+
QuadPrecisionObject *self = QuadPrecision_raw_new(backend);
51+
if (!self)
52+
return NULL;
53+
54+
// Try as floating point first
55+
if (PyArray_IsScalar(value, Floating)) {
56+
PyObject *py_float = PyNumber_Float(value);
57+
if (py_float == NULL) {
58+
Py_DECREF(self);
59+
return NULL;
60+
}
61+
double dval = PyFloat_AsDouble(py_float);
62+
Py_DECREF(py_float);
63+
64+
if (backend == BACKEND_SLEEF) {
65+
self->value.sleef_value = Sleef_cast_from_doubleq1(dval);
66+
}
67+
else {
68+
self->value.longdouble_value = (long double)dval;
69+
}
70+
return self;
71+
}
72+
// Try as integer
73+
else if (PyArray_IsScalar(value, Integer)) {
74+
PyObject *py_int = PyNumber_Long(value);
75+
if (py_int == NULL) {
76+
Py_DECREF(self);
77+
return NULL;
78+
}
79+
long long lval = PyLong_AsLongLong(py_int);
80+
Py_DECREF(py_int);
81+
82+
if (backend == BACKEND_SLEEF) {
83+
self->value.sleef_value = Sleef_cast_from_int64q1(lval);
84+
}
85+
else {
86+
self->value.longdouble_value = (long double)lval;
87+
}
88+
return self;
89+
}
90+
// For other scalar types, fall through to error handling
91+
Py_DECREF(self);
92+
}
4793

94+
// this checks arrays and sequences (array, tuple)
4895
if (PyArray_Check(value) || (PySequence_Check(value) && !PyUnicode_Check(value) && !PyBytes_Check(value)))
4996
{
5097
QuadPrecDTypeObject *dtype_descr = new_quaddtype_instance(backend);

quaddtype/tests/test_quaddtype.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,105 @@ def test_create_empty_array(self):
9191
assert result.dtype.name == "QuadPrecDType128"
9292
assert result.shape == (0,)
9393

94+
def test_create_from_numpy_int_scalars(self):
95+
"""Test that QuadPrecision can create scalars from numpy integer types."""
96+
# Test np.int32
97+
result = QuadPrecision(np.int32(42))
98+
assert isinstance(result, QuadPrecision)
99+
assert float(result) == 42.0
100+
101+
# Test np.int64
102+
result = QuadPrecision(np.int64(100))
103+
assert isinstance(result, QuadPrecision)
104+
assert float(result) == 100.0
105+
106+
# Test np.uint32
107+
result = QuadPrecision(np.uint32(255))
108+
assert isinstance(result, QuadPrecision)
109+
assert float(result) == 255.0
110+
111+
# Test np.int8
112+
result = QuadPrecision(np.int8(-128))
113+
assert isinstance(result, QuadPrecision)
114+
assert float(result) == -128.0
115+
116+
def test_create_from_numpy_float_scalars(self):
117+
"""Test that QuadPrecision can create scalars from numpy floating types."""
118+
# Test np.float64
119+
result = QuadPrecision(np.float64(3.14))
120+
assert isinstance(result, QuadPrecision)
121+
assert abs(float(result) - 3.14) < 1e-10
122+
123+
# Test np.float32
124+
result = QuadPrecision(np.float32(2.71))
125+
assert isinstance(result, QuadPrecision)
126+
# Note: float32 has limited precision, so we use a looser tolerance
127+
assert abs(float(result) - 2.71) < 1e-5
128+
129+
# Test np.float16
130+
result = QuadPrecision(np.float16(1.5))
131+
assert isinstance(result, QuadPrecision)
132+
assert abs(float(result) - 1.5) < 1e-3
133+
134+
def test_create_from_zero_dimensional_array(self):
135+
"""Test that QuadPrecision can create from 0-d numpy arrays."""
136+
# 0-d array from scalar
137+
arr_0d = np.array(5.5)
138+
result = QuadPrecision(arr_0d)
139+
assert isinstance(result, np.ndarray)
140+
assert result.shape == () # 0-d array
141+
assert result.dtype.name == "QuadPrecDType128"
142+
143+
# Another test with integer
144+
arr_0d = np.array(42)
145+
result = QuadPrecision(arr_0d)
146+
assert isinstance(result, np.ndarray)
147+
assert result.shape == ()
148+
149+
def test_numpy_scalar_with_backend(self):
150+
"""Test that numpy scalars respect the backend parameter."""
151+
# Test with sleef backend
152+
result = QuadPrecision(np.int32(10), backend='sleef')
153+
assert isinstance(result, QuadPrecision)
154+
assert "backend='sleef'" in repr(result)
155+
156+
# Test with longdouble backend
157+
result = QuadPrecision(np.float64(3.14), backend='longdouble')
158+
assert isinstance(result, QuadPrecision)
159+
assert "backend='longdouble'" in repr(result)
160+
161+
def test_numpy_scalar_types_coverage(self):
162+
"""Test a comprehensive set of numpy scalar types."""
163+
# Integer types
164+
int_types = [
165+
(np.int8, 10),
166+
(np.int16, 1000),
167+
(np.int32, 100000),
168+
(np.int64, 10000000),
169+
(np.uint8, 200),
170+
(np.uint16, 50000),
171+
(np.uint32, 4000000000),
172+
]
173+
174+
for dtype, value in int_types:
175+
result = QuadPrecision(dtype(value))
176+
assert isinstance(result, QuadPrecision), f"Failed for {dtype.__name__}"
177+
assert float(result) == float(value), f"Value mismatch for {dtype.__name__}"
178+
179+
# Float types
180+
float_types = [
181+
(np.float16, 1.5),
182+
(np.float32, 2.5),
183+
(np.float64, 3.5),
184+
]
185+
186+
for dtype, value in float_types:
187+
result = QuadPrecision(dtype(value))
188+
assert isinstance(result, QuadPrecision), f"Failed for {dtype.__name__}"
189+
# Use appropriate tolerance based on dtype precision
190+
expected = float(dtype(value))
191+
assert abs(float(result) - expected) < 1e-5, f"Value mismatch for {dtype.__name__}"
192+
94193

95194
def test_string_roundtrip():
96195
# Test with various values that require full quad precision

0 commit comments

Comments
 (0)