@@ -2,7 +2,7 @@ use std::iter;
22
33use ast:: make;
44use either:: Either ;
5- use hir:: { HirDisplay , InFile , Local , ModuleDef , Semantics , TypeInfo } ;
5+ use hir:: { HasSource , HirDisplay , InFile , Local , ModuleDef , Semantics , TypeInfo } ;
66use ide_db:: {
77 defs:: { Definition , NameRefClass } ,
88 famous_defs:: FamousDefs ,
@@ -27,6 +27,7 @@ use syntax::{
2727
2828use crate :: {
2929 assist_context:: { AssistContext , Assists , TreeMutator } ,
30+ utils:: generate_impl_text,
3031 AssistId ,
3132} ;
3233
@@ -106,6 +107,8 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext) -> Option
106107 let params =
107108 body. extracted_function_params ( ctx, & container_info, locals_used. iter ( ) . copied ( ) ) ;
108109
110+ let extracted_from_trait_impl = body. extracted_from_trait_impl ( ) ;
111+
109112 let name = make_function_name ( & semantics_scope) ;
110113
111114 let fun = Function {
@@ -124,8 +127,13 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext) -> Option
124127
125128 builder. replace ( target_range, make_call ( ctx, & fun, old_indent) ) ;
126129
127- let fn_def = format_function ( ctx, module, & fun, old_indent, new_indent) ;
128- let insert_offset = insert_after. text_range ( ) . end ( ) ;
130+ let fn_def = match fun. self_param_adt ( ctx) {
131+ Some ( adt) if extracted_from_trait_impl => {
132+ let fn_def = format_function ( ctx, module, & fun, old_indent, new_indent + 1 ) ;
133+ generate_impl_text ( & adt, & fn_def) . replace ( "{\n \n " , "{" )
134+ }
135+ _ => format_function ( ctx, module, & fun, old_indent, new_indent) ,
136+ } ;
129137
130138 if fn_def. contains ( "ControlFlow" ) {
131139 let scope = match scope {
@@ -150,6 +158,8 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext) -> Option
150158 }
151159 }
152160
161+ let insert_offset = insert_after. text_range ( ) . end ( ) ;
162+
153163 match ctx. config . snippet_cap {
154164 Some ( cap) => builder. insert_snippet ( cap, insert_offset, fn_def) ,
155165 None => builder. insert ( insert_offset, fn_def) ,
@@ -381,6 +391,14 @@ impl Function {
381391 } ,
382392 }
383393 }
394+
395+ fn self_param_adt ( & self , ctx : & AssistContext ) -> Option < ast:: Adt > {
396+ let self_param = self . self_param . as_ref ( ) ?;
397+ let def = ctx. sema . to_def ( self_param) ?;
398+ let adt = def. ty ( ctx. db ( ) ) . strip_references ( ) . as_adt ( ) ?;
399+ let InFile { file_id : _, value } = adt. source ( ctx. db ( ) ) ?;
400+ Some ( value)
401+ }
384402}
385403
386404impl ParamKind {
@@ -485,6 +503,20 @@ impl FunctionBody {
485503 }
486504 }
487505
506+ fn node ( & self ) -> & SyntaxNode {
507+ match self {
508+ FunctionBody :: Expr ( e) => e. syntax ( ) ,
509+ FunctionBody :: Span { parent, .. } => parent. syntax ( ) ,
510+ }
511+ }
512+
513+ fn extracted_from_trait_impl ( & self ) -> bool {
514+ match self . node ( ) . ancestors ( ) . find_map ( ast:: Impl :: cast) {
515+ Some ( c) => return c. trait_ ( ) . is_some ( ) ,
516+ None => false ,
517+ }
518+ }
519+
488520 fn from_expr ( expr : ast:: Expr ) -> Option < Self > {
489521 match expr {
490522 ast:: Expr :: BreakExpr ( it) => it. expr ( ) . map ( Self :: Expr ) ,
@@ -1111,10 +1143,7 @@ fn either_syntax(value: &Either<ast::IdentPat, ast::SelfParam>) -> &SyntaxNode {
11111143///
11121144/// Function should be put right after returned node
11131145fn node_to_insert_after ( body : & FunctionBody , anchor : Anchor ) -> Option < SyntaxNode > {
1114- let node = match body {
1115- FunctionBody :: Expr ( e) => e. syntax ( ) ,
1116- FunctionBody :: Span { parent, .. } => parent. syntax ( ) ,
1117- } ;
1146+ let node = body. node ( ) ;
11181147 let mut ancestors = node. ancestors ( ) . peekable ( ) ;
11191148 let mut last_ancestor = None ;
11201149 while let Some ( next_ancestor) = ancestors. next ( ) {
@@ -1126,9 +1155,8 @@ fn node_to_insert_after(body: &FunctionBody, anchor: Anchor) -> Option<SyntaxNod
11261155 break ;
11271156 }
11281157 }
1129- SyntaxKind :: ASSOC_ITEM_LIST if !matches ! ( anchor, Anchor :: Method ) => {
1130- continue ;
1131- }
1158+ SyntaxKind :: ASSOC_ITEM_LIST if !matches ! ( anchor, Anchor :: Method ) => continue ,
1159+ SyntaxKind :: ASSOC_ITEM_LIST if body. extracted_from_trait_impl ( ) => continue ,
11321160 SyntaxKind :: ASSOC_ITEM_LIST => {
11331161 if ancestors. peek ( ) . map ( SyntaxNode :: kind) == Some ( SyntaxKind :: IMPL ) {
11341162 break ;
@@ -4777,6 +4805,43 @@ fn fun_name() {
47774805fn $0fun_name2() {
47784806 let x = 0;
47794807}
4808+ "# ,
4809+ ) ;
4810+ }
4811+
4812+ #[ test]
4813+ fn extract_method_from_trait_impl ( ) {
4814+ check_assist (
4815+ extract_function,
4816+ r#"
4817+ struct Struct(i32);
4818+ trait Trait {
4819+ fn bar(&self) -> i32;
4820+ }
4821+
4822+ impl Trait for Struct {
4823+ fn bar(&self) -> i32 {
4824+ $0self.0 + 2$0
4825+ }
4826+ }
4827+ "# ,
4828+ r#"
4829+ struct Struct(i32);
4830+ trait Trait {
4831+ fn bar(&self) -> i32;
4832+ }
4833+
4834+ impl Trait for Struct {
4835+ fn bar(&self) -> i32 {
4836+ self.fun_name()
4837+ }
4838+ }
4839+
4840+ impl Struct {
4841+ fn $0fun_name(&self) -> i32 {
4842+ self.0 + 2
4843+ }
4844+ }
47804845"# ,
47814846 ) ;
47824847 }
0 commit comments