Skip to content

Commit 318fe86

Browse files
rewrite hash fix, add tests
1 parent be9d047 commit 318fe86

File tree

4 files changed

+138
-56
lines changed

4 files changed

+138
-56
lines changed

pandas/_libs/hashtable_class_helper.pxi.in

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ Template for each `dtype` helper function for hashtable
44
WARNING: DO NOT edit .pxi FILE directly, .pxi is generated from .pxi.in
55
"""
66
from cpython.unicode cimport PyUnicode_AsUTF8
7+
from cpython.exc cimport PyErr_Occurred, PyErr_Fetch
8+
from cpython.ref cimport Py_XDECREF
79

810
{{py:
911

@@ -1309,6 +1311,22 @@ cdef class StringHashTable(HashTable):
13091311
return labels
13101312

13111313

1314+
cdef raise_if_errors():
1315+
cdef:
1316+
object exc
1317+
PyObject *type
1318+
PyObject *value
1319+
PyObject *traceback
1320+
1321+
PyErr_Fetch(&type, &value, &traceback)
1322+
if value != NULL:
1323+
exc = <object>value
1324+
Py_XDECREF(value)
1325+
Py_XDECREF(type)
1326+
Py_XDECREF(traceback)
1327+
raise exc
1328+
1329+
13121330
cdef class PyObjectHashTable(HashTable):
13131331

13141332
def __init__(self, int64_t size_hint=1):
@@ -1356,15 +1374,8 @@ cdef class PyObjectHashTable(HashTable):
13561374
cdef:
13571375
khiter_t k
13581376

1359-
# GH 57052
1360-
# in khash_python.h, kh_python_hash_equal and kh_python_hash_func will be called repeatedly by khash in a loop.
1361-
# if object implements custom __hash__ and __eq__ methods that can raise exceptions,
1362-
# kh_python_hash_{equal,func} will suppress the exceptions without warnings.
1363-
# as a workaround: try triggering exceptions here, before starting the khash loop
1364-
hash(val)
1365-
val == val
1366-
13671377
k = kh_get_pymap(self.table, <PyObject*>val)
1378+
raise_if_errors()
13681379
if k != self.table.n_buckets:
13691380
return self.table.vals[k]
13701381
else:
@@ -1377,10 +1388,9 @@ cdef class PyObjectHashTable(HashTable):
13771388
char* buf
13781389

13791390
hash(key)
1380-
# GH 57052
1381-
key == key
13821391

13831392
k = kh_put_pymap(self.table, <PyObject*>key, &ret)
1393+
raise_if_errors()
13841394
if kh_exist_pymap(self.table, k):
13851395
self.table.vals[k] = val
13861396
else:

pandas/_libs/include/pandas/vendored/klib/khash_python.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,9 @@ static inline int tupleobject_cmp(PyTupleObject *a, PyTupleObject *b) {
193193
}
194194

195195
static inline int pyobject_cmp(PyObject *a, PyObject *b) {
196+
if (PyErr_Occurred() != NULL) {
197+
return 0;
198+
}
196199
if (a == b) {
197200
return 1;
198201
}
@@ -215,7 +218,6 @@ static inline int pyobject_cmp(PyObject *a, PyObject *b) {
215218

216219
int result = PyObject_RichCompareBool(a, b, Py_EQ);
217220
if (result < 0) {
218-
PyErr_Clear();
219221
return 0;
220222
}
221223
return result;
@@ -292,6 +294,9 @@ static inline Py_hash_t tupleobject_hash(PyTupleObject *key) {
292294
}
293295

294296
static inline khuint32_t kh_python_hash_func(PyObject *key) {
297+
if (PyErr_Occurred() != NULL) {
298+
return 0;
299+
}
295300
Py_hash_t hash;
296301
// For PyObject_Hash holds:
297302
// hash(0.0) == 0 == hash(-0.0)
@@ -315,7 +320,6 @@ static inline khuint32_t kh_python_hash_func(PyObject *key) {
315320
}
316321

317322
if (hash == -1) {
318-
PyErr_Clear();
319323
return 0;
320324
}
321325
#if SIZEOF_PY_HASH_T == 4

pandas/tests/frame/indexing/test_indexing.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1943,3 +1943,27 @@ def test_setitem_validation_scalar_int(self, invalid, any_int_numpy_dtype, index
19431943
def test_setitem_validation_scalar_float(self, invalid, float_numpy_dtype, indexer):
19441944
df = DataFrame({"a": [1, 2, None]}, dtype=float_numpy_dtype)
19451945
self._check_setitem_invalid(df, invalid, indexer)
1946+
1947+
1948+
def test_error_raised_from_custom_hash_method():
1949+
# GH 57052
1950+
class testkey:
1951+
def __init__(self, value):
1952+
self.value = value
1953+
1954+
def __hash__(self):
1955+
raise RuntimeError(f"exception in {self!r}.__hash__")
1956+
1957+
def __eq__(self, other):
1958+
return self.value == other.value
1959+
1960+
def __repr__(self):
1961+
return f"testkey({self.value})"
1962+
1963+
df = DataFrame({"i": map(testkey, range(10))}).set_index("i")
1964+
for i in range(len(df.index)):
1965+
key = testkey(i)
1966+
with pytest.raises(
1967+
RuntimeError, match=re.escape(f"exception in {key!r}.__hash__")
1968+
):
1969+
df.loc[key]

pandas/tests/libs/test_hashtable.py

Lines changed: 88 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -783,33 +783,57 @@ def test_float_complex_int_are_equal_as_objects():
783783
tm.assert_numpy_array_equal(result, expected)
784784

785785

786-
@pytest.mark.parametrize(
787-
"throw1hash, throw2hash, throw1eq, throw2eq",
788-
product([True, False], repeat=4),
789-
)
790-
def test_exceptions_thrown_from_custom_hash_and_eq_methods(
791-
throw1hash, throw2hash, throw1eq, throw2eq
792-
):
786+
class testkey:
793787
# GH 57052
794-
class testkey:
795-
def __init__(self, value, throw_hash=False, throw_eq=False):
796-
self.value = value
797-
self.throw_hash = throw_hash
798-
self.throw_eq = throw_eq
788+
def __init__(self, value, throw_hash=False, throw_eq=False):
789+
self.value = value
790+
self.throw_hash = throw_hash
791+
self.throw_eq = throw_eq
792+
793+
def __hash__(self):
794+
if self.throw_hash:
795+
raise RuntimeError(f"exception in {self!r}.__hash__")
796+
return hash(self.value)
797+
798+
def __eq__(self, other):
799+
if self.throw_eq:
800+
raise RuntimeError(f"exception in {self!r}.__eq__")
801+
return self.value == other.value
799802

800-
def __hash__(self):
801-
if self.throw_hash:
802-
raise RuntimeError(f"exception in {self!r}.__hash__")
803-
return hash(self.value)
803+
def __repr__(self):
804+
return f"testkey({self.value}, {self.throw_hash}, {self.throw_eq})"
805+
806+
807+
@pytest.mark.parametrize("throw1, throw2", product([True, False], repeat=2))
808+
def test_error_raised_from_hash_method_in_set_item(throw1, throw2):
809+
# GH 57052
810+
table = ht.PyObjectHashTable()
811+
812+
key1 = testkey(value="hello1", throw_hash=throw1)
813+
key2 = testkey(value="hello2", throw_hash=throw2)
814+
815+
if throw1:
816+
with pytest.raises(
817+
RuntimeError, match=re.escape(f"exception in {key1!r}.__hash__")
818+
):
819+
table.set_item(key1, 123)
820+
else:
821+
table.set_item(key1, 123)
822+
assert table.get_item(key1) == 123
804823

805-
def __eq__(self, other):
806-
if self.throw_eq:
807-
raise RuntimeError(f"exception in {self!r}.__eq__")
808-
return self.value == other.value
824+
if throw2:
825+
with pytest.raises(
826+
RuntimeError, match=re.escape(f"exception in {key2!r}.__hash__")
827+
):
828+
table.set_item(key2, 456)
829+
else:
830+
table.set_item(key2, 456)
831+
assert table.get_item(key2) == 456
809832

810-
def __repr__(self):
811-
return f"testkey({self.value}, {self.throw_hash}, {self.throw_eq})"
812833

834+
@pytest.mark.parametrize("throw1, throw2", product([True, False], repeat=2))
835+
def test_error_raised_from_hash_method_in_get_item(throw1, throw2):
836+
# GH 57052
813837
table = ht.PyObjectHashTable()
814838

815839
key1 = testkey(value="hello1")
@@ -818,43 +842,63 @@ def __repr__(self):
818842
table.set_item(key1, 123)
819843
table.set_item(key2, 456)
820844

821-
key1.throw_hash = throw1hash
822-
key2.throw_hash = throw2hash
823-
key1.throw_eq = throw1eq
824-
key2.throw_eq = throw2eq
845+
key1.throw_hash = throw1
846+
key2.throw_hash = throw2
825847

826-
if throw1hash and throw1eq:
827-
with pytest.raises(
828-
RuntimeError, match=re.escape(f"exception in {key1!r}.") + "__(hash|eq)__"
829-
):
830-
table.get_item(key1)
831-
elif throw1hash:
848+
if throw1:
832849
with pytest.raises(
833850
RuntimeError, match=re.escape(f"exception in {key1!r}.__hash__")
834851
):
835852
table.get_item(key1)
836-
elif throw1eq:
837-
with pytest.raises(
838-
RuntimeError, match=re.escape(f"exception in {key1!r}.__eq__")
839-
):
840-
table.get_item(key1)
841853
else:
842854
assert table.get_item(key1) == 123
843855

844-
if throw2hash and throw2eq:
856+
if throw2:
845857
with pytest.raises(
846-
RuntimeError, match=re.escape(f"exception in {key2!r}.") + "__(hash|eq)__"
858+
RuntimeError, match=re.escape(f"exception in {key2!r}.__hash__")
847859
):
848860
table.get_item(key2)
849-
elif throw2hash:
861+
else:
862+
assert table.get_item(key2) == 456
863+
864+
865+
@pytest.mark.parametrize("throw", [True, False])
866+
def test_error_raised_from_eq_method_in_set_item(throw):
867+
# GH 57052
868+
table = ht.PyObjectHashTable()
869+
870+
key1 = testkey(value="hello", throw_eq=throw)
871+
key2 = testkey(value=key1.value)
872+
873+
if throw:
874+
table.set_item(key1, 123)
850875
with pytest.raises(
851-
RuntimeError, match=re.escape(f"exception in {key2!r}.__hash__")
876+
RuntimeError, match=re.escape(f"exception in {key1!r}.__eq__")
852877
):
853-
table.get_item(key2)
854-
elif throw2eq:
878+
table.set_item(key2, 456)
879+
else:
880+
table.set_item(key2, 456)
881+
assert table.get_item(key2) == 456
882+
883+
884+
@pytest.mark.parametrize("throw", [True, False])
885+
def test_error_raised_from_eq_method_in_get_item(throw):
886+
# GH 57052
887+
table = ht.PyObjectHashTable()
888+
889+
key1 = testkey(value="hello")
890+
key2 = testkey(value=key1.value)
891+
892+
table.set_item(key1, 123)
893+
table.set_item(key2, 456)
894+
895+
if throw:
896+
key1.throw_eq = True
855897
with pytest.raises(
856-
RuntimeError, match=re.escape(f"exception in {key2!r}.__eq__")
898+
RuntimeError, match=re.escape(f"exception in {key1!r}.__eq__")
857899
):
858900
table.get_item(key2)
859901
else:
902+
# this looks odd but it is because key1.value == key2.value
903+
assert table.get_item(key1) == 456
860904
assert table.get_item(key2) == 456

0 commit comments

Comments
 (0)