@@ -17,52 +17,82 @@ impl BitSet {
1717
1818 #[ inline( always) ]
1919 pub fn insert ( & mut self , index : usize ) -> bool {
20- let word_index = index / 64 ;
21- let bit_index = index % 64 ;
20+ let word_index = index >> 6 ;
21+ let bit_index = index & 63 ;
2222 let mask = 1u64 << bit_index;
2323
24- let was_set = ( self . words [ word_index] & mask) != 0 ;
25- self . words [ word_index] |= mask;
24+ debug_assert ! ( word_index < self . words. len( ) , "BitSet index out of bounds" ) ;
25+
26+ // SAFETY: word_index is derived from a memory address that is bounds-checked
27+ // during memory access. The bitset is sized to accommodate all valid
28+ // memory addresses, so word_index is always within bounds.
29+ let word = unsafe { self . words . get_unchecked_mut ( word_index) } ;
30+ let was_set = ( * word & mask) != 0 ;
31+ * word |= mask;
2632 !was_set
2733 }
2834
2935 /// Set all bits within [start, end) to 1, return the number of flipped bits.
36+ /// Assumes start < end and end <= self.words.len() * 64.
3037 #[ inline( always) ]
3138 pub fn insert_range ( & mut self , start : usize , end : usize ) -> usize {
3239 debug_assert ! ( start < end) ;
40+ debug_assert ! ( end <= self . words. len( ) * 64 , "BitSet range out of bounds" ) ;
41+
3342 let mut ret = 0 ;
34- let start_word_index = start / u64:: BITS as usize ;
35- let end_word_index = ( end - 1 ) / u64:: BITS as usize ;
36- let start_bit = start as u32 % u64:: BITS ;
43+ let start_word_index = start >> 6 ;
44+ let end_word_index = ( end - 1 ) >> 6 ;
45+ let start_bit = ( start & 63 ) as u32 ;
46+
3747 if start_word_index == end_word_index {
38- let end_bit = ( end - 1 ) as u32 % u64 :: BITS + 1 ;
48+ let end_bit = ( ( end - 1 ) & 63 ) as u32 + 1 ;
3949 let mask_bits = end_bit - start_bit;
40- let mask = ( u64:: MAX >> ( u64:: BITS - mask_bits) ) << start_bit;
41- ret += mask_bits - ( self . words [ start_word_index] & mask) . count_ones ( ) ;
42- self . words [ start_word_index] |= mask;
50+ let mask = ( u64:: MAX >> ( 64 - mask_bits) ) << start_bit;
51+ // SAFETY: Caller ensures start < end and end <= self.words.len() * 64,
52+ // so start_word_index < self.words.len()
53+ let word = unsafe { self . words . get_unchecked_mut ( start_word_index) } ;
54+ ret += mask_bits - ( * word & mask) . count_ones ( ) ;
55+ * word |= mask;
4356 } else {
44- let end_bit = end as u32 % u64 :: BITS ;
45- let mask_bits = u64 :: BITS - start_bit;
57+ let end_bit = ( end & 63 ) as u32 ;
58+ let mask_bits = 64 - start_bit;
4659 let mask = u64:: MAX << start_bit;
47- ret += mask_bits - ( self . words [ start_word_index] & mask) . count_ones ( ) ;
48- self . words [ start_word_index] |= mask;
60+ // SAFETY: Caller ensures start < end and end <= self.words.len() * 64,
61+ // so start_word_index < self.words.len()
62+ let start_word = unsafe { self . words . get_unchecked_mut ( start_word_index) } ;
63+ ret += mask_bits - ( * start_word & mask) . count_ones ( ) ;
64+ * start_word |= mask;
65+
4966 let mask_bits = end_bit;
50- let ( mask, _) = u64:: MAX . overflowing_shr ( u64:: BITS - end_bit) ;
51- ret += mask_bits - ( self . words [ end_word_index] & mask) . count_ones ( ) ;
52- self . words [ end_word_index] |= mask;
67+ let mask = if end_bit == 0 {
68+ 0
69+ } else {
70+ u64:: MAX >> ( 64 - end_bit)
71+ } ;
72+ // SAFETY: Caller ensures end <= self.words.len() * 64, so
73+ // end_word_index < self.words.len()
74+ let end_word = unsafe { self . words . get_unchecked_mut ( end_word_index) } ;
75+ ret += mask_bits - ( * end_word & mask) . count_ones ( ) ;
76+ * end_word |= mask;
5377 }
78+
5479 if start_word_index + 1 < end_word_index {
5580 for i in ( start_word_index + 1 ) ..end_word_index {
56- ret += self . words [ i] . count_zeros ( ) ;
57- self . words [ i] = u64:: MAX ;
81+ // SAFETY: Caller ensures proper start and end, so i is within bounds
82+ // of self.words.len()
83+ let word = unsafe { self . words . get_unchecked_mut ( i) } ;
84+ ret += word. count_zeros ( ) ;
85+ * word = u64:: MAX ;
5886 }
5987 }
6088 ret as usize
6189 }
6290
91+ #[ inline( always) ]
6392 pub fn clear ( & mut self ) {
64- for item in self . words . iter_mut ( ) {
65- * item = 0 ;
93+ // SAFETY: words is valid for self.words.len() elements
94+ unsafe {
95+ std:: ptr:: write_bytes ( self . words . as_mut_ptr ( ) , 0 , self . words . len ( ) ) ;
6696 }
6797 }
6898}
@@ -132,6 +162,7 @@ impl<const PAGE_BITS: usize> MemoryCtx<PAGE_BITS> {
132162 addr_space_access_count : vec ! [ 0 ; ( 1 << memory_dimensions. addr_space_height) + 1 ] ,
133163 }
134164 }
165+
135166 #[ inline( always) ]
136167 pub fn clear ( & mut self ) {
137168 self . page_indices . clear ( ) ;
@@ -147,6 +178,8 @@ impl<const PAGE_BITS: usize> MemoryCtx<PAGE_BITS> {
147178 ptr : u32 ,
148179 size : u32 ,
149180 ) {
181+ debug_assert ! ( ( address_space as usize ) < self . addr_space_access_count. len( ) ) ;
182+
150183 let num_blocks = ( size + self . chunk - 1 ) >> self . chunk_bits ;
151184 let start_chunk_id = ptr >> self . chunk_bits ;
152185 let start_block_id = if self . chunk == 1 {
@@ -159,10 +192,17 @@ impl<const PAGE_BITS: usize> MemoryCtx<PAGE_BITS> {
159192 let end_block_id = start_block_id + num_blocks;
160193 let start_page_id = start_block_id >> PAGE_BITS ;
161194 let end_page_id = ( ( end_block_id - 1 ) >> PAGE_BITS ) + 1 ;
195+
162196 for page_id in start_page_id..end_page_id {
163197 if self . page_indices . insert ( page_id as usize ) {
164198 self . page_access_count += 1 ;
165- self . addr_space_access_count [ address_space as usize ] += 1 ;
199+ // SAFETY: address_space passed is usually a hardcoded constant or derived from an
200+ // Instruction where it is bounds checked before passing
201+ unsafe {
202+ * self
203+ . addr_space_access_count
204+ . get_unchecked_mut ( address_space as usize ) += 1 ;
205+ }
166206 }
167207 }
168208 }
@@ -185,38 +225,68 @@ impl<const PAGE_BITS: usize> MemoryCtx<PAGE_BITS> {
185225 size_bits : u32 ,
186226 num : u32 ,
187227 ) {
188- let align_bits = self . as_byte_alignment_bits [ address_space as usize ] ;
228+ debug_assert ! ( ( address_space as usize ) < self . as_byte_alignment_bits. len( ) ) ;
229+
230+ // SAFETY: address_space passed is usually a hardcoded constant or derived from an
231+ // Instruction where it is bounds checked before passing
232+ let align_bits = unsafe {
233+ * self
234+ . as_byte_alignment_bits
235+ . get_unchecked ( address_space as usize )
236+ } ;
189237 debug_assert ! (
190238 align_bits as u32 <= size_bits,
191239 "align_bits ({}) must be <= size_bits ({})" ,
192240 align_bits,
193241 size_bits
194242 ) ;
243+
195244 for adapter_bits in ( align_bits as u32 + 1 ..=size_bits) . rev ( ) {
196245 let adapter_idx = self . adapter_offset + adapter_bits as usize - 1 ;
197- trace_heights[ adapter_idx] += num << ( size_bits - adapter_bits + 1 ) ;
246+ debug_assert ! ( adapter_idx < trace_heights. len( ) ) ;
247+ // SAFETY: trace_heights is initialized taking access adapters into account
248+ unsafe {
249+ * trace_heights. get_unchecked_mut ( adapter_idx) +=
250+ num << ( size_bits - adapter_bits + 1 ) ;
251+ }
198252 }
199253 }
200254
201255 /// Resolve all lazy updates of each memory access for memory adapters/poseidon2/merkle chip.
202256 #[ inline( always) ]
203257 pub ( crate ) fn lazy_update_boundary_heights ( & mut self , trace_heights : & mut [ u32 ] ) {
258+ debug_assert ! ( self . boundary_idx < trace_heights. len( ) ) ;
259+
204260 // On page fault, assume we add all leaves in a page
205261 let leaves = ( self . page_access_count << PAGE_BITS ) as u32 ;
206- trace_heights[ self . boundary_idx ] += leaves;
262+ // SAFETY: boundary_idx is a compile time constant within bounds
263+ unsafe {
264+ * trace_heights. get_unchecked_mut ( self . boundary_idx ) += leaves;
265+ }
207266
208267 if let Some ( merkle_tree_idx) = self . merkle_tree_index {
268+ debug_assert ! ( merkle_tree_idx < trace_heights. len( ) ) ;
269+ debug_assert ! ( trace_heights. len( ) >= 2 ) ;
270+
209271 let poseidon2_idx = trace_heights. len ( ) - 2 ;
210- trace_heights[ poseidon2_idx] += leaves * 2 ;
272+ // SAFETY: poseidon2_idx is trace_heights.len() - 2, guaranteed to be in bounds
273+ unsafe {
274+ * trace_heights. get_unchecked_mut ( poseidon2_idx) += leaves * 2 ;
275+ }
211276
212277 let merkle_height = self . memory_dimensions . overall_height ( ) ;
213278 let nodes = ( ( ( 1 << PAGE_BITS ) - 1 ) + ( merkle_height - PAGE_BITS ) ) as u32 ;
214- trace_heights[ poseidon2_idx] += nodes * 2 ;
215- trace_heights[ merkle_tree_idx] += nodes * 2 ;
279+ // SAFETY: merkle_tree_idx is guaranteed to be in bounds
280+ unsafe {
281+ * trace_heights. get_unchecked_mut ( poseidon2_idx) += nodes * 2 ;
282+ * trace_heights. get_unchecked_mut ( merkle_tree_idx) += nodes * 2 ;
283+ }
216284 }
217285 self . page_access_count = 0 ;
286+
218287 for address_space in 0 ..self . addr_space_access_count . len ( ) {
219- let x = self . addr_space_access_count [ address_space] ;
288+ // SAFETY: address_space is from 0 to len(), guaranteed to be in bounds
289+ let x = unsafe { * self . addr_space_access_count . get_unchecked ( address_space) } ;
220290 if x > 0 {
221291 // After finalize, we'll need to read it in chunk-sized units for the merkle chip
222292 self . update_adapter_heights_batch (
@@ -225,7 +295,12 @@ impl<const PAGE_BITS: usize> MemoryCtx<PAGE_BITS> {
225295 self . chunk_bits ,
226296 ( x << PAGE_BITS ) as u32 ,
227297 ) ;
228- self . addr_space_access_count [ address_space] = 0 ;
298+ // SAFETY: address_space is from 0 to len(), guaranteed to be in bounds
299+ unsafe {
300+ * self
301+ . addr_space_access_count
302+ . get_unchecked_mut ( address_space) = 0 ;
303+ }
229304 }
230305 }
231306 }
0 commit comments