@@ -7,7 +7,7 @@ use core::hash::{BuildHasher, Hash};
77use core:: iter:: { Chain , FusedIterator } ;
88use core:: ops:: { BitAnd , BitOr , BitXor , Sub } ;
99
10- use super :: map:: { self , DefaultHashBuilder , HashMap , Keys } ;
10+ use super :: map:: { self , make_hash , DefaultHashBuilder , HashMap , Keys , RawEntryMut } ;
1111use crate :: raw:: { Allocator , Global , RawExtractIf } ;
1212
1313// Future Optimization (FIXME!)
@@ -955,6 +955,11 @@ where
955955 /// Inserts a value computed from `f` into the set if the given `value` is
956956 /// not present, then returns a reference to the value in the set.
957957 ///
958+ /// # Panics
959+ ///
960+ /// Panics if the value from the function and the provided lookup value
961+ /// are not equivalent. See [`Equivalent`] and [`Hash`] for more information.
962+ ///
958963 /// # Examples
959964 ///
960965 /// ```
@@ -969,20 +974,37 @@ where
969974 /// assert_eq!(value, pet);
970975 /// }
971976 /// assert_eq!(set.len(), 4); // a new "fish" was inserted
977+ /// assert!(set.contains("fish"));
972978 /// ```
973979 #[ cfg_attr( feature = "inline-more" , inline) ]
974980 pub fn get_or_insert_with < Q : ?Sized , F > ( & mut self , value : & Q , f : F ) -> & T
975981 where
976982 Q : Hash + Equivalent < T > ,
977983 F : FnOnce ( & Q ) -> T ,
978984 {
985+ #[ cold]
986+ #[ inline( never) ]
987+ fn assert_failed ( ) {
988+ panic ! (
989+ "the value from the function and the lookup value \
990+ must be equivalent and have the same hash"
991+ ) ;
992+ }
993+
979994 // Although the raw entry gives us `&mut T`, we only return `&T` to be consistent with
980995 // `get`. Key mutation is "raw" because you're not supposed to affect `Eq` or `Hash`.
981- self . map
982- . raw_entry_mut ( )
983- . from_key ( value)
984- . or_insert_with ( || ( f ( value) , ( ) ) )
985- . 0
996+ let hash = make_hash :: < Q , S > ( & self . map . hash_builder , value) ;
997+ let raw_entry_builder = self . map . raw_entry_mut ( ) ;
998+ match raw_entry_builder. from_key_hashed_nocheck ( hash, value) {
999+ RawEntryMut :: Occupied ( entry) => entry. into_key ( ) ,
1000+ RawEntryMut :: Vacant ( entry) => {
1001+ let insert_value = f ( value) ;
1002+ if !value. equivalent ( & insert_value) {
1003+ assert_failed ( ) ;
1004+ }
1005+ entry. insert_hashed_nocheck ( hash, insert_value, ( ) ) . 0
1006+ }
1007+ }
9861008 }
9871009
9881010 /// Gets the given value's corresponding entry in the set for in-place manipulation.
@@ -2492,7 +2514,7 @@ fn assert_covariance() {
24922514#[ cfg( test) ]
24932515mod test_set {
24942516 use super :: super :: map:: DefaultHashBuilder ;
2495- use super :: HashSet ;
2517+ use super :: { make_hash , Equivalent , HashSet } ;
24962518 use std:: vec:: Vec ;
24972519
24982520 #[ test]
@@ -2958,4 +2980,57 @@ mod test_set {
29582980 // (and without the `map`, it does not).
29592981 let mut _set: HashSet < _ > = ( 0 ..3 ) . map ( |_| ( ) ) . collect ( ) ;
29602982 }
2983+
2984+ #[ test]
2985+ fn duplicate_insert ( ) {
2986+ let mut set = HashSet :: new ( ) ;
2987+ set. insert ( 1 ) ;
2988+ set. get_or_insert_with ( & 1 , |_| 1 ) ;
2989+ set. get_or_insert_with ( & 1 , |_| 1 ) ;
2990+ assert ! ( [ 1 ] . iter( ) . eq( set. iter( ) ) ) ;
2991+ }
2992+
2993+ #[ test]
2994+ #[ should_panic]
2995+ fn some_invalid_equivalent ( ) {
2996+ use core:: hash:: { Hash , Hasher } ;
2997+ struct Invalid {
2998+ count : u32 ,
2999+ other : u32 ,
3000+ }
3001+
3002+ struct InvalidRef {
3003+ count : u32 ,
3004+ other : u32 ,
3005+ }
3006+
3007+ impl PartialEq for Invalid {
3008+ fn eq ( & self , other : & Self ) -> bool {
3009+ self . count == other. count && self . other == other. other
3010+ }
3011+ }
3012+ impl Eq for Invalid { }
3013+
3014+ impl Equivalent < Invalid > for InvalidRef {
3015+ fn equivalent ( & self , key : & Invalid ) -> bool {
3016+ self . count == key. count && self . other == key. other
3017+ }
3018+ }
3019+ impl Hash for Invalid {
3020+ fn hash < H : Hasher > ( & self , state : & mut H ) {
3021+ self . count . hash ( state) ;
3022+ }
3023+ }
3024+ impl Hash for InvalidRef {
3025+ fn hash < H : Hasher > ( & self , state : & mut H ) {
3026+ self . count . hash ( state) ;
3027+ }
3028+ }
3029+ let mut set: HashSet < Invalid > = HashSet :: new ( ) ;
3030+ let key = InvalidRef { count : 1 , other : 1 } ;
3031+ let value = Invalid { count : 1 , other : 2 } ;
3032+ if make_hash ( set. hasher ( ) , & key) == make_hash ( set. hasher ( ) , & value) {
3033+ set. get_or_insert_with ( & key, |_| value) ;
3034+ }
3035+ }
29613036}
0 commit comments