|
137 | 137 | //! } |
138 | 138 | //! vec.truncate(write_idx); |
139 | 139 | //! ``` |
| 140 | +use crate::alloc::{handle_alloc_error, Global}; |
| 141 | +use core::alloc::Allocator; |
| 142 | +use core::alloc::Layout; |
140 | 143 | use core::iter::{InPlaceIterable, SourceIter, TrustedRandomAccessNoCoerce}; |
141 | 144 | use core::mem::{self, ManuallyDrop, SizedTypeProperties}; |
142 | | -use core::ptr::{self}; |
| 145 | +use core::num::NonZeroUsize; |
| 146 | +use core::ptr::{self, NonNull}; |
143 | 147 |
|
144 | 148 | use super::{InPlaceDrop, InPlaceDstBufDrop, SpecFromIter, SpecFromIterNested, Vec}; |
145 | 149 |
|
146 | | -/// Specialization marker for collecting an iterator pipeline into a Vec while reusing the |
147 | | -/// source allocation, i.e. executing the pipeline in place. |
148 | | -#[rustc_unsafe_specialization_marker] |
149 | | -pub(super) trait InPlaceIterableMarker {} |
| 150 | +const fn in_place_collectible<DEST, SRC>( |
| 151 | + step_merge: Option<NonZeroUsize>, |
| 152 | + step_expand: Option<NonZeroUsize>, |
| 153 | +) -> bool { |
| 154 | + if DEST::IS_ZST || mem::align_of::<SRC>() != mem::align_of::<DEST>() { |
| 155 | + return false; |
| 156 | + } |
| 157 | + |
| 158 | + match (step_merge, step_expand) { |
| 159 | + (Some(step_merge), Some(step_expand)) => { |
| 160 | + // At least N merged source items -> at most M expanded destination items |
| 161 | + // e.g. |
| 162 | + // - 1 x [u8; 4] -> 4x u8, via flatten |
| 163 | + // - 4 x u8 -> 1x [u8; 4], via array_chunks |
| 164 | + mem::size_of::<SRC>() * step_merge.get() == mem::size_of::<DEST>() * step_expand.get() |
| 165 | + } |
| 166 | + // Fall back to other from_iter impls if an overflow occured in the step merge/expansion |
| 167 | + // tracking. |
| 168 | + _ => false, |
| 169 | + } |
| 170 | +} |
150 | 171 |
|
151 | | -impl<T> InPlaceIterableMarker for T where T: InPlaceIterable {} |
| 172 | +/// This provides a shorthand for the source type since local type aliases aren't a thing. |
| 173 | +#[rustc_specialization_trait] |
| 174 | +trait InPlaceCollect: SourceIter<Source: AsVecIntoIter> + InPlaceIterable { |
| 175 | + type Src; |
| 176 | +} |
| 177 | + |
| 178 | +impl<T> InPlaceCollect for T |
| 179 | +where |
| 180 | + T: SourceIter<Source: AsVecIntoIter> + InPlaceIterable, |
| 181 | +{ |
| 182 | + type Src = <<T as SourceIter>::Source as AsVecIntoIter>::Item; |
| 183 | +} |
152 | 184 |
|
153 | 185 | impl<T, I> SpecFromIter<T, I> for Vec<T> |
154 | 186 | where |
155 | | - I: Iterator<Item = T> + SourceIter<Source: AsVecIntoIter> + InPlaceIterableMarker, |
| 187 | + I: Iterator<Item = T> + InPlaceCollect, |
| 188 | + <I as SourceIter>::Source: AsVecIntoIter, |
156 | 189 | { |
157 | 190 | default fn from_iter(mut iterator: I) -> Self { |
158 | 191 | // See "Layout constraints" section in the module documentation. We rely on const |
159 | 192 | // optimization here since these conditions currently cannot be expressed as trait bounds |
160 | | - if T::IS_ZST |
161 | | - || mem::size_of::<T>() |
162 | | - != mem::size_of::<<<I as SourceIter>::Source as AsVecIntoIter>::Item>() |
163 | | - || mem::align_of::<T>() |
164 | | - != mem::align_of::<<<I as SourceIter>::Source as AsVecIntoIter>::Item>() |
165 | | - { |
| 193 | + if const { !in_place_collectible::<T, I::Src>(I::MERGE_BY, I::EXPAND_BY) } { |
166 | 194 | // fallback to more generic implementations |
167 | 195 | return SpecFromIterNested::from_iter(iterator); |
168 | 196 | } |
169 | 197 |
|
170 | | - let (src_buf, src_ptr, dst_buf, dst_end, cap) = unsafe { |
| 198 | + let (src_buf, src_ptr, src_cap, mut dst_buf, dst_end, dst_cap) = unsafe { |
171 | 199 | let inner = iterator.as_inner().as_into_iter(); |
172 | 200 | ( |
173 | 201 | inner.buf.as_ptr(), |
174 | 202 | inner.ptr, |
| 203 | + inner.cap, |
175 | 204 | inner.buf.as_ptr() as *mut T, |
176 | 205 | inner.end as *const T, |
177 | | - inner.cap, |
| 206 | + inner.cap * mem::size_of::<I::Src>() / mem::size_of::<T>(), |
178 | 207 | ) |
179 | 208 | }; |
180 | 209 |
|
@@ -203,11 +232,31 @@ where |
203 | 232 | // Note: This access to the source wouldn't be allowed by the TrustedRandomIteratorNoCoerce |
204 | 233 | // contract (used by SpecInPlaceCollect below). But see the "O(1) collect" section in the |
205 | 234 | // module documentation why this is ok anyway. |
206 | | - let dst_guard = InPlaceDstBufDrop { ptr: dst_buf, len, cap }; |
| 235 | + let dst_guard = InPlaceDstBufDrop { ptr: dst_buf, len, cap: dst_cap }; |
207 | 236 | src.forget_allocation_drop_remaining(); |
208 | 237 | mem::forget(dst_guard); |
209 | 238 |
|
210 | | - let vec = unsafe { Vec::from_raw_parts(dst_buf, len, cap) }; |
| 239 | + // Adjust the allocation size if the source had a capacity in bytes that wasn't a multiple |
| 240 | + // of the destination type size. |
| 241 | + // Since the discrepancy should generally be small this should only result in some |
| 242 | + // bookkeeping updates and no memmove. |
| 243 | + if const { mem::size_of::<T>() > mem::size_of::<I::Src>() } |
| 244 | + && src_cap * mem::size_of::<I::Src>() != dst_cap * mem::size_of::<T>() |
| 245 | + { |
| 246 | + let alloc = Global; |
| 247 | + unsafe { |
| 248 | + let new_layout = Layout::array::<T>(dst_cap).unwrap(); |
| 249 | + let result = alloc.shrink( |
| 250 | + NonNull::new_unchecked(dst_buf as *mut u8), |
| 251 | + Layout::array::<I::Src>(src_cap).unwrap(), |
| 252 | + new_layout, |
| 253 | + ); |
| 254 | + let Ok(reallocated) = result else { handle_alloc_error(new_layout) }; |
| 255 | + dst_buf = reallocated.as_ptr() as *mut T; |
| 256 | + } |
| 257 | + } |
| 258 | + |
| 259 | + let vec = unsafe { Vec::from_raw_parts(dst_buf, len, dst_cap) }; |
211 | 260 |
|
212 | 261 | vec |
213 | 262 | } |
|
0 commit comments