Skip to content

Commit f7df8af

Browse files
try triggering exceptions before entering khash
1 parent d597079 commit f7df8af

File tree

2 files changed

+88
-0
lines changed

2 files changed

+88
-0
lines changed

pandas/_libs/hashtable_class_helper.pxi.in

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1356,6 +1356,14 @@ cdef class PyObjectHashTable(HashTable):
13561356
cdef:
13571357
khiter_t k
13581358

1359+
# GH 57052
1360+
# in khash_python.h, kh_python_hash_equal and kh_python_hash_func will be called repeatedly by khash in a loop.
1361+
# if object implements custom __hash__ and __eq__ methods that can raise exceptions,
1362+
# kh_python_hash_{equal,func} will suppress the exceptions without warnings.
1363+
# as a workaround: try triggering exceptions here, before starting the khash loop
1364+
hash(val)
1365+
val == val
1366+
13591367
k = kh_get_pymap(self.table, <PyObject*>val)
13601368
if k != self.table.n_buckets:
13611369
return self.table.vals[k]
@@ -1369,6 +1377,8 @@ cdef class PyObjectHashTable(HashTable):
13691377
char* buf
13701378

13711379
hash(key)
1380+
# GH 57052
1381+
key == key
13721382

13731383
k = kh_put_pymap(self.table, <PyObject*>key, &ret)
13741384
if kh_exist_pymap(self.table, k):

pandas/tests/libs/test_hashtable.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from collections import namedtuple
22
from collections.abc import Generator
33
from contextlib import contextmanager
4+
from itertools import product
45
import re
56
import struct
67
import tracemalloc
@@ -780,3 +781,80 @@ def test_float_complex_int_are_equal_as_objects():
780781
result = isin(np.array(values, dtype=object), np.asarray(comps))
781782
expected = np.array([False, True, True, True], dtype=np.bool_)
782783
tm.assert_numpy_array_equal(result, expected)
784+
785+
786+
@pytest.mark.parametrize(
787+
"throw1hash, throw2hash, throw1eq, throw2eq",
788+
product([True, False], repeat=4),
789+
)
790+
def test_exceptions_thrown_from_custom_hash_and_eq_methods(
791+
throw1hash, throw2hash, throw1eq, throw2eq
792+
):
793+
# GH 57052
794+
class testkey:
795+
def __init__(self, value, throw_hash=False, throw_eq=False):
796+
self.value = value
797+
self.throw_hash = throw_hash
798+
self.throw_eq = throw_eq
799+
800+
def __hash__(self):
801+
if self.throw_hash:
802+
raise RuntimeError(f"exception in {self!r}.__hash__")
803+
return hash(self.value)
804+
805+
def __eq__(self, other):
806+
if self.throw_eq:
807+
raise RuntimeError(f"exception in {self!r}.__eq__")
808+
return self.value == other.value
809+
810+
def __repr__(self):
811+
return f"{self.__class__.__name__}({self.value}, {self.throw_hash}, {self.throw_eq})"
812+
813+
table = ht.PyObjectHashTable()
814+
815+
key1 = testkey(value="hello1")
816+
key2 = testkey(value="hello2")
817+
818+
table.set_item(key1, 123)
819+
table.set_item(key2, 456)
820+
821+
key1.throw_hash = throw1hash
822+
key2.throw_hash = throw2hash
823+
key1.throw_eq = throw1eq
824+
key2.throw_eq = throw2eq
825+
826+
if throw1hash and throw1eq:
827+
with pytest.raises(
828+
RuntimeError, match=re.escape(f"exception in {key1!r}.") + "__(hash|eq)__"
829+
):
830+
table.get_item(key1)
831+
elif throw1hash:
832+
with pytest.raises(
833+
RuntimeError, match=re.escape(f"exception in {key1!r}.__hash__")
834+
):
835+
table.get_item(key1)
836+
elif throw1eq:
837+
with pytest.raises(
838+
RuntimeError, match=re.escape(f"exception in {key1!r}.__eq__")
839+
):
840+
table.get_item(key1)
841+
else:
842+
assert table.get_item(key1) == 123
843+
844+
if throw2hash and throw2eq:
845+
with pytest.raises(
846+
RuntimeError, match=re.escape(f"exception in {key2!r}.") + "__(hash|eq)__"
847+
):
848+
table.get_item(key2)
849+
elif throw2hash:
850+
with pytest.raises(
851+
RuntimeError, match=re.escape(f"exception in {key2!r}.__hash__")
852+
):
853+
table.get_item(key2)
854+
elif throw2eq:
855+
with pytest.raises(
856+
RuntimeError, match=re.escape(f"exception in {key2!r}.__eq__")
857+
):
858+
table.get_item(key2)
859+
else:
860+
assert table.get_item(key2) == 456

0 commit comments

Comments
 (0)