@@ -155,6 +155,22 @@ impl<T: fmt::Debug> fmt::Debug for SymmetricDifference<'_, T> {
155155 }
156156}
157157
158+ /// Whether the sizes of two sets are roughly the same order of magnitude.
159+ ///
160+ /// If they are, or if either set is empty, then their intersection
161+ /// is efficiently calculated by iterating both sets jointly.
162+ /// If they aren't, then it is more scalable to iterate over the small set
163+ /// and find matches in the large set (except if the largest element in
164+ /// the small set hardly surpasses the smallest element in the large set).
165+ fn are_proportionate_for_intersection ( len1 : usize , len2 : usize ) -> bool {
166+ let ( small, large) = if len1 <= len2 {
167+ ( len1, len2)
168+ } else {
169+ ( len2, len1)
170+ } ;
171+ ( large >> 7 ) <= small
172+ }
173+
158174/// A lazy iterator producing elements in the intersection of `BTreeSet`s.
159175///
160176/// This `struct` is created by the [`intersection`] method on [`BTreeSet`].
@@ -165,7 +181,13 @@ impl<T: fmt::Debug> fmt::Debug for SymmetricDifference<'_, T> {
165181#[ stable( feature = "rust1" , since = "1.0.0" ) ]
166182pub struct Intersection < ' a , T : ' a > {
167183 a : Peekable < Iter < ' a , T > > ,
168- b : Peekable < Iter < ' a , T > > ,
184+ b : IntersectionOther < ' a , T > ,
185+ }
186+
187+ #[ derive( Debug ) ]
188+ enum IntersectionOther < ' a , T > {
189+ Stitch ( Peekable < Iter < ' a , T > > ) ,
190+ Search ( & ' a BTreeSet < T > ) ,
169191}
170192
171193#[ stable( feature = "collection_debug" , since = "1.17.0" ) ]
@@ -326,9 +348,21 @@ impl<T: Ord> BTreeSet<T> {
326348 /// ```
327349 #[ stable( feature = "rust1" , since = "1.0.0" ) ]
328350 pub fn intersection < ' a > ( & ' a self , other : & ' a BTreeSet < T > ) -> Intersection < ' a , T > {
329- Intersection {
330- a : self . iter ( ) . peekable ( ) ,
331- b : other. iter ( ) . peekable ( ) ,
351+ if are_proportionate_for_intersection ( self . len ( ) , other. len ( ) ) {
352+ Intersection {
353+ a : self . iter ( ) . peekable ( ) ,
354+ b : IntersectionOther :: Stitch ( other. iter ( ) . peekable ( ) ) ,
355+ }
356+ } else if self . len ( ) <= other. len ( ) {
357+ Intersection {
358+ a : self . iter ( ) . peekable ( ) ,
359+ b : IntersectionOther :: Search ( & other) ,
360+ }
361+ } else {
362+ Intersection {
363+ a : other. iter ( ) . peekable ( ) ,
364+ b : IntersectionOther :: Search ( & self ) ,
365+ }
332366 }
333367 }
334368
@@ -1069,6 +1103,14 @@ impl<'a, T: Ord> Iterator for SymmetricDifference<'a, T> {
10691103#[ stable( feature = "fused" , since = "1.26.0" ) ]
10701104impl < T : Ord > FusedIterator for SymmetricDifference < ' _ , T > { }
10711105
1106+ impl < ' a , T > Clone for IntersectionOther < ' a , T > {
1107+ fn clone ( & self ) -> IntersectionOther < ' a , T > {
1108+ match self {
1109+ IntersectionOther :: Stitch ( ref iter) => IntersectionOther :: Stitch ( iter. clone ( ) ) ,
1110+ IntersectionOther :: Search ( set) => IntersectionOther :: Search ( set) ,
1111+ }
1112+ }
1113+ }
10721114#[ stable( feature = "rust1" , since = "1.0.0" ) ]
10731115impl < T > Clone for Intersection < ' _ , T > {
10741116 fn clone ( & self ) -> Self {
@@ -1083,24 +1125,36 @@ impl<'a, T: Ord> Iterator for Intersection<'a, T> {
10831125 type Item = & ' a T ;
10841126
10851127 fn next ( & mut self ) -> Option < & ' a T > {
1086- loop {
1087- match Ord :: cmp ( self . a . peek ( ) ?, self . b . peek ( ) ?) {
1088- Less => {
1089- self . a . next ( ) ;
1090- }
1091- Equal => {
1092- self . b . next ( ) ;
1093- return self . a . next ( ) ;
1128+ match self . b {
1129+ IntersectionOther :: Stitch ( ref mut self_b) => loop {
1130+ match Ord :: cmp ( self . a . peek ( ) ?, self_b. peek ( ) ?) {
1131+ Less => {
1132+ self . a . next ( ) ;
1133+ }
1134+ Equal => {
1135+ self_b. next ( ) ;
1136+ return self . a . next ( ) ;
1137+ }
1138+ Greater => {
1139+ self_b. next ( ) ;
1140+ }
10941141 }
1095- Greater => {
1096- self . b . next ( ) ;
1142+ }
1143+ IntersectionOther :: Search ( set) => loop {
1144+ let e = self . a . next ( ) ?;
1145+ if set. contains ( & e) {
1146+ return Some ( e) ;
10971147 }
10981148 }
10991149 }
11001150 }
11011151
11021152 fn size_hint ( & self ) -> ( usize , Option < usize > ) {
1103- ( 0 , Some ( min ( self . a . len ( ) , self . b . len ( ) ) ) )
1153+ let b_len = match self . b {
1154+ IntersectionOther :: Stitch ( ref iter) => iter. len ( ) ,
1155+ IntersectionOther :: Search ( set) => set. len ( ) ,
1156+ } ;
1157+ ( 0 , Some ( min ( self . a . len ( ) , b_len) ) )
11041158 }
11051159}
11061160
@@ -1140,3 +1194,21 @@ impl<'a, T: Ord> Iterator for Union<'a, T> {
11401194
11411195#[ stable( feature = "fused" , since = "1.26.0" ) ]
11421196impl < T : Ord > FusedIterator for Union < ' _ , T > { }
1197+
1198+ #[ cfg( test) ]
1199+ mod tests {
1200+ use super :: * ;
1201+
1202+ #[ test]
1203+ fn test_are_proportionate_for_intersection ( ) {
1204+ assert ! ( are_proportionate_for_intersection( 0 , 0 ) ) ;
1205+ assert ! ( are_proportionate_for_intersection( 0 , 127 ) ) ;
1206+ assert ! ( !are_proportionate_for_intersection( 0 , 128 ) ) ;
1207+ assert ! ( are_proportionate_for_intersection( 1 , 255 ) ) ;
1208+ assert ! ( !are_proportionate_for_intersection( 1 , 256 ) ) ;
1209+ assert ! ( are_proportionate_for_intersection( 127 , 0 ) ) ;
1210+ assert ! ( !are_proportionate_for_intersection( 128 , 0 ) ) ;
1211+ assert ! ( are_proportionate_for_intersection( 255 , 1 ) ) ;
1212+ assert ! ( !are_proportionate_for_intersection( 256 , 1 ) ) ;
1213+ }
1214+ }
0 commit comments