1+ use hir:: Semantics ;
2+ use ide_db:: RootDatabase ;
3+ use stdx:: format_to;
14use syntax:: ast:: { self , AstNode } ;
25
36use crate :: { AssistContext , AssistId , AssistKind , Assists } ;
@@ -24,6 +27,7 @@ pub(crate) fn convert_two_arm_bool_match_to_matches_macro(
2427 acc : & mut Assists ,
2528 ctx : & AssistContext < ' _ > ,
2629) -> Option < ( ) > {
30+ use ArmBodyExpression :: * ;
2731 let match_expr = ctx. find_node_at_offset :: < ast:: MatchExpr > ( ) ?;
2832 let match_arm_list = match_expr. match_arm_list ( ) ?;
2933 let mut arms = match_arm_list. arms ( ) ;
@@ -33,21 +37,20 @@ pub(crate) fn convert_two_arm_bool_match_to_matches_macro(
3337 cov_mark:: hit!( non_two_arm_match) ;
3438 return None ;
3539 }
36- let first_arm_expr = first_arm. expr ( ) ;
37- let second_arm_expr = second_arm. expr ( ) ;
40+ let first_arm_expr = first_arm. expr ( ) ?;
41+ let second_arm_expr = second_arm. expr ( ) ?;
42+ let first_arm_body = is_bool_literal_expr ( & ctx. sema , & first_arm_expr) ?;
43+ let second_arm_body = is_bool_literal_expr ( & ctx. sema , & second_arm_expr) ?;
3844
39- let invert_matches = if is_bool_literal_expr ( & first_arm_expr, true )
40- && is_bool_literal_expr ( & second_arm_expr, false )
41- {
42- false
43- } else if is_bool_literal_expr ( & first_arm_expr, false )
44- && is_bool_literal_expr ( & second_arm_expr, true )
45- {
46- true
47- } else {
45+ if !matches ! (
46+ ( & first_arm_body, & second_arm_body) ,
47+ ( Literal ( true ) , Literal ( false ) )
48+ | ( Literal ( false ) , Literal ( true ) )
49+ | ( Expression ( _) , Literal ( false ) )
50+ ) {
4851 cov_mark:: hit!( non_invert_bool_literal_arms) ;
4952 return None ;
50- } ;
53+ }
5154
5255 let target_range = ctx. sema . original_range ( match_expr. syntax ( ) ) . range ;
5356 let expr = match_expr. expr ( ) ?;
@@ -59,28 +62,55 @@ pub(crate) fn convert_two_arm_bool_match_to_matches_macro(
5962 |builder| {
6063 let mut arm_str = String :: new ( ) ;
6164 if let Some ( pat) = & first_arm. pat ( ) {
62- arm_str += & pat. to_string ( ) ;
65+ format_to ! ( arm_str, "{ pat}" ) ;
6366 }
6467 if let Some ( guard) = & first_arm. guard ( ) {
6568 arm_str += & format ! ( " {guard}" ) ;
6669 }
67- if invert_matches {
68- builder. replace ( target_range, format ! ( "!matches!({expr}, {arm_str})" ) ) ;
69- } else {
70- builder. replace ( target_range, format ! ( "matches!({expr}, {arm_str})" ) ) ;
71- }
70+
71+ let replace_with = match ( first_arm_body, second_arm_body) {
72+ ( Literal ( true ) , Literal ( false ) ) => {
73+ format ! ( "matches!({expr}, {arm_str})" )
74+ }
75+ ( Literal ( false ) , Literal ( true ) ) => {
76+ format ! ( "!matches!({expr}, {arm_str})" )
77+ }
78+ ( Expression ( body_expr) , Literal ( false ) ) => {
79+ arm_str. push_str ( match & first_arm. guard ( ) {
80+ Some ( _) => " && " ,
81+ _ => " if " ,
82+ } ) ;
83+ format ! ( "matches!({expr}, {arm_str}{body_expr})" )
84+ }
85+ _ => {
86+ unreachable ! ( )
87+ }
88+ } ;
89+ builder. replace ( target_range, replace_with) ;
7290 } ,
7391 )
7492}
7593
76- fn is_bool_literal_expr ( expr : & Option < ast:: Expr > , expect_bool : bool ) -> bool {
77- if let Some ( ast:: Expr :: Literal ( lit) ) = expr {
94+ enum ArmBodyExpression {
95+ Literal ( bool ) ,
96+ Expression ( ast:: Expr ) ,
97+ }
98+
99+ fn is_bool_literal_expr (
100+ sema : & Semantics < ' _ , RootDatabase > ,
101+ expr : & ast:: Expr ,
102+ ) -> Option < ArmBodyExpression > {
103+ if let ast:: Expr :: Literal ( lit) = expr {
78104 if let ast:: LiteralKind :: Bool ( b) = lit. kind ( ) {
79- return b == expect_bool ;
105+ return Some ( ArmBodyExpression :: Literal ( b ) ) ;
80106 }
81107 }
82108
83- return false ;
109+ if !sema. type_of_expr ( expr) ?. original . is_bool ( ) {
110+ return None ;
111+ }
112+
113+ Some ( ArmBodyExpression :: Expression ( expr. clone ( ) ) )
84114}
85115
86116#[ cfg( test) ]
@@ -121,21 +151,6 @@ fn foo(a: Option<u32>) -> bool {
121151 ) ;
122152 }
123153
124- #[ test]
125- fn not_applicable_non_bool_literal_arms ( ) {
126- cov_mark:: check!( non_invert_bool_literal_arms) ;
127- check_assist_not_applicable (
128- convert_two_arm_bool_match_to_matches_macro,
129- r#"
130- fn foo(a: Option<u32>) -> bool {
131- match a$0 {
132- Some(val) => val == 3,
133- _ => false
134- }
135- }
136- "# ,
137- ) ;
138- }
139154 #[ test]
140155 fn not_applicable_both_false_arms ( ) {
141156 cov_mark:: check!( non_invert_bool_literal_arms) ;
@@ -291,4 +306,40 @@ fn main() {
291306 }" ,
292307 ) ;
293308 }
309+
310+ #[ test]
311+ fn convert_non_literal_bool ( ) {
312+ check_assist (
313+ convert_two_arm_bool_match_to_matches_macro,
314+ r#"
315+ fn main() {
316+ match 0$0 {
317+ a @ 0..15 => a == 0,
318+ _ => false,
319+ }
320+ }
321+ "# ,
322+ r#"
323+ fn main() {
324+ matches!(0, a @ 0..15 if a == 0)
325+ }
326+ "# ,
327+ ) ;
328+ check_assist (
329+ convert_two_arm_bool_match_to_matches_macro,
330+ r#"
331+ fn main() {
332+ match 0$0 {
333+ a @ 0..15 if thing() => a == 0,
334+ _ => false,
335+ }
336+ }
337+ "# ,
338+ r#"
339+ fn main() {
340+ matches!(0, a @ 0..15 if thing() && a == 0)
341+ }
342+ "# ,
343+ ) ;
344+ }
294345}
0 commit comments