@@ -2,13 +2,17 @@ use hir::Semantics;
22use ide_db:: {
33 base_db:: { FileId , FileRange } ,
44 defs:: Definition ,
5- search:: SearchScope ,
5+ search:: { SearchScope , UsageSearchResult } ,
66 RootDatabase ,
77} ;
88use syntax:: {
9- ast:: { self , make:: impl_trait_type, HasGenericParams , HasName , HasTypeBounds } ,
10- ted, AstNode ,
9+ ast:: {
10+ self , make:: impl_trait_type, HasGenericParams , HasName , HasTypeBounds , Name , NameLike ,
11+ PathType ,
12+ } ,
13+ match_ast, ted, AstNode ,
1114} ;
15+ use text_edit:: TextRange ;
1216
1317use crate :: { AssistContext , AssistId , AssistKind , Assists } ;
1418
@@ -36,87 +40,131 @@ pub(crate) fn replace_named_generic_with_impl(
3640 let type_bound_list = type_param. type_bound_list ( ) ?;
3741
3842 let fn_ = type_param. syntax ( ) . ancestors ( ) . find_map ( ast:: Fn :: cast) ?;
39- let params = fn_
40- . param_list ( ) ?
41- . params ( )
42- . filter_map ( |param| {
43- // function parameter type needs to match generic type name
44- if let ast:: Type :: PathType ( path_type) = param. ty ( ) ? {
45- let left = path_type. path ( ) ?. segment ( ) ?. name_ref ( ) ?. ident_token ( ) ?. to_string ( ) ;
46- let right = type_param_name. to_string ( ) ;
47- if left == right {
48- Some ( param)
49- } else {
50- None
51- }
52- } else {
53- None
54- }
55- } )
56- . collect :: < Vec < _ > > ( ) ;
57-
58- if params. is_empty ( ) {
59- return None ;
60- }
43+ let param_list_text_range = fn_. param_list ( ) ?. syntax ( ) . text_range ( ) ;
6144
6245 let type_param_hir_def = ctx. sema . to_def ( & type_param) ?;
6346 let type_param_def = Definition :: GenericParam ( hir:: GenericParam :: TypeParam ( type_param_hir_def) ) ;
6447
65- if is_referenced_outside ( & ctx. sema , type_param_def, & fn_, ctx. file_id ( ) ) {
48+ // get all usage references for the type param
49+ let usage_refs = find_usages ( & ctx. sema , & fn_, type_param_def, ctx. file_id ( ) ) ;
50+ if usage_refs. is_empty ( ) {
6651 return None ;
6752 }
6853
54+ // All usage references need to be valid (inside the function param list)
55+ if !check_valid_usages ( & usage_refs, param_list_text_range) {
56+ return None ;
57+ }
58+
59+ let mut path_types_to_replace = Vec :: new ( ) ;
60+ for ( _a, refs) in usage_refs. iter ( ) {
61+ for usage_ref in refs {
62+ let param_node = find_path_type ( & ctx. sema , & type_param_name, & usage_ref. name ) ?;
63+ path_types_to_replace. push ( param_node) ;
64+ }
65+ }
66+
6967 let target = type_param. syntax ( ) . text_range ( ) ;
7068
7169 acc. add (
7270 AssistId ( "replace_named_generic_with_impl" , AssistKind :: RefactorRewrite ) ,
73- "Replace named generic with impl" ,
71+ "Replace named generic with impl trait " ,
7472 target,
7573 |edit| {
7674 let type_param = edit. make_mut ( type_param) ;
7775 let fn_ = edit. make_mut ( fn_) ;
7876
79- // get all params
80- let param_types = params
81- . iter ( )
82- . filter_map ( |param| match param. ty ( ) {
83- Some ( ast:: Type :: PathType ( param_type) ) => Some ( edit. make_mut ( param_type) ) ,
84- _ => None ,
85- } )
77+ let path_types_to_replace = path_types_to_replace
78+ . into_iter ( )
79+ . map ( |param| edit. make_mut ( param) )
8680 . collect :: < Vec < _ > > ( ) ;
8781
82+ // remove trait from generic param list
8883 if let Some ( generic_params) = fn_. generic_param_list ( ) {
8984 generic_params. remove_generic_param ( ast:: GenericParam :: TypeParam ( type_param) ) ;
9085 if generic_params. generic_params ( ) . count ( ) == 0 {
9186 ted:: remove ( generic_params. syntax ( ) ) ;
9287 }
9388 }
9489
95- // get type bounds in signature type: `P` -> `impl AsRef<Path>`
9690 let new_bounds = impl_trait_type ( type_bound_list) ;
97- for param_type in param_types . iter ( ) . rev ( ) {
98- ted:: replace ( param_type . syntax ( ) , new_bounds. clone_for_update ( ) . syntax ( ) ) ;
91+ for path_type in path_types_to_replace . iter ( ) . rev ( ) {
92+ ted:: replace ( path_type . syntax ( ) , new_bounds. clone_for_update ( ) . syntax ( ) ) ;
9993 }
10094 } ,
10195 )
10296}
10397
104- fn is_referenced_outside (
98+ fn find_path_type (
99+ sema : & Semantics < ' _ , RootDatabase > ,
100+ type_param_name : & Name ,
101+ param : & NameLike ,
102+ ) -> Option < PathType > {
103+ let path_type =
104+ sema. ancestors_with_macros ( param. syntax ( ) . clone ( ) ) . find_map ( ast:: PathType :: cast) ?;
105+
106+ // Ignore any path types that look like `P::Assoc`
107+ if path_type. path ( ) ?. as_single_name_ref ( ) ?. text ( ) != type_param_name. text ( ) {
108+ return None ;
109+ }
110+
111+ let ancestors = sema. ancestors_with_macros ( path_type. syntax ( ) . clone ( ) ) ;
112+
113+ let mut in_generic_arg_list = false ;
114+ let mut is_associated_type = false ;
115+
116+ // walking the ancestors checks them in a heuristic way until the `Fn` node is reached.
117+ for ancestor in ancestors {
118+ match_ast ! {
119+ match ancestor {
120+ ast:: PathSegment ( ps) => {
121+ match ps. kind( ) ? {
122+ ast:: PathSegmentKind :: Name ( _name_ref) => ( ) ,
123+ ast:: PathSegmentKind :: Type { .. } => return None ,
124+ _ => return None ,
125+ }
126+ } ,
127+ ast:: GenericArgList ( _) => {
128+ in_generic_arg_list = true ;
129+ } ,
130+ ast:: AssocTypeArg ( _) => {
131+ is_associated_type = true ;
132+ } ,
133+ ast:: ImplTraitType ( _) => {
134+ if in_generic_arg_list && !is_associated_type {
135+ return None ;
136+ }
137+ } ,
138+ ast:: DynTraitType ( _) => {
139+ if !is_associated_type {
140+ return None ;
141+ }
142+ } ,
143+ ast:: Fn ( _) => return Some ( path_type) ,
144+ _ => ( ) ,
145+ }
146+ }
147+ }
148+
149+ None
150+ }
151+
152+ /// Returns all usage references for the given type parameter definition.
153+ fn find_usages (
105154 sema : & Semantics < ' _ , RootDatabase > ,
106- type_param_def : Definition ,
107155 fn_ : & ast:: Fn ,
156+ type_param_def : Definition ,
108157 file_id : FileId ,
109- ) -> bool {
110- // limit search scope to function body & return type
111- let search_ranges = vec ! [
112- fn_. body( ) . map( |body| body. syntax( ) . text_range( ) ) ,
113- fn_. ret_type( ) . map( |ret_type| ret_type. syntax( ) . text_range( ) ) ,
114- ] ;
115-
116- search_ranges. into_iter ( ) . flatten ( ) . any ( |search_range| {
117- let file_range = FileRange { file_id, range : search_range } ;
118- !type_param_def. usages ( sema) . in_scope ( SearchScope :: file_range ( file_range) ) . all ( ) . is_empty ( )
119- } )
158+ ) -> UsageSearchResult {
159+ let file_range = FileRange { file_id, range : fn_. syntax ( ) . text_range ( ) } ;
160+ type_param_def. usages ( sema) . in_scope ( SearchScope :: file_range ( file_range) ) . all ( )
161+ }
162+
163+ fn check_valid_usages ( usages : & UsageSearchResult , param_list_range : TextRange ) -> bool {
164+ usages
165+ . iter ( )
166+ . flat_map ( |( _, usage_refs) | usage_refs)
167+ . all ( |usage_ref| param_list_range. contains_range ( usage_ref. range ) )
120168}
121169
122170#[ cfg( test) ]
@@ -152,6 +200,96 @@ mod tests {
152200 ) ;
153201 }
154202
203+ #[ test]
204+ fn replace_generic_trait_applies_to_generic_arguments_in_params ( ) {
205+ check_assist (
206+ replace_named_generic_with_impl,
207+ r#"
208+ fn foo<P$0: Trait>(
209+ _: P,
210+ _: Option<P>,
211+ _: Option<Option<P>>,
212+ _: impl Iterator<Item = P>,
213+ _: &dyn Iterator<Item = P>,
214+ ) {}
215+ "# ,
216+ r#"
217+ fn foo(
218+ _: impl Trait,
219+ _: Option<impl Trait>,
220+ _: Option<Option<impl Trait>>,
221+ _: impl Iterator<Item = impl Trait>,
222+ _: &dyn Iterator<Item = impl Trait>,
223+ ) {}
224+ "# ,
225+ ) ;
226+ }
227+
228+ #[ test]
229+ fn replace_generic_not_applicable_when_one_param_type_is_invalid ( ) {
230+ check_assist_not_applicable (
231+ replace_named_generic_with_impl,
232+ r#"
233+ fn foo<P$0: Trait>(
234+ _: i32,
235+ _: Option<P>,
236+ _: Option<Option<P>>,
237+ _: impl Iterator<Item = P>,
238+ _: &dyn Iterator<Item = P>,
239+ _: <P as Trait>::Assoc,
240+ ) {}
241+ "# ,
242+ ) ;
243+ }
244+
245+ #[ test]
246+ fn replace_generic_not_applicable_when_referenced_in_where_clause ( ) {
247+ check_assist_not_applicable (
248+ replace_named_generic_with_impl,
249+ r#"fn foo<P$0: Trait, I>() where I: FromRef<P> {}"# ,
250+ ) ;
251+ }
252+
253+ #[ test]
254+ fn replace_generic_not_applicable_when_used_with_type_alias ( ) {
255+ check_assist_not_applicable (
256+ replace_named_generic_with_impl,
257+ r#"fn foo<P$0: Trait>(p: <P as Trait>::Assoc) {}"# ,
258+ ) ;
259+ }
260+
261+ #[ test]
262+ fn replace_generic_not_applicable_when_used_as_argument_in_outer_trait_alias ( ) {
263+ check_assist_not_applicable (
264+ replace_named_generic_with_impl,
265+ r#"fn foo<P$0: Trait>(_: <() as OtherTrait<P>>::Assoc) {}"# ,
266+ ) ;
267+ }
268+
269+ #[ test]
270+ fn replace_generic_not_applicable_with_inner_associated_type ( ) {
271+ check_assist_not_applicable (
272+ replace_named_generic_with_impl,
273+ r#"fn foo<P$0: Trait>(_: P::Assoc) {}"# ,
274+ ) ;
275+ }
276+
277+ #[ test]
278+ fn replace_generic_not_applicable_when_passed_into_outer_impl_trait ( ) {
279+ check_assist_not_applicable (
280+ replace_named_generic_with_impl,
281+ r#"fn foo<P$0: Trait>(_: impl OtherTrait<P>) {}"# ,
282+ ) ;
283+ }
284+
285+ #[ test]
286+ fn replace_generic_not_applicable_when_used_in_passed_function_parameter ( ) {
287+ check_assist_not_applicable (
288+ replace_named_generic_with_impl,
289+ r#"fn foo<P$0: Trait>(_: &dyn Fn(P)) {}"# ,
290+ ) ;
291+ }
292+
155293 #[ test]
156294 fn replace_generic_with_multiple_generic_params ( ) {
157295 check_assist (
0 commit comments