11use std:: iter:: once;
22
3- use ide_db:: {
4- syntax_helpers:: node_ext:: { is_pattern_cond, single_let} ,
5- ty_filter:: TryEnum ,
6- } ;
3+ use either:: Either ;
4+ use hir:: { Semantics , TypeInfo } ;
5+ use ide_db:: { RootDatabase , ty_filter:: TryEnum } ;
76use syntax:: {
87 AstNode ,
9- SyntaxKind :: { FN , FOR_EXPR , LOOP_EXPR , WHILE_EXPR , WHITESPACE } ,
10- T ,
8+ SyntaxKind :: { CLOSURE_EXPR , FN , FOR_EXPR , LOOP_EXPR , WHILE_EXPR , WHITESPACE } ,
9+ SyntaxNode , T ,
1110 ast:: {
1211 self ,
1312 edit:: { AstNodeEdit , IndentLevel } ,
@@ -44,12 +43,9 @@ use crate::{
4443// }
4544// ```
4645pub ( crate ) fn convert_to_guarded_return ( acc : & mut Assists , ctx : & AssistContext < ' _ > ) -> Option < ( ) > {
47- if let Some ( let_stmt) = ctx. find_node_at_offset ( ) {
48- let_stmt_to_guarded_return ( let_stmt, acc, ctx)
49- } else if let Some ( if_expr) = ctx. find_node_at_offset ( ) {
50- if_expr_to_guarded_return ( if_expr, acc, ctx)
51- } else {
52- None
46+ match ctx. find_node_at_offset :: < Either < ast:: LetStmt , ast:: IfExpr > > ( ) ? {
47+ Either :: Left ( let_stmt) => let_stmt_to_guarded_return ( let_stmt, acc, ctx) ,
48+ Either :: Right ( if_expr) => if_expr_to_guarded_return ( if_expr, acc, ctx) ,
5349 }
5450}
5551
@@ -73,13 +69,7 @@ fn if_expr_to_guarded_return(
7369 return None ;
7470 }
7571
76- // Check if there is an IfLet that we can handle.
77- let ( if_let_pat, cond_expr) = if is_pattern_cond ( cond. clone ( ) ) {
78- let let_ = single_let ( cond) ?;
79- ( Some ( let_. pat ( ) ?) , let_. expr ( ) ?)
80- } else {
81- ( None , cond)
82- } ;
72+ let let_chains = flat_let_chain ( cond) ;
8373
8474 let then_block = if_expr. then_branch ( ) ?;
8575 let then_block = then_block. stmt_list ( ) ?;
@@ -106,11 +96,7 @@ fn if_expr_to_guarded_return(
10696
10797 let parent_container = parent_block. syntax ( ) . parent ( ) ?;
10898
109- let early_expression: ast:: Expr = match parent_container. kind ( ) {
110- WHILE_EXPR | LOOP_EXPR | FOR_EXPR => make:: expr_continue ( None ) ,
111- FN => make:: expr_return ( None ) ,
112- _ => return None ,
113- } ;
99+ let early_expression: ast:: Expr = early_expression ( parent_container, & ctx. sema ) ?;
114100
115101 then_block. syntax ( ) . first_child_or_token ( ) . map ( |t| t. kind ( ) == T ! [ '{' ] ) ?;
116102
@@ -132,32 +118,42 @@ fn if_expr_to_guarded_return(
132118 target,
133119 |edit| {
134120 let if_indent_level = IndentLevel :: from_node ( if_expr. syntax ( ) ) ;
135- let replacement = match if_let_pat {
136- None => {
137- // If.
138- let new_expr = {
139- let then_branch =
140- make:: block_expr ( once ( make:: expr_stmt ( early_expression) . into ( ) ) , None ) ;
141- let cond = invert_boolean_expression_legacy ( cond_expr) ;
142- make:: expr_if ( cond, then_branch, None ) . indent ( if_indent_level)
143- } ;
144- new_expr. syntax ( ) . clone ( )
145- }
146- Some ( pat) => {
121+ let replacement = let_chains. into_iter ( ) . map ( |expr| {
122+ if let ast:: Expr :: LetExpr ( let_expr) = & expr
123+ && let ( Some ( pat) , Some ( expr) ) = ( let_expr. pat ( ) , let_expr. expr ( ) )
124+ {
147125 // If-let.
148126 let let_else_stmt = make:: let_else_stmt (
149127 pat,
150128 None ,
151- cond_expr ,
152- ast:: make:: tail_only_block_expr ( early_expression) ,
129+ expr ,
130+ ast:: make:: tail_only_block_expr ( early_expression. clone ( ) ) ,
153131 ) ;
154132 let let_else_stmt = let_else_stmt. indent ( if_indent_level) ;
155133 let_else_stmt. syntax ( ) . clone ( )
134+ } else {
135+ // If.
136+ let new_expr = {
137+ let then_branch = make:: block_expr (
138+ once ( make:: expr_stmt ( early_expression. clone ( ) ) . into ( ) ) ,
139+ None ,
140+ ) ;
141+ let cond = invert_boolean_expression_legacy ( expr) ;
142+ make:: expr_if ( cond, then_branch, None ) . indent ( if_indent_level)
143+ } ;
144+ new_expr. syntax ( ) . clone ( )
156145 }
157- } ;
146+ } ) ;
158147
148+ let newline = & format ! ( "\n {if_indent_level}" ) ;
159149 let then_statements = replacement
160- . children_with_tokens ( )
150+ . enumerate ( )
151+ . flat_map ( |( i, node) | {
152+ ( i != 0 )
153+ . then ( || make:: tokens:: whitespace ( newline) . into ( ) )
154+ . into_iter ( )
155+ . chain ( node. children_with_tokens ( ) )
156+ } )
161157 . chain (
162158 then_block_items
163159 . syntax ( )
@@ -201,11 +197,7 @@ fn let_stmt_to_guarded_return(
201197 let_stmt. syntax ( ) . parent ( ) ?. ancestors ( ) . find_map ( ast:: BlockExpr :: cast) ?;
202198 let parent_container = parent_block. syntax ( ) . parent ( ) ?;
203199
204- match parent_container. kind ( ) {
205- WHILE_EXPR | LOOP_EXPR | FOR_EXPR => make:: expr_continue ( None ) ,
206- FN => make:: expr_return ( None ) ,
207- _ => return None ,
208- }
200+ early_expression ( parent_container, & ctx. sema ) ?
209201 } ;
210202
211203 acc. add (
@@ -232,6 +224,54 @@ fn let_stmt_to_guarded_return(
232224 )
233225}
234226
227+ fn early_expression (
228+ parent_container : SyntaxNode ,
229+ sema : & Semantics < ' _ , RootDatabase > ,
230+ ) -> Option < ast:: Expr > {
231+ let return_none_expr = || {
232+ let none_expr = make:: expr_path ( make:: ext:: ident_path ( "None" ) ) ;
233+ make:: expr_return ( Some ( none_expr) )
234+ } ;
235+ if let Some ( fn_) = ast:: Fn :: cast ( parent_container. clone ( ) )
236+ && let Some ( fn_def) = sema. to_def ( & fn_)
237+ && let Some ( TryEnum :: Option ) = TryEnum :: from_ty ( sema, & fn_def. ret_type ( sema. db ) )
238+ {
239+ return Some ( return_none_expr ( ) ) ;
240+ }
241+ if let Some ( body) = ast:: ClosureExpr :: cast ( parent_container. clone ( ) ) . and_then ( |it| it. body ( ) )
242+ && let Some ( ret_ty) = sema. type_of_expr ( & body) . map ( TypeInfo :: original)
243+ && let Some ( TryEnum :: Option ) = TryEnum :: from_ty ( sema, & ret_ty)
244+ {
245+ return Some ( return_none_expr ( ) ) ;
246+ }
247+
248+ Some ( match parent_container. kind ( ) {
249+ WHILE_EXPR | LOOP_EXPR | FOR_EXPR => make:: expr_continue ( None ) ,
250+ FN | CLOSURE_EXPR => make:: expr_return ( None ) ,
251+ _ => return None ,
252+ } )
253+ }
254+
255+ fn flat_let_chain ( mut expr : ast:: Expr ) -> Vec < ast:: Expr > {
256+ let mut chains = vec ! [ ] ;
257+
258+ while let ast:: Expr :: BinExpr ( bin_expr) = & expr
259+ && bin_expr. op_kind ( ) == Some ( ast:: BinaryOp :: LogicOp ( ast:: LogicOp :: And ) )
260+ && let ( Some ( lhs) , Some ( rhs) ) = ( bin_expr. lhs ( ) , bin_expr. rhs ( ) )
261+ {
262+ if let Some ( last) = chains. pop_if ( |last| !matches ! ( last, ast:: Expr :: LetExpr ( _) ) ) {
263+ chains. push ( make:: expr_bin_op ( rhs, ast:: BinaryOp :: LogicOp ( ast:: LogicOp :: And ) , last) ) ;
264+ } else {
265+ chains. push ( rhs) ;
266+ }
267+ expr = lhs;
268+ }
269+
270+ chains. push ( expr) ;
271+ chains. reverse ( ) ;
272+ chains
273+ }
274+
235275#[ cfg( test) ]
236276mod tests {
237277 use crate :: tests:: { check_assist, check_assist_not_applicable} ;
@@ -268,6 +308,71 @@ fn main() {
268308 ) ;
269309 }
270310
311+ #[ test]
312+ fn convert_inside_fn_return_option ( ) {
313+ check_assist (
314+ convert_to_guarded_return,
315+ r#"
316+ //- minicore: option
317+ fn ret_option() -> Option<()> {
318+ bar();
319+ if$0 true {
320+ foo();
321+
322+ // comment
323+ bar();
324+ }
325+ }
326+ "# ,
327+ r#"
328+ fn ret_option() -> Option<()> {
329+ bar();
330+ if false {
331+ return None;
332+ }
333+ foo();
334+
335+ // comment
336+ bar();
337+ }
338+ "# ,
339+ ) ;
340+ }
341+
342+ #[ test]
343+ fn convert_inside_closure ( ) {
344+ check_assist (
345+ convert_to_guarded_return,
346+ r#"
347+ fn main() {
348+ let _f = || {
349+ bar();
350+ if$0 true {
351+ foo();
352+
353+ // comment
354+ bar();
355+ }
356+ }
357+ }
358+ "# ,
359+ r#"
360+ fn main() {
361+ let _f = || {
362+ bar();
363+ if false {
364+ return;
365+ }
366+ foo();
367+
368+ // comment
369+ bar();
370+ }
371+ }
372+ "# ,
373+ ) ;
374+ }
375+
271376 #[ test]
272377 fn convert_let_inside_fn ( ) {
273378 check_assist (
@@ -316,6 +421,82 @@ fn main() {
316421 ) ;
317422 }
318423
424+ #[ test]
425+ fn convert_if_let_result_inside_let ( ) {
426+ check_assist (
427+ convert_to_guarded_return,
428+ r#"
429+ fn main() {
430+ let _x = loop {
431+ if$0 let Ok(x) = Err(92) {
432+ foo(x);
433+ }
434+ };
435+ }
436+ "# ,
437+ r#"
438+ fn main() {
439+ let _x = loop {
440+ let Ok(x) = Err(92) else { continue };
441+ foo(x);
442+ };
443+ }
444+ "# ,
445+ ) ;
446+ }
447+
448+ #[ test]
449+ fn convert_if_let_chain_result ( ) {
450+ check_assist (
451+ convert_to_guarded_return,
452+ r#"
453+ fn main() {
454+ if$0 let Ok(x) = Err(92)
455+ && x < 30
456+ && let Some(y) = Some(8)
457+ {
458+ foo(x, y);
459+ }
460+ }
461+ "# ,
462+ r#"
463+ fn main() {
464+ let Ok(x) = Err(92) else { return };
465+ if x >= 30 {
466+ return;
467+ }
468+ let Some(y) = Some(8) else { return };
469+ foo(x, y);
470+ }
471+ "# ,
472+ ) ;
473+
474+ check_assist (
475+ convert_to_guarded_return,
476+ r#"
477+ fn main() {
478+ if$0 let Ok(x) = Err(92)
479+ && x < 30
480+ && y < 20
481+ && let Some(y) = Some(8)
482+ {
483+ foo(x, y);
484+ }
485+ }
486+ "# ,
487+ r#"
488+ fn main() {
489+ let Ok(x) = Err(92) else { return };
490+ if !(x < 30 && y < 20) {
491+ return;
492+ }
493+ let Some(y) = Some(8) else { return };
494+ foo(x, y);
495+ }
496+ "# ,
497+ ) ;
498+ }
499+
319500 #[ test]
320501 fn convert_let_ok_inside_fn ( ) {
321502 check_assist (
@@ -560,6 +741,32 @@ fn main() {
560741 ) ;
561742 }
562743
744+ #[ test]
745+ fn convert_let_stmt_inside_fn_return_option ( ) {
746+ check_assist (
747+ convert_to_guarded_return,
748+ r#"
749+ //- minicore: option
750+ fn foo() -> Option<i32> {
751+ None
752+ }
753+
754+ fn ret_option() -> Option<i32> {
755+ let x$0 = foo();
756+ }
757+ "# ,
758+ r#"
759+ fn foo() -> Option<i32> {
760+ None
761+ }
762+
763+ fn ret_option() -> Option<i32> {
764+ let Some(x) = foo() else { return None };
765+ }
766+ "# ,
767+ ) ;
768+ }
769+
563770 #[ test]
564771 fn convert_let_stmt_inside_loop ( ) {
565772 check_assist (
0 commit comments