11use crate :: deriving:: generic:: ty:: * ;
22use crate :: deriving:: generic:: * ;
33use crate :: deriving:: { path_std, pathvec_std} ;
4- use rustc_ast:: MetaItem ;
4+ use rustc_ast:: { ExprKind , ItemKind , MetaItem , PatKind } ;
55use rustc_expand:: base:: { Annotatable , ExtCtxt } ;
66use rustc_span:: symbol:: { sym, Ident } ;
77use rustc_span:: Span ;
@@ -21,6 +21,27 @@ pub fn expand_deriving_partial_ord(
2121
2222 let attrs = thin_vec ! [ cx. attr_word( sym:: inline, span) ] ;
2323
24+ // Order in which to perform matching
25+ let tag_then_data = if let Annotatable :: Item ( item) = item
26+ && let ItemKind :: Enum ( def, _) = & item. kind {
27+ let dataful: Vec < bool > = def. variants . iter ( ) . map ( |v| !v. data . fields ( ) . is_empty ( ) ) . collect ( ) ;
28+ match dataful. iter ( ) . filter ( |& & b| b) . count ( ) {
29+ // No data, placing the tag check first makes codegen simpler
30+ 0 => true ,
31+ 1 ..=2 => false ,
32+ _ => {
33+ ( 0 ..dataful. len ( ) -1 ) . any ( |i| {
34+ if dataful[ i] && let Some ( idx) = dataful[ i+1 ..] . iter ( ) . position ( |v| * v) {
35+ idx >= 2
36+ } else {
37+ false
38+ }
39+ } )
40+ }
41+ }
42+ } else {
43+ true
44+ } ;
2445 let partial_cmp_def = MethodDef {
2546 name : sym:: partial_cmp,
2647 generics : Bounds :: empty ( ) ,
@@ -30,7 +51,7 @@ pub fn expand_deriving_partial_ord(
3051 attributes : attrs,
3152 unify_fieldless_variants : true ,
3253 combine_substructure : combine_substructure ( Box :: new ( |cx, span, substr| {
33- cs_partial_cmp ( cx, span, substr)
54+ cs_partial_cmp ( cx, span, substr, tag_then_data )
3455 } ) ) ,
3556 } ;
3657
@@ -47,7 +68,12 @@ pub fn expand_deriving_partial_ord(
4768 trait_def. expand ( cx, mitem, item, push)
4869}
4970
50- pub fn cs_partial_cmp ( cx : & mut ExtCtxt < ' _ > , span : Span , substr : & Substructure < ' _ > ) -> BlockOrExpr {
71+ fn cs_partial_cmp (
72+ cx : & mut ExtCtxt < ' _ > ,
73+ span : Span ,
74+ substr : & Substructure < ' _ > ,
75+ tag_then_data : bool ,
76+ ) -> BlockOrExpr {
5177 let test_id = Ident :: new ( sym:: cmp, span) ;
5278 let equal_path = cx. path_global ( span, cx. std_path ( & [ sym:: cmp, sym:: Ordering , sym:: Equal ] ) ) ;
5379 let partial_cmp_path = cx. std_path ( & [ sym:: cmp, sym:: PartialOrd , sym:: partial_cmp] ) ;
@@ -74,12 +100,50 @@ pub fn cs_partial_cmp(cx: &mut ExtCtxt<'_>, span: Span, substr: &Substructure<'_
74100 let args = vec ! [ field. self_expr. clone( ) , other_expr. clone( ) ] ;
75101 cx. expr_call_global ( field. span , partial_cmp_path. clone ( ) , args)
76102 }
77- CsFold :: Combine ( span, expr1, expr2) => {
78- let eq_arm =
79- cx. arm ( span, cx. pat_some ( span, cx. pat_path ( span, equal_path. clone ( ) ) ) , expr1) ;
80- let neq_arm =
81- cx. arm ( span, cx. pat_ident ( span, test_id) , cx. expr_ident ( span, test_id) ) ;
82- cx. expr_match ( span, expr2, vec ! [ eq_arm, neq_arm] )
103+ CsFold :: Combine ( span, mut expr1, expr2) => {
104+ // When the item is an enum, this expands to
105+ // ```
106+ // match (expr2) {
107+ // Some(Ordering::Equal) => expr1,
108+ // cmp => cmp
109+ // }
110+ // ```
111+ // where `expr2` is `partial_cmp(self_tag, other_tag)`, and `expr1` is a `match`
112+ // against the enum variants. This means that we begin by comparing the enum tags,
113+ // before either inspecting their contents (if they match), or returning
114+ // the `cmp::Ordering` of comparing the enum tags.
115+ // ```
116+ // match partial_cmp(self_tag, other_tag) {
117+ // Some(Ordering::Equal) => match (self, other) {
118+ // (Self::A(self_0), Self::A(other_0)) => partial_cmp(self_0, other_0),
119+ // (Self::B(self_0), Self::B(other_0)) => partial_cmp(self_0, other_0),
120+ // _ => Some(Ordering::Equal)
121+ // }
122+ // cmp => cmp
123+ // }
124+ // ```
125+ // If we have any certain enum layouts, flipping this results in better codegen
126+ // ```
127+ // match (self, other) {
128+ // (Self::A(self_0), Self::A(other_0)) => partial_cmp(self_0, other_0),
129+ // _ => partial_cmp(self_tag, other_tag)
130+ // }
131+ // ```
132+ // Reference: https://github.com/rust-lang/rust/pull/103659#issuecomment-1328126354
133+
134+ if !tag_then_data
135+ && let ExprKind :: Match ( _, arms) = & mut expr1. kind
136+ && let Some ( last) = arms. last_mut ( )
137+ && let PatKind :: Wild = last. pat . kind {
138+ last. body = expr2;
139+ expr1
140+ } else {
141+ let eq_arm =
142+ cx. arm ( span, cx. pat_some ( span, cx. pat_path ( span, equal_path. clone ( ) ) ) , expr1) ;
143+ let neq_arm =
144+ cx. arm ( span, cx. pat_ident ( span, test_id) , cx. expr_ident ( span, test_id) ) ;
145+ cx. expr_match ( span, expr2, vec ! [ eq_arm, neq_arm] )
146+ }
83147 }
84148 CsFold :: Fieldless => cx. expr_some ( span, cx. expr_path ( equal_path. clone ( ) ) ) ,
85149 } ,
0 commit comments