diff --git a/Cargo.lock b/Cargo.lock index c2e635b4cfe65..564c76ee67969 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4196,6 +4196,7 @@ dependencies = [ name = "rustc_mir_build" version = "0.0.0" dependencies = [ + "either", "itertools", "rustc_abi", "rustc_apfloat", diff --git a/compiler/rustc_hir/src/lang_items.rs b/compiler/rustc_hir/src/lang_items.rs index 2e099a97b65be..f599be96ab5b2 100644 --- a/compiler/rustc_hir/src/lang_items.rs +++ b/compiler/rustc_hir/src/lang_items.rs @@ -440,6 +440,7 @@ language_item_table! { // Reborrowing related lang-items Reborrow, sym::reborrow, reborrow, Target::Trait, GenericRequirement::Exact(0); + PatCmp, sym::pat_cmp, pat_cmp, Target::Fn, GenericRequirement::Exact(1); } /// The requirement imposed on the generics of a lang item diff --git a/compiler/rustc_mir_build/Cargo.toml b/compiler/rustc_mir_build/Cargo.toml index f756f0a19ee9b..bf888b3f7add7 100644 --- a/compiler/rustc_mir_build/Cargo.toml +++ b/compiler/rustc_mir_build/Cargo.toml @@ -5,6 +5,7 @@ edition = "2024" [dependencies] # tidy-alphabetical-start +either = "1.15.0" itertools = "0.12" rustc_abi = { path = "../rustc_abi" } rustc_apfloat = "0.2.0" diff --git a/compiler/rustc_mir_build/src/builder/matches/match_pair.rs b/compiler/rustc_mir_build/src/builder/matches/match_pair.rs index 7a848536d0e33..5d396e606a5ef 100644 --- a/compiler/rustc_mir_build/src/builder/matches/match_pair.rs +++ b/compiler/rustc_mir_build/src/builder/matches/match_pair.rs @@ -1,6 +1,9 @@ +use std::ops; use std::sync::Arc; +use either::Either; use rustc_hir::ByRef; +use rustc_middle::bug; use rustc_middle::mir::*; use rustc_middle::thir::*; use rustc_middle::ty::{self, Ty, TypeVisitableExt}; @@ -34,6 +37,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { /// Used internally by [`MatchPairTree::for_pattern`]. fn prefix_slice_suffix( &mut self, + top_pattern: &Pat<'tcx>, match_pairs: &mut Vec>, extra_data: &mut PatternExtraData<'tcx>, place: &PlaceBuilder<'tcx>, @@ -56,11 +60,30 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { ((prefix.len() + suffix.len()).try_into().unwrap(), false) }; - for (idx, subpattern) in prefix.iter().enumerate() { - let elem = - ProjectionElem::ConstantIndex { offset: idx as u64, min_length, from_end: false }; - let place = place.clone_project(elem); - MatchPairTree::for_pattern(place, subpattern, self, match_pairs, extra_data) + if suffix.is_empty() { + // new + if !prefix.is_empty() { + self.build_slice_branch( + match_pairs, + extra_data, + false, + place, + top_pattern, + prefix, + min_length, + ); + } + } else { + // old + for (idx, subpattern) in prefix.iter().enumerate() { + let elem = ProjectionElem::ConstantIndex { + offset: idx as u64, + min_length, + from_end: false, + }; + let place = place.clone_project(elem); + MatchPairTree::for_pattern(place, subpattern, self, match_pairs, extra_data) + } } if let Some(subslice_pat) = opt_slice { @@ -84,6 +107,145 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { MatchPairTree::for_pattern(place, subpattern, self, match_pairs, extra_data) } } + + fn build_slice_branch<'b>( + &'b mut self, + match_pairs: &mut Vec>, + extra_data: &mut PatternExtraData<'tcx>, + _is_suffix: bool, + place: &'b PlaceBuilder<'tcx>, + top_pattern: &Pat<'tcx>, + pattern: &[Pat<'tcx>], + min_length: u64, + ) { + let entries = self.find_const_groups(pattern); + + entries.into_iter().for_each(move |entry| { + let mut build_single = |idx| { + let subpattern = &pattern[idx as usize]; + let place = place.clone_project(ProjectionElem::ConstantIndex { + offset: idx, + min_length, + from_end: false, + }); + + MatchPairTree::for_pattern(place, subpattern, self, match_pairs, extra_data); + }; + + match entry { + Either::Right(range) if range.end - range.start > 1 => { + let subpattern = &pattern[range.start as usize..range.end as usize]; + let elem_ty = subpattern[0].ty; + + let valtree = self.simplify_const_pattern_slice_into_valtree(subpattern); + + let place = if top_pattern.ty.is_slice() { + place + .clone_project(ProjectionElem::Subslice { + from: range.start, + to: pattern.len() as u64 - range.end, + from_end: true, + }) + .to_place(self) + } else { + place + .clone_project(ProjectionElem::Subslice { + from: range.start, + to: range.end, + from_end: false, + }) + .to_place(self) + }; + + let pair = self.valtree_to_match_pair(top_pattern, valtree, place, elem_ty); + + match_pairs.push(pair); + } + Either::Right(range) => build_single(range.start), + Either::Left(idx) => build_single(idx), + } + }) + } + + fn valtree_to_match_pair( + &mut self, + source_pattern: &Pat<'tcx>, + valtree: ty::ValTree<'tcx>, + place: Place<'tcx>, + elem_ty: Ty<'tcx>, + ) -> MatchPairTree<'tcx> { + let tcx = self.tcx; + let n = match &*valtree { + ty::ValTreeKind::Leaf(_) => unreachable!(), + ty::ValTreeKind::Branch(children) => children.len() as u64, + }; + + let ty = Ty::new_array(tcx, elem_ty, n); + let value = ty::Value { ty, valtree }; + + MatchPairTree { + place: Some(place), + test_case: TestCase::Constant { value }, + subpairs: vec![], + pattern_ty: ty, + pattern_span: source_pattern.span, + } + } + + fn find_const_groups(&self, pattern: &[Pat<'tcx>]) -> Vec>> { + let mut entries = Vec::new(); + let mut current_seq_start = None; + + for (idx, pat) in pattern.iter().enumerate() { + if self.is_constant_pattern(pat) { + if current_seq_start.is_none() { + current_seq_start = Some(idx as u64); + } else { + continue; + } + } else { + if let Some(start) = current_seq_start { + entries.push(Either::Right(start..idx as u64)); + current_seq_start = None; + } + entries.push(Either::Left(idx as u64)); + } + } + + if let Some(start) = current_seq_start { + entries.push(Either::Right(start..pattern.len() as u64)); + } + + entries + } + + fn is_constant_pattern(&self, pat: &Pat<'tcx>) -> bool { + if let PatKind::Constant { value } = pat.kind + && let ty::ValTreeKind::Leaf(_) = &*value.valtree + { + true + } else { + false + } + } + + fn extract_leaf(&self, pat: &Pat<'tcx>) -> ty::ValTree<'tcx> { + if let PatKind::Constant { value } = pat.kind + && matches!(&*value.valtree, ty::ValTreeKind::Leaf(_)) + { + value.valtree + } else { + bug!("expected constant pattern, got {:?}", pat) + } + } + + fn simplify_const_pattern_slice_into_valtree( + &self, + subslice: &[Pat<'tcx>], + ) -> ty::ValTree<'tcx> { + let leaves = subslice.iter().map(|p| self.extract_leaf(p)); + ty::ValTree::from_branches(self.tcx, leaves) + } } impl<'tcx> MatchPairTree<'tcx> { @@ -222,6 +384,7 @@ impl<'tcx> MatchPairTree<'tcx> { PatKind::Array { ref prefix, ref slice, ref suffix } => { cx.prefix_slice_suffix( + pattern, &mut subpairs, extra_data, &place_builder, @@ -233,6 +396,7 @@ impl<'tcx> MatchPairTree<'tcx> { } PatKind::Slice { ref prefix, ref slice, ref suffix } => { cx.prefix_slice_suffix( + pattern, &mut subpairs, extra_data, &place_builder, diff --git a/compiler/rustc_mir_build/src/builder/matches/test.rs b/compiler/rustc_mir_build/src/builder/matches/test.rs index 1b6d96e49f0c1..6025af1aaa14b 100644 --- a/compiler/rustc_mir_build/src/builder/matches/test.rs +++ b/compiler/rustc_mir_build/src/builder/matches/test.rs @@ -236,7 +236,28 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { // (Interestingly this means that exhaustiveness analysis relies, for soundness, // on the `PartialEq` impl for `str` to b correct!) match *cast_ty.kind() { - ty::Ref(_, deref_ty, _) if deref_ty == self.tcx.types.str_ => {} + ty::Ref(_, deref_ty, _) if deref_ty == self.tcx.types.str_ => { + self.string_compare( + block, + success_block, + fail_block, + source_info, + expect, + Operand::Copy(place), + ); + } + ty::Array(elem_ty, n) if elem_ty.is_scalar() => { + self.scalar_array_compare( + block, + success_block, + fail_block, + source_info, + expect, + place, + elem_ty, + n, + ); + } _ => { span_bug!( source_info.span, @@ -244,14 +265,6 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { ) } }; - self.string_compare( - block, - success_block, - fail_block, - source_info, - expect, - Operand::Copy(place), - ); } else { self.compare( block, @@ -487,6 +500,109 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { ); } + fn scalar_array_compare( + &mut self, + block: BasicBlock, + success_block: BasicBlock, + fail_block: BasicBlock, + source_info: SourceInfo, + expect: Operand<'tcx>, + val: Place<'tcx>, + item_ty: Ty<'tcx>, + n: ty::Const<'tcx>, + ) { + let pat_cmp_def_id = self.tcx.require_lang_item(LangItem::PatCmp, source_info.span); + let array_ty = Ty::new_array_with_const_len(self.tcx, item_ty, n); + let slice_ty = Ty::new_slice(self.tcx, item_ty); + let func = + Operand::function_handle(self.tcx, pat_cmp_def_id, [array_ty.into()], source_info.span); + + let re_erased = self.tcx.lifetimes.re_erased; + let array_ref_ty = Ty::new_ref(self.tcx, re_erased, array_ty, ty::Mutability::Not); + //let slice_ref_ty = Ty::new_ref(self.tcx, re_erased, slice_ty, ty::Mutability::Not); + let array_ptr_ty = Ty::new_ptr(self.tcx, array_ty, ty::Mutability::Not); + let slice_ptr_ty = Ty::new_ptr(self.tcx, slice_ty, ty::Mutability::Not); + + let val_ref = match val.ty(&self.local_decls, self.tcx).ty.kind() { + ty::Array(_, _) => { + let val_ref = self.temp(array_ref_ty, source_info.span); + self.cfg.push_assign( + block, + source_info, + val_ref, + Rvalue::Ref(re_erased, BorrowKind::Shared, val), + ); + val_ref + } + ty::Slice(_) => { + let val_slice_ptr = self.temp(slice_ptr_ty, source_info.span); + self.cfg.push_assign( + block, + source_info, + val_slice_ptr, + Rvalue::RawPtr(RawPtrKind::Const, val), + ); + let val_array_ptr = self.temp(array_ptr_ty, source_info.span); + self.cfg.push_assign( + block, + source_info, + val_array_ptr, + Rvalue::Cast(CastKind::PtrToPtr, Operand::Copy(val_slice_ptr), array_ptr_ty), + ); + + let val_array = val_array_ptr.project_deeper(&[PlaceElem::Deref], self.tcx); + let val_ref = self.temp(array_ref_ty, source_info.span); + self.cfg.push_assign( + block, + source_info, + val_ref, + Rvalue::Ref(re_erased, BorrowKind::Shared, val_array), + ); + val_ref + } + _ => unreachable!(), + }; + + let expect_value = self.temp(array_ty, source_info.span); + self.cfg.push_assign(block, source_info, expect_value, Rvalue::Use(expect)); + let expect_ref = self.temp(array_ref_ty, source_info.span); + self.cfg.push_assign( + block, + source_info, + expect_ref, + Rvalue::Ref(re_erased, BorrowKind::Shared, expect_value), + ); + + let bool_ty = self.tcx.types.bool; + let eq_result = self.temp(bool_ty, source_info.span); + let eq_block = self.cfg.start_new_block(); + self.cfg.terminate( + block, + source_info, + TerminatorKind::Call { + func, + args: [ + Spanned { node: Operand::Copy(val_ref), span: DUMMY_SP }, + Spanned { node: Operand::Copy(expect_ref), span: DUMMY_SP }, + ] + .into(), + destination: eq_result, + target: Some(eq_block), + unwind: UnwindAction::Continue, + call_source: CallSource::MatchCmp, + fn_span: source_info.span, + }, + ); + self.diverge_from(block); + + // check the result + self.cfg.terminate( + eq_block, + source_info, + TerminatorKind::if_(Operand::Move(eq_result), success_block, fail_block), + ); + } + /// Given that we are performing `test` against `test_place`, this job /// sorts out what the status of `candidate` will be after the test. See /// `test_candidates` for the usage of this function. The candidate may diff --git a/compiler/rustc_span/src/symbol.rs b/compiler/rustc_span/src/symbol.rs index 4fef65f46b1fd..f13c57fab2f0c 100644 --- a/compiler/rustc_span/src/symbol.rs +++ b/compiler/rustc_span/src/symbol.rs @@ -1633,6 +1633,7 @@ symbols! { partial_ord, passes, pat, + pat_cmp, pat_param, patchable_function_entry, path, diff --git a/library/core/src/cmp.rs b/library/core/src/cmp.rs index 7f369d19c3d12..62f9398b4952e 100644 --- a/library/core/src/cmp.rs +++ b/library/core/src/cmp.rs @@ -2259,3 +2259,11 @@ mod impls { } } } + +#[lang = "pat_cmp"] +#[rustc_const_stable_indirect] +#[rustc_allow_const_fn_unstable(const_cmp, const_trait_impl)] +#[inline(always)] +const fn pat_cmp(lhs: &T, rhs: &T) -> bool { + lhs.eq(rhs) +}