@@ -4,7 +4,7 @@ use syntax::{
44 algo,
55 ast:: { self , make, AstNode } ,
66 ted:: { self , Position } ,
7- AstToken , NodeOrToken , SyntaxToken , TextRange , T ,
7+ NodeOrToken , SyntaxToken , TextRange , T ,
88} ;
99
1010use crate :: { AssistContext , AssistId , AssistKind , Assists } ;
@@ -55,39 +55,46 @@ fn attempt_get_derive(attr: ast::Attr, ident: SyntaxToken) -> WrapUnwrapOption {
5555 syntax:: Direction :: Next ,
5656 ) ?;
5757 if ( prev. kind ( ) == T ! [ , ] || prev. kind ( ) == T ! [ '(' ] )
58- && ( following. kind ( ) == T ! [ , ] || following. kind ( ) == T ! [ '( ' ] )
58+ && ( following. kind ( ) == T ! [ , ] || following. kind ( ) == T ! [ ') ' ] )
5959 {
6060 // This would be a single ident such as Debug. As no path is present
6161 if following. kind ( ) == T ! [ , ] {
6262 derive = derive. cover ( following. text_range ( ) ) ;
63+ } else if following. kind ( ) == T ! [ ')' ] && prev. kind ( ) == T ! [ , ] {
64+ derive = derive. cover ( prev. text_range ( ) ) ;
6365 }
6466
6567 Some ( WrapUnwrapOption :: WrapDerive { derive, attr : attr. clone ( ) } )
6668 } else {
69+ let mut consumed_comma = false ;
6770 // Collect the path
68-
6971 while let Some ( prev_token) = algo:: skip_trivia_token ( prev, syntax:: Direction :: Prev )
7072 {
7173 let kind = prev_token. kind ( ) ;
72- if kind == T ! [ , ] || kind == T ! [ '(' ] {
74+ if kind == T ! [ , ] {
75+ consumed_comma = true ;
76+ derive = derive. cover ( prev_token. text_range ( ) ) ;
77+ break ;
78+ } else if kind == T ! [ '(' ] {
7379 break ;
80+ } else {
81+ derive = derive. cover ( prev_token. text_range ( ) ) ;
7482 }
75- derive = derive. cover ( prev_token. text_range ( ) ) ;
7683 prev = prev_token. prev_sibling_or_token ( ) ?. into_token ( ) ?;
7784 }
7885 while let Some ( next_token) =
7986 algo:: skip_trivia_token ( following. clone ( ) , syntax:: Direction :: Next )
8087 {
8188 let kind = next_token. kind ( ) ;
82- if kind != T ! [ ')' ] {
83- // We also want to consume a following comma
84- derive = derive. cover ( next_token. text_range ( ) ) ;
89+ match kind {
90+ T ! [ , ] if !consumed_comma => {
91+ derive = derive. cover ( next_token. text_range ( ) ) ;
92+ break ;
93+ }
94+ T ! [ ')' ] | T ! [ , ] => break ,
95+ _ => derive = derive. cover ( next_token. text_range ( ) ) ,
8596 }
8697 following = next_token. next_sibling_or_token ( ) ?. into_token ( ) ?;
87-
88- if kind == T ! [ , ] || kind == T ! [ ')' ] {
89- break ;
90- }
9198 }
9299 Some ( WrapUnwrapOption :: WrapDerive { derive, attr : attr. clone ( ) } )
93100 }
@@ -103,14 +110,15 @@ fn attempt_get_derive(attr: ast::Attr, ident: SyntaxToken) -> WrapUnwrapOption {
103110}
104111pub ( crate ) fn wrap_unwrap_cfg_attr ( acc : & mut Assists , ctx : & AssistContext < ' _ > ) -> Option < ( ) > {
105112 let option = if ctx. has_empty_selection ( ) {
106- let ident = ctx. find_token_at_offset :: < ast :: Ident > ( ) . map ( |v| v . syntax ( ) . clone ( ) ) ;
113+ let ident = ctx. find_token_syntax_at_offset ( T ! [ ident ] ) ;
107114 let attr = ctx. find_node_at_offset :: < ast:: Attr > ( ) ;
108115 match ( attr, ident) {
109116 ( Some ( attr) , Some ( ident) )
110117 if attr. simple_name ( ) . map ( |v| v. eq ( "derive" ) ) . unwrap_or_default ( ) =>
111118 {
112119 Some ( attempt_get_derive ( attr. clone ( ) , ident) )
113120 }
121+
114122 ( Some ( attr) , _) => Some ( WrapUnwrapOption :: WrapAttr ( attr) ) ,
115123 _ => None ,
116124 }
@@ -156,7 +164,7 @@ fn wrap_derive(
156164 }
157165
158166 if derive_element. contains_range ( token. text_range ( ) ) {
159- if token. kind ( ) != T ! [ , ] {
167+ if token. kind ( ) != T ! [ , ] && token . kind ( ) != syntax :: SyntaxKind :: WHITESPACE {
160168 path_text. push_str ( token. text ( ) ) ;
161169 cfg_derive_tokens. push ( NodeOrToken :: Token ( token) ) ;
162170 }
@@ -527,7 +535,42 @@ mod tests {
527535 }
528536 "# ,
529537 r#"
530- #[derive(Clone, Copy)]
538+ #[derive(Clone, Copy)]
539+ #[cfg_attr($0, derive(std::fmt::Debug))]
540+ pub struct Test {
541+ test: u32,
542+ }
543+ "# ,
544+ ) ;
545+ }
546+ #[ test]
547+ fn test_derive_wrap_at_end ( ) {
548+ check_assist (
549+ wrap_unwrap_cfg_attr,
550+ r#"
551+ #[derive(std::fmt::Debug, Clone, Cop$0y)]
552+ pub struct Test {
553+ test: u32,
554+ }
555+ "# ,
556+ r#"
557+ #[derive(std::fmt::Debug, Clone)]
558+ #[cfg_attr($0, derive(Copy))]
559+ pub struct Test {
560+ test: u32,
561+ }
562+ "# ,
563+ ) ;
564+ check_assist (
565+ wrap_unwrap_cfg_attr,
566+ r#"
567+ #[derive(Clone, Copy, std::fmt::D$0ebug)]
568+ pub struct Test {
569+ test: u32,
570+ }
571+ "# ,
572+ r#"
573+ #[derive(Clone, Copy)]
531574 #[cfg_attr($0, derive(std::fmt::Debug))]
532575 pub struct Test {
533576 test: u32,
0 commit comments