@@ -16,7 +16,7 @@ use syntax::{
1616 edit_in_place:: { AttrsOwnerEdit , Indent } ,
1717 make, HasName ,
1818 } ,
19- ted , AstNode , NodeOrToken , SyntaxKind , SyntaxNode , T ,
19+ AstNode , NodeOrToken , SyntaxKind , SyntaxNode , T ,
2020} ;
2121use text_edit:: TextRange ;
2222
@@ -73,7 +73,7 @@ pub(crate) fn bool_to_enum(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option
7373
7474 let usages = definition. usages ( & ctx. sema ) . all ( ) ;
7575 add_enum_def ( edit, ctx, & usages, target_node, & target_module) ;
76- replace_usages ( edit, ctx, & usages, definition, & target_module) ;
76+ replace_usages ( edit, ctx, usages, definition, & target_module) ;
7777 } ,
7878 )
7979}
@@ -192,58 +192,55 @@ fn bool_expr_to_enum_expr(expr: ast::Expr) -> ast::Expr {
192192fn replace_usages (
193193 edit : & mut SourceChangeBuilder ,
194194 ctx : & AssistContext < ' _ > ,
195- usages : & UsageSearchResult ,
195+ usages : UsageSearchResult ,
196196 target_definition : Definition ,
197197 target_module : & hir:: Module ,
198198) {
199- for ( file_id, references) in usages. iter ( ) {
200- edit. edit_file ( * file_id) ;
199+ for ( file_id, references) in usages {
200+ edit. edit_file ( file_id) ;
201201
202- let refs_with_imports =
203- augment_references_with_imports ( edit, ctx, references, target_module) ;
202+ let refs_with_imports = augment_references_with_imports ( ctx, references, target_module) ;
204203
205204 refs_with_imports. into_iter ( ) . rev ( ) . for_each (
206- |FileReferenceWithImport { range, old_name , new_name , import_data } | {
205+ |FileReferenceWithImport { range, name , import_data } | {
207206 // replace the usages in patterns and expressions
208- if let Some ( ident_pat) = old_name. syntax ( ) . ancestors ( ) . find_map ( ast:: IdentPat :: cast)
209- {
207+ if let Some ( ident_pat) = name. syntax ( ) . ancestors ( ) . find_map ( ast:: IdentPat :: cast) {
210208 cov_mark:: hit!( replaces_record_pat_shorthand) ;
211209
212210 let definition = ctx. sema . to_def ( & ident_pat) . map ( Definition :: Local ) ;
213211 if let Some ( def) = definition {
214212 replace_usages (
215213 edit,
216214 ctx,
217- & def. usages ( & ctx. sema ) . all ( ) ,
215+ def. usages ( & ctx. sema ) . all ( ) ,
218216 target_definition,
219217 target_module,
220218 )
221219 }
222- } else if let Some ( initializer) = find_assignment_usage ( & new_name ) {
220+ } else if let Some ( initializer) = find_assignment_usage ( & name ) {
223221 cov_mark:: hit!( replaces_assignment) ;
224222
225223 replace_bool_expr ( edit, initializer) ;
226- } else if let Some ( ( prefix_expr, inner_expr) ) = find_negated_usage ( & new_name ) {
224+ } else if let Some ( ( prefix_expr, inner_expr) ) = find_negated_usage ( & name ) {
227225 cov_mark:: hit!( replaces_negation) ;
228226
229227 edit. replace (
230228 prefix_expr. syntax ( ) . text_range ( ) ,
231229 format ! ( "{} == Bool::False" , inner_expr) ,
232230 ) ;
233- } else if let Some ( ( record_field, initializer) ) = old_name
231+ } else if let Some ( ( record_field, initializer) ) = name
234232 . as_name_ref ( )
235233 . and_then ( ast:: RecordExprField :: for_field_name)
236234 . and_then ( |record_field| ctx. sema . resolve_record_field ( & record_field) )
237235 . and_then ( |( got_field, _, _) | {
238- find_record_expr_usage ( & new_name , got_field, target_definition)
236+ find_record_expr_usage ( & name , got_field, target_definition)
239237 } )
240238 {
241239 cov_mark:: hit!( replaces_record_expr) ;
242240
243- let record_field = edit. make_mut ( record_field) ;
244241 let enum_expr = bool_expr_to_enum_expr ( initializer) ;
245- record_field . replace_expr ( enum_expr) ;
246- } else if let Some ( pat) = find_record_pat_field_usage ( & old_name ) {
242+ replace_record_field_expr ( edit , record_field , enum_expr) ;
243+ } else if let Some ( pat) = find_record_pat_field_usage ( & name ) {
247244 match pat {
248245 ast:: Pat :: IdentPat ( ident_pat) => {
249246 cov_mark:: hit!( replaces_record_pat) ;
@@ -253,7 +250,7 @@ fn replace_usages(
253250 replace_usages (
254251 edit,
255252 ctx,
256- & def. usages ( & ctx. sema ) . all ( ) ,
253+ def. usages ( & ctx. sema ) . all ( ) ,
257254 target_definition,
258255 target_module,
259256 )
@@ -270,79 +267,94 @@ fn replace_usages(
270267 }
271268 _ => ( ) ,
272269 }
273- } else if let Some ( ( ty_annotation, initializer) ) = find_assoc_const_usage ( & new_name)
274- {
270+ } else if let Some ( ( ty_annotation, initializer) ) = find_assoc_const_usage ( & name) {
275271 edit. replace ( ty_annotation. syntax ( ) . text_range ( ) , "Bool" ) ;
276272 replace_bool_expr ( edit, initializer) ;
277- } else if let Some ( receiver) = find_method_call_expr_usage ( & new_name ) {
273+ } else if let Some ( receiver) = find_method_call_expr_usage ( & name ) {
278274 edit. replace (
279275 receiver. syntax ( ) . text_range ( ) ,
280276 format ! ( "({} == Bool::True)" , receiver) ,
281277 ) ;
282- } else if new_name . syntax ( ) . ancestors ( ) . find_map ( ast:: UseTree :: cast) . is_none ( ) {
278+ } else if name . syntax ( ) . ancestors ( ) . find_map ( ast:: UseTree :: cast) . is_none ( ) {
283279 // for any other usage in an expression, replace it with a check that it is the true variant
284- if let Some ( ( record_field, expr) ) = new_name
285- . as_name_ref ( )
286- . and_then ( ast:: RecordExprField :: for_field_name)
287- . and_then ( |record_field| {
288- record_field. expr ( ) . map ( |expr| ( record_field, expr) )
289- } )
280+ if let Some ( ( record_field, expr) ) =
281+ name. as_name_ref ( ) . and_then ( ast:: RecordExprField :: for_field_name) . and_then (
282+ |record_field| record_field. expr ( ) . map ( |expr| ( record_field, expr) ) ,
283+ )
290284 {
291- record_field. replace_expr (
285+ replace_record_field_expr (
286+ edit,
287+ record_field,
292288 make:: expr_bin_op (
293289 expr,
294290 ast:: BinaryOp :: CmpOp ( ast:: CmpOp :: Eq { negated : false } ) ,
295291 make:: expr_path ( make:: path_from_text ( "Bool::True" ) ) ,
296- )
297- . clone_for_update ( ) ,
292+ ) ,
298293 ) ;
299294 } else {
300- edit. replace ( range, format ! ( "{} == Bool::True" , new_name . text( ) ) ) ;
295+ edit. replace ( range, format ! ( "{} == Bool::True" , name . text( ) ) ) ;
301296 }
302297 }
303298
304299 // add imports across modules where needed
305300 if let Some ( ( import_scope, path) ) = import_data {
306- insert_use ( & import_scope, path, & ctx. config . insert_use ) ;
301+ let scope = match import_scope. clone ( ) {
302+ ImportScope :: File ( it) => ImportScope :: File ( edit. make_mut ( it) ) ,
303+ ImportScope :: Module ( it) => ImportScope :: Module ( edit. make_mut ( it) ) ,
304+ ImportScope :: Block ( it) => ImportScope :: Block ( edit. make_mut ( it) ) ,
305+ } ;
306+ insert_use ( & scope, path, & ctx. config . insert_use ) ;
307307 }
308308 } ,
309309 )
310310 }
311311}
312312
313+ /// Replaces the record expression, handling field shorthands.
314+ fn replace_record_field_expr (
315+ edit : & mut SourceChangeBuilder ,
316+ record_field : ast:: RecordExprField ,
317+ initializer : ast:: Expr ,
318+ ) {
319+ if let Some ( ast:: Expr :: PathExpr ( path_expr) ) = record_field. expr ( ) {
320+ // replace field shorthand
321+ edit. insert (
322+ path_expr. syntax ( ) . text_range ( ) . end ( ) ,
323+ format ! ( ": {}" , initializer. syntax( ) . text( ) ) ,
324+ )
325+ } else if let Some ( expr) = record_field. expr ( ) {
326+ // just replace expr
327+ edit. replace_ast ( expr, initializer) ;
328+ }
329+ }
330+
313331struct FileReferenceWithImport {
314332 range : TextRange ,
315- old_name : ast:: NameLike ,
316- new_name : ast:: NameLike ,
333+ name : ast:: NameLike ,
317334 import_data : Option < ( ImportScope , ast:: Path ) > ,
318335}
319336
320337fn augment_references_with_imports (
321- edit : & mut SourceChangeBuilder ,
322338 ctx : & AssistContext < ' _ > ,
323- references : & [ FileReference ] ,
339+ references : Vec < FileReference > ,
324340 target_module : & hir:: Module ,
325341) -> Vec < FileReferenceWithImport > {
326342 let mut visited_modules = FxHashSet :: default ( ) ;
327343
328344 references
329- . iter ( )
345+ . into_iter ( )
330346 . filter_map ( |FileReference { range, name, .. } | {
331347 let name = name. clone ( ) . into_name_like ( ) ?;
332- ctx. sema . scope ( name. syntax ( ) ) . map ( |scope| ( * range, name, scope. module ( ) ) )
348+ ctx. sema . scope ( name. syntax ( ) ) . map ( |scope| ( range, name, scope. module ( ) ) )
333349 } )
334350 . map ( |( range, name, ref_module) | {
335- let old_name = name. clone ( ) ;
336- let new_name = edit. make_mut ( name. clone ( ) ) ;
337-
338351 // if the referenced module is not the same as the target one and has not been seen before, add an import
339352 let import_data = if ref_module. nearest_non_block_module ( ctx. db ( ) ) != * target_module
340353 && !visited_modules. contains ( & ref_module)
341354 {
342355 visited_modules. insert ( ref_module) ;
343356
344- let import_scope =
345- ImportScope :: find_insert_use_container ( new_name. syntax ( ) , & ctx. sema ) ;
357+ let import_scope = ImportScope :: find_insert_use_container ( name. syntax ( ) , & ctx. sema ) ;
346358 let path = ref_module
347359 . find_use_path_prefixed (
348360 ctx. sema . db ,
@@ -360,7 +372,7 @@ fn augment_references_with_imports(
360372 None
361373 } ;
362374
363- FileReferenceWithImport { range, old_name , new_name , import_data }
375+ FileReferenceWithImport { range, name , import_data }
364376 } )
365377 . collect ( )
366378}
@@ -465,12 +477,9 @@ fn add_enum_def(
465477 let indent = IndentLevel :: from_node ( & insert_before) ;
466478 enum_def. reindent_to ( indent) ;
467479
468- ted:: insert_all (
469- ted:: Position :: before ( & edit. make_syntax_mut ( insert_before) ) ,
470- vec ! [
471- enum_def. syntax( ) . clone( ) . into( ) ,
472- make:: tokens:: whitespace( & format!( "\n \n {indent}" ) ) . into( ) ,
473- ] ,
480+ edit. insert (
481+ insert_before. text_range ( ) . start ( ) ,
482+ format ! ( "{}\n \n {indent}" , enum_def. syntax( ) . text( ) ) ,
474483 ) ;
475484}
476485
0 commit comments