@@ -783,33 +783,57 @@ def test_float_complex_int_are_equal_as_objects():
783783 tm .assert_numpy_array_equal (result , expected )
784784
785785
786- @pytest .mark .parametrize (
787- "throw1hash, throw2hash, throw1eq, throw2eq" ,
788- product ([True , False ], repeat = 4 ),
789- )
790- def test_exceptions_thrown_from_custom_hash_and_eq_methods (
791- throw1hash , throw2hash , throw1eq , throw2eq
792- ):
786+ class testkey :
793787 # GH 57052
794- class testkey :
795- def __init__ (self , value , throw_hash = False , throw_eq = False ):
796- self .value = value
797- self .throw_hash = throw_hash
798- self .throw_eq = throw_eq
788+ def __init__ (self , value , throw_hash = False , throw_eq = False ):
789+ self .value = value
790+ self .throw_hash = throw_hash
791+ self .throw_eq = throw_eq
792+
793+ def __hash__ (self ):
794+ if self .throw_hash :
795+ raise RuntimeError (f"exception in { self !r} .__hash__" )
796+ return hash (self .value )
797+
798+ def __eq__ (self , other ):
799+ if self .throw_eq :
800+ raise RuntimeError (f"exception in { self !r} .__eq__" )
801+ return self .value == other .value
799802
800- def __hash__ (self ):
801- if self .throw_hash :
802- raise RuntimeError (f"exception in { self !r} .__hash__" )
803- return hash (self .value )
803+ def __repr__ (self ):
804+ return f"testkey({ self .value } , { self .throw_hash } , { self .throw_eq } )"
805+
806+
807+ @pytest .mark .parametrize ("throw1, throw2" , product ([True , False ], repeat = 2 ))
808+ def test_error_raised_from_hash_method_in_set_item (throw1 , throw2 ):
809+ # GH 57052
810+ table = ht .PyObjectHashTable ()
811+
812+ key1 = testkey (value = "hello1" , throw_hash = throw1 )
813+ key2 = testkey (value = "hello2" , throw_hash = throw2 )
814+
815+ if throw1 :
816+ with pytest .raises (
817+ RuntimeError , match = re .escape (f"exception in { key1 !r} .__hash__" )
818+ ):
819+ table .set_item (key1 , 123 )
820+ else :
821+ table .set_item (key1 , 123 )
822+ assert table .get_item (key1 ) == 123
804823
805- def __eq__ (self , other ):
806- if self .throw_eq :
807- raise RuntimeError (f"exception in { self !r} .__eq__" )
808- return self .value == other .value
824+ if throw2 :
825+ with pytest .raises (
826+ RuntimeError , match = re .escape (f"exception in { key2 !r} .__hash__" )
827+ ):
828+ table .set_item (key2 , 456 )
829+ else :
830+ table .set_item (key2 , 456 )
831+ assert table .get_item (key2 ) == 456
809832
810- def __repr__ (self ):
811- return f"testkey({ self .value } , { self .throw_hash } , { self .throw_eq } )"
812833
834+ @pytest .mark .parametrize ("throw1, throw2" , product ([True , False ], repeat = 2 ))
835+ def test_error_raised_from_hash_method_in_get_item (throw1 , throw2 ):
836+ # GH 57052
813837 table = ht .PyObjectHashTable ()
814838
815839 key1 = testkey (value = "hello1" )
@@ -818,43 +842,63 @@ def __repr__(self):
818842 table .set_item (key1 , 123 )
819843 table .set_item (key2 , 456 )
820844
821- key1 .throw_hash = throw1hash
822- key2 .throw_hash = throw2hash
823- key1 .throw_eq = throw1eq
824- key2 .throw_eq = throw2eq
845+ key1 .throw_hash = throw1
846+ key2 .throw_hash = throw2
825847
826- if throw1hash and throw1eq :
827- with pytest .raises (
828- RuntimeError , match = re .escape (f"exception in { key1 !r} ." ) + "__(hash|eq)__"
829- ):
830- table .get_item (key1 )
831- elif throw1hash :
848+ if throw1 :
832849 with pytest .raises (
833850 RuntimeError , match = re .escape (f"exception in { key1 !r} .__hash__" )
834851 ):
835852 table .get_item (key1 )
836- elif throw1eq :
837- with pytest .raises (
838- RuntimeError , match = re .escape (f"exception in { key1 !r} .__eq__" )
839- ):
840- table .get_item (key1 )
841853 else :
842854 assert table .get_item (key1 ) == 123
843855
844- if throw2hash and throw2eq :
856+ if throw2 :
845857 with pytest .raises (
846- RuntimeError , match = re .escape (f"exception in { key2 !r} ." ) + "__(hash|eq)__"
858+ RuntimeError , match = re .escape (f"exception in { key2 !r} .__hash__" )
847859 ):
848860 table .get_item (key2 )
849- elif throw2hash :
861+ else :
862+ assert table .get_item (key2 ) == 456
863+
864+
865+ @pytest .mark .parametrize ("throw" , [True , False ])
866+ def test_error_raised_from_eq_method_in_set_item (throw ):
867+ # GH 57052
868+ table = ht .PyObjectHashTable ()
869+
870+ key1 = testkey (value = "hello" , throw_eq = throw )
871+ key2 = testkey (value = key1 .value )
872+
873+ if throw :
874+ table .set_item (key1 , 123 )
850875 with pytest .raises (
851- RuntimeError , match = re .escape (f"exception in { key2 !r} .__hash__ " )
876+ RuntimeError , match = re .escape (f"exception in { key1 !r} .__eq__ " )
852877 ):
853- table .get_item (key2 )
854- elif throw2eq :
878+ table .set_item (key2 , 456 )
879+ else :
880+ table .set_item (key2 , 456 )
881+ assert table .get_item (key2 ) == 456
882+
883+
884+ @pytest .mark .parametrize ("throw" , [True , False ])
885+ def test_error_raised_from_eq_method_in_get_item (throw ):
886+ # GH 57052
887+ table = ht .PyObjectHashTable ()
888+
889+ key1 = testkey (value = "hello" )
890+ key2 = testkey (value = key1 .value )
891+
892+ table .set_item (key1 , 123 )
893+ table .set_item (key2 , 456 )
894+
895+ if throw :
896+ key1 .throw_eq = True
855897 with pytest .raises (
856- RuntimeError , match = re .escape (f"exception in { key2 !r} .__eq__" )
898+ RuntimeError , match = re .escape (f"exception in { key1 !r} .__eq__" )
857899 ):
858900 table .get_item (key2 )
859901 else :
902+ # this looks odd but it is because key1.value == key2.value
903+ assert table .get_item (key1 ) == 456
860904 assert table .get_item (key2 ) == 456
0 commit comments