11use std:: iter:: Step ;
22use std:: marker:: PhantomData ;
3- use std:: ops:: Bound ;
43use std:: ops:: RangeBounds ;
4+ use std:: ops:: { Bound , Range } ;
55
66use crate :: vec:: Idx ;
77use crate :: vec:: IndexVec ;
@@ -11,6 +11,10 @@ use smallvec::SmallVec;
1111mod tests;
1212
1313/// Stores a set of intervals on the indices.
14+ ///
15+ /// The elements in `map` are sorted and non-adjacent, which means
16+ /// the second value of the previous element is *greater* than the
17+ /// first value of the following element.
1418#[ derive( Debug , Clone ) ]
1519pub struct IntervalSet < I > {
1620 // Start, end
@@ -84,7 +88,7 @@ impl<I: Idx> IntervalSet<I> {
8488 // continue to the next range. We're looking here for the first
8589 // range which starts *non-adjacently* to our end.
8690 let next = self . map . partition_point ( |r| r. 0 <= end + 1 ) ;
87- if let Some ( right) = next. checked_sub ( 1 ) {
91+ let result = if let Some ( right) = next. checked_sub ( 1 ) {
8892 let ( prev_start, prev_end) = self . map [ right] ;
8993 if prev_end + 1 >= start {
9094 // If the start for the inserted range is adjacent to the
@@ -99,25 +103,25 @@ impl<I: Idx> IntervalSet<I> {
99103 if left != right {
100104 self . map . drain ( left..right) ;
101105 }
102- return true ;
106+ true
103107 } else {
104108 // We overlap with the previous range, increase it to
105109 // include us.
106110 //
107111 // Make sure we're actually going to *increase* it though --
108112 // it may be that end is just inside the previously existing
109113 // set.
110- return if end > prev_end {
114+ if end > prev_end {
111115 self . map [ right] . 1 = end;
112116 true
113117 } else {
114118 false
115- } ;
119+ }
116120 }
117121 } else {
118122 // Otherwise, we don't overlap, so just insert
119123 self . map . insert ( right + 1 , ( start, end) ) ;
120- return true ;
124+ true
121125 }
122126 } else {
123127 if self . map . is_empty ( ) {
@@ -127,8 +131,16 @@ impl<I: Idx> IntervalSet<I> {
127131 } else {
128132 self . map . insert ( next, ( start, end) ) ;
129133 }
130- return true ;
131- }
134+ true
135+ } ;
136+ debug_assert ! (
137+ self . check_invariants( ) ,
138+ "wrong intervals after insert {:?}..={:?} to {:?}" ,
139+ start,
140+ end,
141+ self
142+ ) ;
143+ result
132144 }
133145
134146 pub fn contains ( & self , needle : I ) -> bool {
@@ -145,9 +157,26 @@ impl<I: Idx> IntervalSet<I> {
145157 where
146158 I : Step ,
147159 {
148- // FIXME: Performance here is probably not great. We will be doing a lot
149- // of pointless tree traversals.
150- other. iter ( ) . all ( |elem| self . contains ( elem) )
160+ let mut sup_iter = self . iter_intervals ( ) ;
161+ let mut current = None ;
162+ let contains = |sup : Range < I > , sub : Range < I > , current : & mut Option < Range < I > > | {
163+ if sup. end < sub. start {
164+ // if `sup.end == sub.start`, the next sup doesn't contain `sub.start`
165+ None // continue to the next sup
166+ } else if sup. end >= sub. end && sup. start <= sub. start {
167+ * current = Some ( sup) ; // save the current sup
168+ Some ( true )
169+ } else {
170+ Some ( false )
171+ }
172+ } ;
173+ other. iter_intervals ( ) . all ( |sub| {
174+ current
175+ . take ( )
176+ . and_then ( |sup| contains ( sup, sub. clone ( ) , & mut current) )
177+ . or_else ( || sup_iter. find_map ( |sup| contains ( sup, sub. clone ( ) , & mut current) ) )
178+ . unwrap_or ( false )
179+ } )
151180 }
152181
153182 pub fn is_empty ( & self ) -> bool {
@@ -174,7 +203,10 @@ impl<I: Idx> IntervalSet<I> {
174203
175204 pub fn insert_all ( & mut self ) {
176205 self . clear ( ) ;
177- self . map . push ( ( 0 , self . domain . try_into ( ) . unwrap ( ) ) ) ;
206+ if let Some ( end) = self . domain . checked_sub ( 1 ) {
207+ self . map . push ( ( 0 , end. try_into ( ) . unwrap ( ) ) ) ;
208+ }
209+ debug_assert ! ( self . check_invariants( ) ) ;
178210 }
179211
180212 pub fn union ( & mut self , other : & IntervalSet < I > ) -> bool
@@ -186,8 +218,21 @@ impl<I: Idx> IntervalSet<I> {
186218 for range in other. iter_intervals ( ) {
187219 did_insert |= self . insert_range ( range) ;
188220 }
221+ debug_assert ! ( self . check_invariants( ) ) ;
189222 did_insert
190223 }
224+
225+ // Check the intervals are valid, sorted and non-adjacent
226+ fn check_invariants ( & self ) -> bool {
227+ let mut current: Option < u32 > = None ;
228+ for ( start, end) in & self . map {
229+ if start > end || current. map_or ( false , |x| x + 1 >= * start) {
230+ return false ;
231+ }
232+ current = Some ( * end) ;
233+ }
234+ current. map_or ( true , |x| x < self . domain as u32 )
235+ }
191236}
192237
193238/// This data structure optimizes for cases where the stored bits in each row
0 commit comments