1- use std:: iter;
1+ use std:: { hash :: BuildHasherDefault , iter} ;
22
33use ast:: make;
44use either:: Either ;
5- use hir:: { HirDisplay , Local } ;
5+ use hir:: { HirDisplay , Local , Semantics } ;
66use ide_db:: {
77 defs:: { Definition , NameRefClass } ,
88 search:: { FileReference , ReferenceAccess , SearchScope } ,
9+ RootDatabase ,
910} ;
10- use itertools :: Itertools ;
11+ use rustc_hash :: FxHasher ;
1112use stdx:: format_to;
1213use syntax:: {
1314 ast:: {
@@ -25,6 +26,8 @@ use crate::{
2526 AssistId ,
2627} ;
2728
29+ type FxIndexSet < T > = indexmap:: IndexSet < T , BuildHasherDefault < FxHasher > > ;
30+
2831// Assist: extract_function
2932//
3033// Extracts selected statements into new function.
@@ -51,7 +54,8 @@ use crate::{
5154// }
5255// ```
5356pub ( crate ) fn extract_function ( acc : & mut Assists , ctx : & AssistContext ) -> Option < ( ) > {
54- if ctx. frange . range . is_empty ( ) {
57+ let range = ctx. frange . range ;
58+ if range. is_empty ( ) {
5559 return None ;
5660 }
5761
@@ -65,11 +69,9 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext) -> Option
6569 syntax:: NodeOrToken :: Node ( n) => n,
6670 syntax:: NodeOrToken :: Token ( t) => t. parent ( ) ?,
6771 } ;
72+ let body = extraction_target ( & node, range) ?;
6873
69- let body = extraction_target ( & node, ctx. frange . range ) ?;
70-
71- let vars_used_in_body = vars_used_in_body ( ctx, & body) ;
72- let self_param = self_param_from_usages ( ctx, & body, & vars_used_in_body) ;
74+ let ( locals_used, has_await, self_param) = analyze_body ( & ctx. sema , & body) ;
7375
7476 let anchor = if self_param. is_some ( ) { Anchor :: Method } else { Anchor :: Freestanding } ;
7577 let insert_after = scope_for_fn_insertion ( & body, anchor) ?;
@@ -95,7 +97,7 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext) -> Option
9597 "Extract into function" ,
9698 target_range,
9799 move |builder| {
98- let params = extracted_function_params ( ctx, & body, & vars_used_in_body ) ;
100+ let params = extracted_function_params ( ctx, & body, locals_used . iter ( ) . copied ( ) ) ;
99101
100102 let fun = Function {
101103 name : "fun_name" . to_string ( ) ,
@@ -109,15 +111,10 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext) -> Option
109111
110112 let new_indent = IndentLevel :: from_node ( & insert_after) ;
111113 let old_indent = fun. body . indent_level ( ) ;
112- let body_contains_await = body_contains_await ( & fun. body ) ;
113114
114- builder. replace (
115- target_range,
116- format_replacement ( ctx, & fun, old_indent, body_contains_await) ,
117- ) ;
115+ builder. replace ( target_range, format_replacement ( ctx, & fun, old_indent, has_await) ) ;
118116
119- let fn_def =
120- format_function ( ctx, module, & fun, old_indent, new_indent, body_contains_await) ;
117+ let fn_def = format_function ( ctx, module, & fun, old_indent, new_indent, has_await) ;
121118 let insert_offset = insert_after. text_range ( ) . end ( ) ;
122119 match ctx. config . snippet_cap {
123120 Some ( cap) => builder. insert_snippet ( cap, insert_offset, fn_def) ,
@@ -500,15 +497,59 @@ impl FunctionBody {
500497 }
501498 }
502499
503- fn descendants ( & self ) -> impl Iterator < Item = SyntaxNode > + ' _ {
500+ fn walk_expr ( & self , cb : & mut dyn FnMut ( ast:: Expr ) ) {
501+ match self {
502+ FunctionBody :: Expr ( expr) => expr. walk ( cb) ,
503+ FunctionBody :: Span { parent, text_range } => {
504+ parent
505+ . statements ( )
506+ . filter ( |stmt| text_range. contains_range ( stmt. syntax ( ) . text_range ( ) ) )
507+ . filter_map ( |stmt| match stmt {
508+ ast:: Stmt :: ExprStmt ( expr_stmt) => expr_stmt. expr ( ) ,
509+ ast:: Stmt :: Item ( _) => None ,
510+ ast:: Stmt :: LetStmt ( stmt) => stmt. initializer ( ) ,
511+ } )
512+ . for_each ( |expr| expr. walk ( cb) ) ;
513+ if let Some ( expr) = parent
514+ . tail_expr ( )
515+ . filter ( |it| text_range. contains_range ( it. syntax ( ) . text_range ( ) ) )
516+ {
517+ expr. walk ( cb) ;
518+ }
519+ }
520+ }
521+ }
522+
523+ fn walk_pat ( & self , cb : & mut dyn FnMut ( ast:: Pat ) ) {
504524 match self {
505- FunctionBody :: Expr ( expr) => Either :: Right ( expr. syntax ( ) . descendants ( ) ) ,
506- FunctionBody :: Span { parent, text_range } => Either :: Left (
525+ FunctionBody :: Expr ( expr) => expr. walk_patterns ( cb ) ,
526+ FunctionBody :: Span { parent, text_range } => {
507527 parent
508- . syntax ( )
509- . descendants ( )
510- . filter ( move |it| text_range. contains_range ( it. text_range ( ) ) ) ,
511- ) ,
528+ . statements ( )
529+ . filter ( |stmt| text_range. contains_range ( stmt. syntax ( ) . text_range ( ) ) )
530+ . for_each ( |stmt| match stmt {
531+ ast:: Stmt :: ExprStmt ( expr_stmt) => {
532+ if let Some ( expr) = expr_stmt. expr ( ) {
533+ expr. walk_patterns ( cb)
534+ }
535+ }
536+ ast:: Stmt :: Item ( _) => ( ) ,
537+ ast:: Stmt :: LetStmt ( stmt) => {
538+ if let Some ( pat) = stmt. pat ( ) {
539+ pat. walk ( cb) ;
540+ }
541+ if let Some ( expr) = stmt. initializer ( ) {
542+ expr. walk_patterns ( cb) ;
543+ }
544+ }
545+ } ) ;
546+ if let Some ( expr) = parent
547+ . tail_expr ( )
548+ . filter ( |it| text_range. contains_range ( it. syntax ( ) . text_range ( ) ) )
549+ {
550+ expr. walk_patterns ( cb) ;
551+ }
552+ }
512553 }
513554 }
514555
@@ -622,58 +663,48 @@ fn extraction_target(node: &SyntaxNode, selection_range: TextRange) -> Option<Fu
622663 node. ancestors ( ) . find_map ( ast:: Expr :: cast) . and_then ( FunctionBody :: from_expr)
623664}
624665
625- /// list local variables that are referenced in `body`
626- fn vars_used_in_body ( ctx : & AssistContext , body : & FunctionBody ) -> Vec < Local > {
666+ /// Analyzes a function body, returning the used local variables that are referenced in it as well as
667+ /// whether it contains an await expression.
668+ fn analyze_body (
669+ sema : & Semantics < RootDatabase > ,
670+ body : & FunctionBody ,
671+ ) -> ( FxIndexSet < Local > , bool , Option < ( Local , ast:: SelfParam ) > ) {
627672 // FIXME: currently usages inside macros are not found
628- body. descendants ( )
629- . filter_map ( ast:: NameRef :: cast)
630- . filter_map ( |name_ref| NameRefClass :: classify ( & ctx. sema , & name_ref) )
631- . map ( |name_kind| match name_kind {
632- NameRefClass :: Definition ( def) => def,
633- NameRefClass :: FieldShorthand { local_ref, field_ref : _ } => {
634- Definition :: Local ( local_ref)
673+ let mut has_await = false ;
674+ let mut self_param = None ;
675+ let mut res = FxIndexSet :: default ( ) ;
676+ body. walk_expr ( & mut |expr| {
677+ has_await |= matches ! ( expr, ast:: Expr :: AwaitExpr ( _) ) ;
678+ let name_ref = match expr {
679+ ast:: Expr :: PathExpr ( path_expr) => {
680+ path_expr. path ( ) . and_then ( |it| it. as_single_name_ref ( ) )
635681 }
636- } )
637- . filter_map ( |definition| match definition {
638- Definition :: Local ( local) => Some ( local) ,
639- _ => None ,
640- } )
641- . unique ( )
642- . collect ( )
643- }
644-
645- fn body_contains_await ( body : & FunctionBody ) -> bool {
646- body. descendants ( ) . any ( |d| matches ! ( d. kind( ) , SyntaxKind :: AWAIT_EXPR ) )
647- }
648-
649- /// find `self` param, that was not defined inside `body`
650- ///
651- /// It should skip `self` params from impls inside `body`
652- fn self_param_from_usages (
653- ctx : & AssistContext ,
654- body : & FunctionBody ,
655- vars_used_in_body : & [ Local ] ,
656- ) -> Option < ( Local , ast:: SelfParam ) > {
657- let mut iter = vars_used_in_body
658- . iter ( )
659- . filter ( |var| var. is_self ( ctx. db ( ) ) )
660- . map ( |var| ( var, var. source ( ctx. db ( ) ) ) )
661- . filter ( |( _, src) | is_defined_before ( ctx, body, src) )
662- . filter_map ( |( & node, src) | match src. value {
663- Either :: Right ( it) => Some ( ( node, it) ) ,
664- Either :: Left ( _) => {
665- stdx:: never!( false , "Local::is_self returned true, but source is IdentPat" ) ;
666- None
682+ _ => return ,
683+ } ;
684+ if let Some ( name_ref) = name_ref {
685+ if let Some (
686+ NameRefClass :: Definition ( Definition :: Local ( local_ref) )
687+ | NameRefClass :: FieldShorthand { local_ref, field_ref : _ } ,
688+ ) = NameRefClass :: classify ( sema, & name_ref)
689+ {
690+ res. insert ( local_ref) ;
691+ if local_ref. is_self ( sema. db ) {
692+ match local_ref. source ( sema. db ) . value {
693+ Either :: Right ( it) => {
694+ stdx:: always!(
695+ self_param. replace( ( local_ref, it) ) . is_none( ) ,
696+ "body references two different self params"
697+ ) ;
698+ }
699+ Either :: Left ( _) => {
700+ stdx:: never!( "Local::is_self returned true, but source is IdentPat" ) ;
701+ }
702+ }
703+ }
667704 }
668- } ) ;
669-
670- let self_param = iter. next ( ) ;
671- stdx:: always!(
672- iter. next( ) . is_none( ) ,
673- "body references two different self params, both defined outside"
674- ) ;
675-
676- self_param
705+ }
706+ } ) ;
707+ ( res, has_await, self_param)
677708}
678709
679710/// find variables that should be extracted as params
@@ -682,16 +713,15 @@ fn self_param_from_usages(
682713fn extracted_function_params (
683714 ctx : & AssistContext ,
684715 body : & FunctionBody ,
685- vars_used_in_body : & [ Local ] ,
716+ locals : impl Iterator < Item = Local > ,
686717) -> Vec < Param > {
687- vars_used_in_body
688- . iter ( )
689- . filter ( |var| !var. is_self ( ctx. db ( ) ) )
690- . map ( |node| ( node, node. source ( ctx. db ( ) ) ) )
691- . filter ( |( _, src) | is_defined_before ( ctx, body, src) )
692- . filter_map ( |( & node, src) | {
718+ locals
719+ . filter ( |local| !local. is_self ( ctx. db ( ) ) )
720+ . map ( |local| ( local, local. source ( ctx. db ( ) ) ) )
721+ . filter ( |( _, src) | is_defined_outside_of_body ( ctx, body, src) )
722+ . filter_map ( |( local, src) | {
693723 if src. value . is_left ( ) {
694- Some ( node )
724+ Some ( local )
695725 } else {
696726 stdx:: never!( false , "Local::is_self returned false, but source is SelfParam" ) ;
697727 None
@@ -838,14 +868,18 @@ fn path_element_of_reference(
838868}
839869
840870/// list local variables defined inside `body`
841- fn vars_defined_in_body ( body : & FunctionBody , ctx : & AssistContext ) -> Vec < Local > {
871+ fn locals_defined_in_body ( body : & FunctionBody , ctx : & AssistContext ) -> FxIndexSet < Local > {
842872 // FIXME: this doesn't work well with macros
843873 // see https://github.com/rust-analyzer/rust-analyzer/pull/7535#discussion_r570048550
844- body. descendants ( )
845- . filter_map ( ast:: IdentPat :: cast)
846- . filter_map ( |let_stmt| ctx. sema . to_def ( & let_stmt) )
847- . unique ( )
848- . collect ( )
874+ let mut res = FxIndexSet :: default ( ) ;
875+ body. walk_pat ( & mut |pat| {
876+ if let ast:: Pat :: IdentPat ( pat) = pat {
877+ if let Some ( local) = ctx. sema . to_def ( & pat) {
878+ res. insert ( local) ;
879+ }
880+ }
881+ } ) ;
882+ res
849883}
850884
851885/// list local variables defined inside `body` that should be returned from extracted function
@@ -854,15 +888,15 @@ fn vars_defined_in_body_and_outlive(
854888 body : & FunctionBody ,
855889 parent : & SyntaxNode ,
856890) -> Vec < OutlivedLocal > {
857- let vars_defined_in_body = vars_defined_in_body ( body, ctx) ;
891+ let vars_defined_in_body = locals_defined_in_body ( body, ctx) ;
858892 vars_defined_in_body
859893 . into_iter ( )
860894 . filter_map ( |var| var_outlives_body ( ctx, body, var, parent) )
861895 . collect ( )
862896}
863897
864898/// checks if the relevant local was defined before(outside of) body
865- fn is_defined_before (
899+ fn is_defined_outside_of_body (
866900 ctx : & AssistContext ,
867901 body : & FunctionBody ,
868902 src : & hir:: InFile < Either < ast:: IdentPat , ast:: SelfParam > > ,
0 commit comments