|
1 | 1 | from collections import namedtuple |
2 | 2 | from collections.abc import Generator |
3 | 3 | from contextlib import contextmanager |
| 4 | +from itertools import product |
4 | 5 | import re |
5 | 6 | import struct |
6 | 7 | import tracemalloc |
@@ -780,3 +781,80 @@ def test_float_complex_int_are_equal_as_objects(): |
780 | 781 | result = isin(np.array(values, dtype=object), np.asarray(comps)) |
781 | 782 | expected = np.array([False, True, True, True], dtype=np.bool_) |
782 | 783 | tm.assert_numpy_array_equal(result, expected) |
| 784 | + |
| 785 | + |
| 786 | +@pytest.mark.parametrize( |
| 787 | + "throw1hash, throw2hash, throw1eq, throw2eq", |
| 788 | + product([True, False], repeat=4), |
| 789 | +) |
| 790 | +def test_exceptions_thrown_from_custom_hash_and_eq_methods( |
| 791 | + throw1hash, throw2hash, throw1eq, throw2eq |
| 792 | +): |
| 793 | + # GH 57052 |
| 794 | + class testkey: |
| 795 | + def __init__(self, value, throw_hash=False, throw_eq=False): |
| 796 | + self.value = value |
| 797 | + self.throw_hash = throw_hash |
| 798 | + self.throw_eq = throw_eq |
| 799 | + |
| 800 | + def __hash__(self): |
| 801 | + if self.throw_hash: |
| 802 | + raise RuntimeError(f"exception in {self!r}.__hash__") |
| 803 | + return hash(self.value) |
| 804 | + |
| 805 | + def __eq__(self, other): |
| 806 | + if self.throw_eq: |
| 807 | + raise RuntimeError(f"exception in {self!r}.__eq__") |
| 808 | + return self.value == other.value |
| 809 | + |
| 810 | + def __repr__(self): |
| 811 | + return f"{self.__class__.__name__}({self.value}, {self.throw_hash}, {self.throw_eq})" |
| 812 | + |
| 813 | + table = ht.PyObjectHashTable() |
| 814 | + |
| 815 | + key1 = testkey(value="hello1") |
| 816 | + key2 = testkey(value="hello2") |
| 817 | + |
| 818 | + table.set_item(key1, 123) |
| 819 | + table.set_item(key2, 456) |
| 820 | + |
| 821 | + key1.throw_hash = throw1hash |
| 822 | + key2.throw_hash = throw2hash |
| 823 | + key1.throw_eq = throw1eq |
| 824 | + key2.throw_eq = throw2eq |
| 825 | + |
| 826 | + if throw1hash and throw1eq: |
| 827 | + with pytest.raises( |
| 828 | + RuntimeError, match=re.escape(f"exception in {key1!r}.") + "__(hash|eq)__" |
| 829 | + ): |
| 830 | + table.get_item(key1) |
| 831 | + elif throw1hash: |
| 832 | + with pytest.raises( |
| 833 | + RuntimeError, match=re.escape(f"exception in {key1!r}.__hash__") |
| 834 | + ): |
| 835 | + table.get_item(key1) |
| 836 | + elif throw1eq: |
| 837 | + with pytest.raises( |
| 838 | + RuntimeError, match=re.escape(f"exception in {key1!r}.__eq__") |
| 839 | + ): |
| 840 | + table.get_item(key1) |
| 841 | + else: |
| 842 | + assert table.get_item(key1) == 123 |
| 843 | + |
| 844 | + if throw2hash and throw2eq: |
| 845 | + with pytest.raises( |
| 846 | + RuntimeError, match=re.escape(f"exception in {key2!r}.") + "__(hash|eq)__" |
| 847 | + ): |
| 848 | + table.get_item(key2) |
| 849 | + elif throw2hash: |
| 850 | + with pytest.raises( |
| 851 | + RuntimeError, match=re.escape(f"exception in {key2!r}.__hash__") |
| 852 | + ): |
| 853 | + table.get_item(key2) |
| 854 | + elif throw2eq: |
| 855 | + with pytest.raises( |
| 856 | + RuntimeError, match=re.escape(f"exception in {key2!r}.__eq__") |
| 857 | + ): |
| 858 | + table.get_item(key2) |
| 859 | + else: |
| 860 | + assert table.get_item(key2) == 456 |
0 commit comments