@@ -11,7 +11,9 @@ use ide_db::{
1111 helpers:: mod_path_to_ast,
1212 imports:: insert_use:: { insert_use, ImportScope } ,
1313 search:: { FileReference , ReferenceCategory , SearchScope } ,
14- syntax_helpers:: node_ext:: { preorder_expr, walk_expr, walk_pat, walk_patterns_in_expr} ,
14+ syntax_helpers:: node_ext:: {
15+ for_each_tail_expr, preorder_expr, walk_expr, walk_pat, walk_patterns_in_expr,
16+ } ,
1517 FxIndexSet , RootDatabase ,
1618} ;
1719use itertools:: Itertools ;
@@ -78,7 +80,7 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext<'_>) -> Op
7880 } ;
7981
8082 let body = extraction_target ( & node, range) ?;
81- let container_info = body. analyze_container ( & ctx. sema ) ?;
83+ let ( container_info, contains_tail_expr ) = body. analyze_container ( & ctx. sema ) ?;
8284
8385 let ( locals_used, self_param) = body. analyze ( & ctx. sema ) ;
8486
@@ -119,6 +121,7 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext<'_>) -> Op
119121 ret_ty,
120122 body,
121123 outliving_locals,
124+ contains_tail_expr,
122125 mods : container_info,
123126 } ;
124127
@@ -245,6 +248,8 @@ struct Function {
245248 ret_ty : RetType ,
246249 body : FunctionBody ,
247250 outliving_locals : Vec < OutlivedLocal > ,
251+ /// Whether at least one of the container's tail expr is contained in the range we're extracting.
252+ contains_tail_expr : bool ,
248253 mods : ContainerInfo ,
249254}
250255
@@ -265,7 +270,7 @@ enum ParamKind {
265270 MutRef ,
266271}
267272
268- #[ derive( Debug , Eq , PartialEq ) ]
273+ #[ derive( Debug ) ]
269274enum FunType {
270275 Unit ,
271276 Single ( hir:: Type ) ,
@@ -294,7 +299,6 @@ struct ControlFlow {
294299#[ derive( Clone , Debug ) ]
295300struct ContainerInfo {
296301 is_const : bool ,
297- is_in_tail : bool ,
298302 parent_loop : Option < SyntaxNode > ,
299303 /// The function's return type, const's type etc.
300304 ret_type : Option < hir:: Type > ,
@@ -743,7 +747,10 @@ impl FunctionBody {
743747 ( res, self_param)
744748 }
745749
746- fn analyze_container ( & self , sema : & Semantics < ' _ , RootDatabase > ) -> Option < ContainerInfo > {
750+ fn analyze_container (
751+ & self ,
752+ sema : & Semantics < ' _ , RootDatabase > ,
753+ ) -> Option < ( ContainerInfo , bool ) > {
747754 let mut ancestors = self . parent ( ) ?. ancestors ( ) ;
748755 let infer_expr_opt = |expr| sema. type_of_expr ( & expr?) . map ( TypeInfo :: adjusted) ;
749756 let mut parent_loop = None ;
@@ -815,28 +822,36 @@ impl FunctionBody {
815822 }
816823 } ;
817824 } ;
818- let container_tail = match expr? {
819- ast:: Expr :: BlockExpr ( block) => block. tail_expr ( ) ,
820- expr => Some ( expr) ,
821- } ;
822- let is_in_tail =
823- container_tail. zip ( self . tail_expr ( ) ) . map_or ( false , |( container_tail, body_tail) | {
824- container_tail. syntax ( ) . text_range ( ) . contains_range ( body_tail. syntax ( ) . text_range ( ) )
825+
826+ let expr = expr?;
827+ let contains_tail_expr = if let Some ( body_tail) = self . tail_expr ( ) {
828+ let mut contains_tail_expr = false ;
829+ let tail_expr_range = body_tail. syntax ( ) . text_range ( ) ;
830+ for_each_tail_expr ( & expr, & mut |e| {
831+ if tail_expr_range. contains_range ( e. syntax ( ) . text_range ( ) ) {
832+ contains_tail_expr = true ;
833+ }
825834 } ) ;
835+ contains_tail_expr
836+ } else {
837+ false
838+ } ;
826839
827840 let parent = self . parent ( ) ?;
828841 let parents = generic_parents ( & parent) ;
829842 let generic_param_lists = parents. iter ( ) . filter_map ( |it| it. generic_param_list ( ) ) . collect ( ) ;
830843 let where_clauses = parents. iter ( ) . filter_map ( |it| it. where_clause ( ) ) . collect ( ) ;
831844
832- Some ( ContainerInfo {
833- is_in_tail,
834- is_const,
835- parent_loop,
836- ret_type : ty,
837- generic_param_lists,
838- where_clauses,
839- } )
845+ Some ( (
846+ ContainerInfo {
847+ is_const,
848+ parent_loop,
849+ ret_type : ty,
850+ generic_param_lists,
851+ where_clauses,
852+ } ,
853+ contains_tail_expr,
854+ ) )
840855 }
841856
842857 fn return_ty ( & self , ctx : & AssistContext < ' _ > ) -> Option < RetType > {
@@ -1368,7 +1383,7 @@ impl FlowHandler {
13681383 None => FlowHandler :: None ,
13691384 Some ( flow_kind) => {
13701385 let action = flow_kind. clone ( ) ;
1371- if * ret_ty == FunType :: Unit {
1386+ if let FunType :: Unit = ret_ty {
13721387 match flow_kind {
13731388 FlowKind :: Return ( None )
13741389 | FlowKind :: Break ( _, None )
@@ -1633,7 +1648,7 @@ impl Function {
16331648
16341649 fn make_ret_ty ( & self , ctx : & AssistContext < ' _ > , module : hir:: Module ) -> Option < ast:: RetType > {
16351650 let fun_ty = self . return_type ( ctx) ;
1636- let handler = if self . mods . is_in_tail {
1651+ let handler = if self . contains_tail_expr {
16371652 FlowHandler :: None
16381653 } else {
16391654 FlowHandler :: from_ret_ty ( self , & fun_ty)
@@ -1707,7 +1722,7 @@ fn make_body(
17071722 fun : & Function ,
17081723) -> ast:: BlockExpr {
17091724 let ret_ty = fun. return_type ( ctx) ;
1710- let handler = if fun. mods . is_in_tail {
1725+ let handler = if fun. contains_tail_expr {
17111726 FlowHandler :: None
17121727 } else {
17131728 FlowHandler :: from_ret_ty ( fun, & ret_ty)
@@ -1946,7 +1961,7 @@ fn update_external_control_flow(handler: &FlowHandler, syntax: &SyntaxNode) {
19461961 if nested_scope. is_none ( ) {
19471962 if let Some ( expr) = ast:: Expr :: cast ( e. clone ( ) ) {
19481963 match expr {
1949- ast:: Expr :: ReturnExpr ( return_expr) if nested_scope . is_none ( ) => {
1964+ ast:: Expr :: ReturnExpr ( return_expr) => {
19501965 let expr = return_expr. expr ( ) ;
19511966 if let Some ( replacement) = make_rewritten_flow ( handler, expr) {
19521967 ted:: replace ( return_expr. syntax ( ) , replacement. syntax ( ) )
@@ -5582,6 +5597,153 @@ impl <T, U> Struct<T, U> where T: Into<i32> + Copy, U: Debug {
55825597fn $0fun_name<T, V>(t: T, v: V) -> i32 where T: Into<i32> + Copy, V: Into<i32> {
55835598 t.into() + v.into()
55845599}
5600+ "# ,
5601+ ) ;
5602+ }
5603+
5604+ #[ test]
5605+ fn non_tail_expr_of_tail_expr_loop ( ) {
5606+ check_assist (
5607+ extract_function,
5608+ r#"
5609+ pub fn f() {
5610+ loop {
5611+ $0if true {
5612+ continue;
5613+ }$0
5614+
5615+ if false {
5616+ break;
5617+ }
5618+ }
5619+ }
5620+ "# ,
5621+ r#"
5622+ pub fn f() {
5623+ loop {
5624+ if let ControlFlow::Break(_) = fun_name() {
5625+ continue;
5626+ }
5627+
5628+ if false {
5629+ break;
5630+ }
5631+ }
5632+ }
5633+
5634+ fn $0fun_name() -> ControlFlow<()> {
5635+ if true {
5636+ return ControlFlow::Break(());
5637+ }
5638+ ControlFlow::Continue(())
5639+ }
5640+ "# ,
5641+ ) ;
5642+ }
5643+
5644+ #[ test]
5645+ fn non_tail_expr_of_tail_if_block ( ) {
5646+ // FIXME: double semicolon
5647+ check_assist (
5648+ extract_function,
5649+ r#"
5650+ //- minicore: option, try
5651+ impl<T> core::ops::Try for Option<T> {
5652+ type Output = T;
5653+ type Residual = Option<!>;
5654+ }
5655+ impl<T> core::ops::FromResidual for Option<T> {}
5656+
5657+ fn f() -> Option<()> {
5658+ if true {
5659+ let a = $0if true {
5660+ Some(())?
5661+ } else {
5662+ ()
5663+ }$0;
5664+ Some(a)
5665+ } else {
5666+ None
5667+ }
5668+ }
5669+ "# ,
5670+ r#"
5671+ impl<T> core::ops::Try for Option<T> {
5672+ type Output = T;
5673+ type Residual = Option<!>;
5674+ }
5675+ impl<T> core::ops::FromResidual for Option<T> {}
5676+
5677+ fn f() -> Option<()> {
5678+ if true {
5679+ let a = fun_name()?;;
5680+ Some(a)
5681+ } else {
5682+ None
5683+ }
5684+ }
5685+
5686+ fn $0fun_name() -> Option<()> {
5687+ Some(if true {
5688+ Some(())?
5689+ } else {
5690+ ()
5691+ })
5692+ }
5693+ "# ,
5694+ ) ;
5695+ }
5696+
5697+ #[ test]
5698+ fn tail_expr_of_tail_block_nested ( ) {
5699+ check_assist (
5700+ extract_function,
5701+ r#"
5702+ //- minicore: option, try
5703+ impl<T> core::ops::Try for Option<T> {
5704+ type Output = T;
5705+ type Residual = Option<!>;
5706+ }
5707+ impl<T> core::ops::FromResidual for Option<T> {}
5708+
5709+ fn f() -> Option<()> {
5710+ if true {
5711+ $0{
5712+ let a = if true {
5713+ Some(())?
5714+ } else {
5715+ ()
5716+ };
5717+ Some(a)
5718+ }$0
5719+ } else {
5720+ None
5721+ }
5722+ }
5723+ "# ,
5724+ r#"
5725+ impl<T> core::ops::Try for Option<T> {
5726+ type Output = T;
5727+ type Residual = Option<!>;
5728+ }
5729+ impl<T> core::ops::FromResidual for Option<T> {}
5730+
5731+ fn f() -> Option<()> {
5732+ if true {
5733+ fun_name()?
5734+ } else {
5735+ None
5736+ }
5737+ }
5738+
5739+ fn $0fun_name() -> Option<()> {
5740+ let a = if true {
5741+ Some(())?
5742+ } else {
5743+ ()
5744+ };
5745+ Some(a)
5746+ }
55855747"# ,
55865748 ) ;
55875749 }
0 commit comments