Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock
Original file line number Diff line number Diff line change
Expand Up @@ -4196,6 +4196,7 @@ dependencies = [
name = "rustc_mir_build"
version = "0.0.0"
dependencies = [
"either",
"itertools",
"rustc_abi",
"rustc_apfloat",
Expand Down
1 change: 1 addition & 0 deletions compiler/rustc_hir/src/lang_items.rs
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,7 @@ language_item_table! {

// Reborrowing related lang-items
Reborrow, sym::reborrow, reborrow, Target::Trait, GenericRequirement::Exact(0);
PatCmp, sym::pat_cmp, pat_cmp, Target::Fn, GenericRequirement::Exact(1);
}

/// The requirement imposed on the generics of a lang item
Expand Down
1 change: 1 addition & 0 deletions compiler/rustc_mir_build/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ edition = "2024"

[dependencies]
# tidy-alphabetical-start
either = "1.15.0"
itertools = "0.12"
rustc_abi = { path = "../rustc_abi" }
rustc_apfloat = "0.2.0"
Expand Down
174 changes: 169 additions & 5 deletions compiler/rustc_mir_build/src/builder/matches/match_pair.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
use std::ops;
use std::sync::Arc;

use either::Either;
use rustc_hir::ByRef;
use rustc_middle::bug;
use rustc_middle::mir::*;
use rustc_middle::thir::*;
use rustc_middle::ty::{self, Ty, TypeVisitableExt};
Expand Down Expand Up @@ -34,6 +37,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
/// Used internally by [`MatchPairTree::for_pattern`].
fn prefix_slice_suffix(
&mut self,
top_pattern: &Pat<'tcx>,
match_pairs: &mut Vec<MatchPairTree<'tcx>>,
extra_data: &mut PatternExtraData<'tcx>,
place: &PlaceBuilder<'tcx>,
Expand All @@ -56,11 +60,30 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
((prefix.len() + suffix.len()).try_into().unwrap(), false)
};

for (idx, subpattern) in prefix.iter().enumerate() {
let elem =
ProjectionElem::ConstantIndex { offset: idx as u64, min_length, from_end: false };
let place = place.clone_project(elem);
MatchPairTree::for_pattern(place, subpattern, self, match_pairs, extra_data)
if suffix.is_empty() {
// new
if !prefix.is_empty() {
self.build_slice_branch(
match_pairs,
extra_data,
false,
place,
top_pattern,
prefix,
min_length,
);
}
} else {
// old
for (idx, subpattern) in prefix.iter().enumerate() {
let elem = ProjectionElem::ConstantIndex {
offset: idx as u64,
min_length,
from_end: false,
};
let place = place.clone_project(elem);
MatchPairTree::for_pattern(place, subpattern, self, match_pairs, extra_data)
}
}

if let Some(subslice_pat) = opt_slice {
Expand All @@ -84,6 +107,145 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
MatchPairTree::for_pattern(place, subpattern, self, match_pairs, extra_data)
}
}

fn build_slice_branch<'b>(
&'b mut self,
match_pairs: &mut Vec<MatchPairTree<'tcx>>,
extra_data: &mut PatternExtraData<'tcx>,
_is_suffix: bool,
place: &'b PlaceBuilder<'tcx>,
top_pattern: &Pat<'tcx>,
pattern: &[Pat<'tcx>],
min_length: u64,
) {
let entries = self.find_const_groups(pattern);

entries.into_iter().for_each(move |entry| {
let mut build_single = |idx| {
let subpattern = &pattern[idx as usize];
let place = place.clone_project(ProjectionElem::ConstantIndex {
offset: idx,
min_length,
from_end: false,
});

MatchPairTree::for_pattern(place, subpattern, self, match_pairs, extra_data);
};

match entry {
Either::Right(range) if range.end - range.start > 1 => {
let subpattern = &pattern[range.start as usize..range.end as usize];
let elem_ty = subpattern[0].ty;

let valtree = self.simplify_const_pattern_slice_into_valtree(subpattern);

let place = if top_pattern.ty.is_slice() {
place
.clone_project(ProjectionElem::Subslice {
from: range.start,
to: pattern.len() as u64 - range.end,
from_end: true,
})
.to_place(self)
} else {
place
.clone_project(ProjectionElem::Subslice {
from: range.start,
to: range.end,
from_end: false,
})
.to_place(self)
};

let pair = self.valtree_to_match_pair(top_pattern, valtree, place, elem_ty);

match_pairs.push(pair);
}
Either::Right(range) => build_single(range.start),
Either::Left(idx) => build_single(idx),
}
})
}

fn valtree_to_match_pair(
&mut self,
source_pattern: &Pat<'tcx>,
valtree: ty::ValTree<'tcx>,
place: Place<'tcx>,
elem_ty: Ty<'tcx>,
) -> MatchPairTree<'tcx> {
let tcx = self.tcx;
let n = match &*valtree {
ty::ValTreeKind::Leaf(_) => unreachable!(),
ty::ValTreeKind::Branch(children) => children.len() as u64,
};

let ty = Ty::new_array(tcx, elem_ty, n);
let value = ty::Value { ty, valtree };

MatchPairTree {
place: Some(place),
test_case: TestCase::Constant { value },
subpairs: vec![],
pattern_ty: ty,
pattern_span: source_pattern.span,
}
}

fn find_const_groups(&self, pattern: &[Pat<'tcx>]) -> Vec<Either<u64, ops::Range<u64>>> {
let mut entries = Vec::new();
let mut current_seq_start = None;

for (idx, pat) in pattern.iter().enumerate() {
if self.is_constant_pattern(pat) {
if current_seq_start.is_none() {
current_seq_start = Some(idx as u64);
} else {
continue;
}
} else {
if let Some(start) = current_seq_start {
entries.push(Either::Right(start..idx as u64));
current_seq_start = None;
}
entries.push(Either::Left(idx as u64));
}
}

if let Some(start) = current_seq_start {
entries.push(Either::Right(start..pattern.len() as u64));
}

entries
}

fn is_constant_pattern(&self, pat: &Pat<'tcx>) -> bool {
if let PatKind::Constant { value } = pat.kind
&& let ty::ValTreeKind::Leaf(_) = &*value.valtree
{
true
} else {
false
}
}

fn extract_leaf(&self, pat: &Pat<'tcx>) -> ty::ValTree<'tcx> {
if let PatKind::Constant { value } = pat.kind
&& matches!(&*value.valtree, ty::ValTreeKind::Leaf(_))
{
value.valtree
} else {
bug!("expected constant pattern, got {:?}", pat)
}
}

fn simplify_const_pattern_slice_into_valtree(
&self,
subslice: &[Pat<'tcx>],
) -> ty::ValTree<'tcx> {
let leaves = subslice.iter().map(|p| self.extract_leaf(p));
ty::ValTree::from_branches(self.tcx, leaves)
}
}

impl<'tcx> MatchPairTree<'tcx> {
Expand Down Expand Up @@ -222,6 +384,7 @@ impl<'tcx> MatchPairTree<'tcx> {

PatKind::Array { ref prefix, ref slice, ref suffix } => {
cx.prefix_slice_suffix(
pattern,
&mut subpairs,
extra_data,
&place_builder,
Expand All @@ -233,6 +396,7 @@ impl<'tcx> MatchPairTree<'tcx> {
}
PatKind::Slice { ref prefix, ref slice, ref suffix } => {
cx.prefix_slice_suffix(
pattern,
&mut subpairs,
extra_data,
&place_builder,
Expand Down
134 changes: 125 additions & 9 deletions compiler/rustc_mir_build/src/builder/matches/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -236,22 +236,35 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
// (Interestingly this means that exhaustiveness analysis relies, for soundness,
// on the `PartialEq` impl for `str` to b correct!)
match *cast_ty.kind() {
ty::Ref(_, deref_ty, _) if deref_ty == self.tcx.types.str_ => {}
ty::Ref(_, deref_ty, _) if deref_ty == self.tcx.types.str_ => {
self.string_compare(
block,
success_block,
fail_block,
source_info,
expect,
Operand::Copy(place),
);
}
ty::Array(elem_ty, n) if elem_ty.is_scalar() => {
self.scalar_array_compare(
block,
success_block,
fail_block,
source_info,
expect,
place,
elem_ty,
n,
);
}
_ => {
span_bug!(
source_info.span,
"invalid type for non-scalar compare: {cast_ty}"
)
}
};
self.string_compare(
block,
success_block,
fail_block,
source_info,
expect,
Operand::Copy(place),
);
} else {
self.compare(
block,
Expand Down Expand Up @@ -487,6 +500,109 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
);
}

fn scalar_array_compare(
&mut self,
block: BasicBlock,
success_block: BasicBlock,
fail_block: BasicBlock,
source_info: SourceInfo,
expect: Operand<'tcx>,
val: Place<'tcx>,
item_ty: Ty<'tcx>,
n: ty::Const<'tcx>,
) {
let pat_cmp_def_id = self.tcx.require_lang_item(LangItem::PatCmp, source_info.span);
let array_ty = Ty::new_array_with_const_len(self.tcx, item_ty, n);
let slice_ty = Ty::new_slice(self.tcx, item_ty);
let func =
Operand::function_handle(self.tcx, pat_cmp_def_id, [array_ty.into()], source_info.span);

let re_erased = self.tcx.lifetimes.re_erased;
let array_ref_ty = Ty::new_ref(self.tcx, re_erased, array_ty, ty::Mutability::Not);
//let slice_ref_ty = Ty::new_ref(self.tcx, re_erased, slice_ty, ty::Mutability::Not);
let array_ptr_ty = Ty::new_ptr(self.tcx, array_ty, ty::Mutability::Not);
let slice_ptr_ty = Ty::new_ptr(self.tcx, slice_ty, ty::Mutability::Not);

let val_ref = match val.ty(&self.local_decls, self.tcx).ty.kind() {
ty::Array(_, _) => {
let val_ref = self.temp(array_ref_ty, source_info.span);
self.cfg.push_assign(
block,
source_info,
val_ref,
Rvalue::Ref(re_erased, BorrowKind::Shared, val),
);
val_ref
}
ty::Slice(_) => {
let val_slice_ptr = self.temp(slice_ptr_ty, source_info.span);
self.cfg.push_assign(
block,
source_info,
val_slice_ptr,
Rvalue::RawPtr(RawPtrKind::Const, val),
);
let val_array_ptr = self.temp(array_ptr_ty, source_info.span);
self.cfg.push_assign(
block,
source_info,
val_array_ptr,
Rvalue::Cast(CastKind::PtrToPtr, Operand::Copy(val_slice_ptr), array_ptr_ty),
);

let val_array = val_array_ptr.project_deeper(&[PlaceElem::Deref], self.tcx);
let val_ref = self.temp(array_ref_ty, source_info.span);
self.cfg.push_assign(
block,
source_info,
val_ref,
Rvalue::Ref(re_erased, BorrowKind::Shared, val_array),
);
val_ref
}
_ => unreachable!(),
};

let expect_value = self.temp(array_ty, source_info.span);
self.cfg.push_assign(block, source_info, expect_value, Rvalue::Use(expect));
let expect_ref = self.temp(array_ref_ty, source_info.span);
self.cfg.push_assign(
block,
source_info,
expect_ref,
Rvalue::Ref(re_erased, BorrowKind::Shared, expect_value),
);

let bool_ty = self.tcx.types.bool;
let eq_result = self.temp(bool_ty, source_info.span);
let eq_block = self.cfg.start_new_block();
self.cfg.terminate(
block,
source_info,
TerminatorKind::Call {
func,
args: [
Spanned { node: Operand::Copy(val_ref), span: DUMMY_SP },
Spanned { node: Operand::Copy(expect_ref), span: DUMMY_SP },
]
.into(),
destination: eq_result,
target: Some(eq_block),
unwind: UnwindAction::Continue,
call_source: CallSource::MatchCmp,
fn_span: source_info.span,
},
);
self.diverge_from(block);

// check the result
self.cfg.terminate(
eq_block,
source_info,
TerminatorKind::if_(Operand::Move(eq_result), success_block, fail_block),
);
}

/// Given that we are performing `test` against `test_place`, this job
/// sorts out what the status of `candidate` will be after the test. See
/// `test_candidates` for the usage of this function. The candidate may
Expand Down
Loading
Loading