@@ -5,7 +5,7 @@ use serde::{Deserialize, Serialize};
55
66pub const DEFAULT_SEGMENT_CHECK_INSNS : u64 = 1000 ;
77
8- pub const DEFAULT_MAX_TRACE_HEIGHT : u32 = ( 1 << 23 ) - 10000 ;
8+ pub const DEFAULT_MAX_TRACE_HEIGHT : u32 = 1 << 23 ;
99pub const DEFAULT_MAX_CELLS : usize = 2_000_000_000 ; // 2B
1010const DEFAULT_MAX_INTERACTIONS : usize = BabyBear :: ORDER_U32 as usize ;
1111
@@ -46,6 +46,10 @@ pub struct SegmentationCtx {
4646 pub instret_last_segment_check : u64 ,
4747 #[ getset( set_with = "pub" ) ]
4848 pub segment_check_insns : u64 ,
49+ /// Checkpoint of trace heights at last known state where all thresholds satisfied
50+ pub ( crate ) checkpoint_trace_heights : Vec < u32 > ,
51+ /// Instruction count at the checkpoint
52+ checkpoint_instret : u64 ,
4953}
5054
5155impl SegmentationCtx {
@@ -58,6 +62,7 @@ impl SegmentationCtx {
5862 assert_eq ! ( air_names. len( ) , widths. len( ) ) ;
5963 assert_eq ! ( air_names. len( ) , interactions. len( ) ) ;
6064
65+ let num_airs = air_names. len ( ) ;
6166 Self {
6267 segments : Vec :: new ( ) ,
6368 air_names,
@@ -66,6 +71,8 @@ impl SegmentationCtx {
6671 segmentation_limits,
6772 segment_check_insns : DEFAULT_SEGMENT_CHECK_INSNS ,
6873 instret_last_segment_check : 0 ,
74+ checkpoint_trace_heights : vec ! [ 0 ; num_airs] ,
75+ checkpoint_instret : 0 ,
6976 }
7077 }
7178
@@ -77,6 +84,7 @@ impl SegmentationCtx {
7784 assert_eq ! ( air_names. len( ) , widths. len( ) ) ;
7885 assert_eq ! ( air_names. len( ) , interactions. len( ) ) ;
7986
87+ let num_airs = air_names. len ( ) ;
8088 Self {
8189 segments : Vec :: new ( ) ,
8290 air_names,
@@ -85,6 +93,8 @@ impl SegmentationCtx {
8593 segmentation_limits : SegmentationLimits :: default ( ) ,
8694 segment_check_insns : DEFAULT_SEGMENT_CHECK_INSNS ,
8795 instret_last_segment_check : 0 ,
96+ checkpoint_trace_heights : vec ! [ 0 ; num_airs] ,
97+ checkpoint_instret : 0 ,
8898 }
8999 }
90100
@@ -100,37 +110,6 @@ impl SegmentationCtx {
100110 self . segmentation_limits . max_interactions = max_interactions;
101111 }
102112
103- /// Calculate the total cells used based on trace heights and widths
104- #[ inline( always) ]
105- fn calculate_total_cells ( & self , trace_heights : & [ u32 ] ) -> usize {
106- debug_assert_eq ! ( trace_heights. len( ) , self . widths. len( ) ) ;
107-
108- // SAFETY: Length equality is asserted during initialization
109- let widths_slice = unsafe { self . widths . get_unchecked ( ..trace_heights. len ( ) ) } ;
110-
111- trace_heights
112- . iter ( )
113- . zip ( widths_slice)
114- . map ( |( & height, & width) | height as usize * width)
115- . sum ( )
116- }
117-
118- /// Calculate the total interactions based on trace heights and interaction counts
119- #[ inline( always) ]
120- fn calculate_total_interactions ( & self , trace_heights : & [ u32 ] ) -> usize {
121- debug_assert_eq ! ( trace_heights. len( ) , self . interactions. len( ) ) ;
122-
123- // SAFETY: Length equality is asserted during initialization
124- let interactions_slice = unsafe { self . interactions . get_unchecked ( ..trace_heights. len ( ) ) } ;
125-
126- trace_heights
127- . iter ( )
128- . zip ( interactions_slice)
129- // We add 1 for the zero messages from the padding rows
130- . map ( |( & height, & interactions) | ( height + 1 ) as usize * interactions)
131- . sum ( )
132- }
133-
134113 #[ inline( always) ]
135114 fn should_segment (
136115 & self ,
@@ -140,6 +119,8 @@ impl SegmentationCtx {
140119 ) -> bool {
141120 debug_assert_eq ! ( trace_heights. len( ) , is_trace_height_constant. len( ) ) ;
142121 debug_assert_eq ! ( trace_heights. len( ) , self . air_names. len( ) ) ;
122+ debug_assert_eq ! ( trace_heights. len( ) , self . widths. len( ) ) ;
123+ debug_assert_eq ! ( trace_heights. len( ) , self . interactions. len( ) ) ;
143124
144125 let instret_start = self
145126 . segments
@@ -152,44 +133,51 @@ impl SegmentationCtx {
152133 return false ;
153134 }
154135
155- for ( i, ( & height, is_constant) ) in trace_heights
136+ let mut total_cells = 0 ;
137+ for ( i, ( ( padded_height, width) , is_constant) ) in trace_heights
156138 . iter ( )
139+ . map ( |& height| height. next_power_of_two ( ) )
140+ . zip ( self . widths . iter ( ) )
157141 . zip ( is_trace_height_constant. iter ( ) )
158142 . enumerate ( )
159143 {
160- // Only segment if the height is not constant and exceeds the maximum height
161- if !is_constant && height > self . segmentation_limits . max_trace_height {
162- let air_name = & self . air_names [ i] ;
144+ // Only segment if the height is not constant and exceeds the maximum height after
145+ // padding
146+ if !is_constant && padded_height > self . segmentation_limits . max_trace_height {
147+ let air_name = unsafe { self . air_names . get_unchecked ( i) } ;
163148 tracing:: info!(
164- "Segment {:2} | instret {:9} | chip {} ({}) height ({:8}) > max ({:8})" ,
165- self . segments. len( ) ,
149+ "instret {:9} | chip {} ({}) height ({:8}) > max ({:8})" ,
166150 instret,
167151 i,
168152 air_name,
169- height ,
153+ padded_height ,
170154 self . segmentation_limits. max_trace_height
171155 ) ;
172156 return true ;
173157 }
158+ total_cells += padded_height as usize * width;
174159 }
175160
176- let total_cells = self . calculate_total_cells ( trace_heights) ;
177161 if total_cells > self . segmentation_limits . max_cells {
178162 tracing:: info!(
179- "Segment {:2} | instret {:9} | total cells ({:10}) > max ({:10})" ,
180- self . segments. len( ) ,
163+ "instret {:9} | total cells ({:10}) > max ({:10})" ,
181164 instret,
182165 total_cells,
183166 self . segmentation_limits. max_cells
184167 ) ;
185168 return true ;
186169 }
187170
188- let total_interactions = self . calculate_total_interactions ( trace_heights) ;
171+ // All padding rows contribute a single message to the interactions (+1) since
172+ // we assume chips don't send/receive with nonzero multiplicity on padding rows.
173+ let total_interactions: usize = trace_heights
174+ . iter ( )
175+ . zip ( self . interactions . iter ( ) )
176+ . map ( |( & height, & interactions) | ( height + 1 ) as usize * interactions)
177+ . sum ( ) ;
189178 if total_interactions > self . segmentation_limits . max_interactions {
190179 tracing:: info!(
191- "Segment {:2} | instret {:9} | total interactions ({:11}) > max ({:11})" ,
192- self . segments. len( ) ,
180+ "instret {:9} | total interactions ({:11}) > max ({:11})" ,
193181 instret,
194182 total_interactions,
195183 self . segmentation_limits. max_interactions
@@ -204,16 +192,84 @@ impl SegmentationCtx {
204192 pub fn check_and_segment (
205193 & mut self ,
206194 instret : u64 ,
207- trace_heights : & [ u32 ] ,
195+ trace_heights : & mut [ u32 ] ,
208196 is_trace_height_constant : & [ bool ] ,
209197 ) -> bool {
210- let ret = self . should_segment ( instret, trace_heights, is_trace_height_constant) ;
211- if ret {
212- self . segment ( instret, trace_heights) ;
198+ let should_seg = self . should_segment ( instret, trace_heights, is_trace_height_constant) ;
199+
200+ if should_seg {
201+ self . create_segment_from_checkpoint ( instret, trace_heights, is_trace_height_constant) ;
202+ } else {
203+ self . update_checkpoint ( instret, trace_heights) ;
213204 }
205+
214206 self . instret_last_segment_check = instret;
207+ should_seg
208+ }
215209
216- ret
210+ #[ inline( always) ]
211+ fn create_segment_from_checkpoint (
212+ & mut self ,
213+ instret : u64 ,
214+ trace_heights : & mut [ u32 ] ,
215+ is_trace_height_constant : & [ bool ] ,
216+ ) {
217+ let instret_start = self
218+ . segments
219+ . last ( )
220+ . map_or ( 0 , |s| s. instret_start + s. num_insns ) ;
221+
222+ let ( segment_instret, segment_heights) = if self . checkpoint_instret > instret_start {
223+ (
224+ self . checkpoint_instret ,
225+ self . checkpoint_trace_heights . clone ( ) ,
226+ )
227+ } else {
228+ // No valid checkpoint, use current values
229+ ( instret, trace_heights. to_vec ( ) )
230+ } ;
231+
232+ // Reset current trace heights and checkpoint
233+ self . reset_trace_heights ( trace_heights, & segment_heights, is_trace_height_constant) ;
234+ self . checkpoint_instret = 0 ;
235+
236+ tracing:: info!(
237+ "Segment {:2} | instret {:9} | {} instructions" ,
238+ self . segments. len( ) ,
239+ instret_start,
240+ segment_instret - instret_start
241+ ) ;
242+ self . segments . push ( Segment {
243+ instret_start,
244+ num_insns : segment_instret - instret_start,
245+ trace_heights : segment_heights,
246+ } ) ;
247+ }
248+
249+ /// Resets trace heights by subtracting segment heights
250+ #[ inline( always) ]
251+ fn reset_trace_heights (
252+ & self ,
253+ trace_heights : & mut [ u32 ] ,
254+ segment_heights : & [ u32 ] ,
255+ is_trace_height_constant : & [ bool ] ,
256+ ) {
257+ for ( ( trace_height, & segment_height) , & is_trace_height_constant) in trace_heights
258+ . iter_mut ( )
259+ . zip ( segment_heights. iter ( ) )
260+ . zip ( is_trace_height_constant. iter ( ) )
261+ {
262+ if !is_trace_height_constant {
263+ * trace_height = trace_height. checked_sub ( segment_height) . unwrap ( ) ;
264+ }
265+ }
266+ }
267+
268+ /// Updates the checkpoint with current safe state
269+ #[ inline( always) ]
270+ fn update_checkpoint ( & mut self , instret : u64 , trace_heights : & [ u32 ] ) {
271+ self . checkpoint_trace_heights . copy_from_slice ( trace_heights) ;
272+ self . checkpoint_instret = instret;
217273 }
218274
219275 /// Try segment if there is at least one cycle
@@ -227,6 +283,12 @@ impl SegmentationCtx {
227283
228284 debug_assert ! ( num_insns > 0 , "Segment should contain at least one cycle" ) ;
229285
286+ tracing:: info!(
287+ "Segment {:2} | instret {:9} | {} instructions [FINAL]" ,
288+ self . segments. len( ) ,
289+ instret_start,
290+ num_insns
291+ ) ;
230292 self . segments . push ( Segment {
231293 instret_start,
232294 num_insns,
0 commit comments