@@ -16,11 +16,14 @@ use syntax::{
1616 edit_in_place:: { AttrsOwnerEdit , Indent } ,
1717 make, HasName ,
1818 } ,
19- ted , AstNode , NodeOrToken , SyntaxKind , SyntaxNode , T ,
19+ AstNode , NodeOrToken , SyntaxKind , SyntaxNode , T ,
2020} ;
2121use text_edit:: TextRange ;
2222
23- use crate :: assist_context:: { AssistContext , Assists } ;
23+ use crate :: {
24+ assist_context:: { AssistContext , Assists } ,
25+ utils,
26+ } ;
2427
2528// Assist: bool_to_enum
2629//
@@ -73,7 +76,7 @@ pub(crate) fn bool_to_enum(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option
7376
7477 let usages = definition. usages ( & ctx. sema ) . all ( ) ;
7578 add_enum_def ( edit, ctx, & usages, target_node, & target_module) ;
76- replace_usages ( edit, ctx, & usages, definition, & target_module) ;
79+ replace_usages ( edit, ctx, usages, definition, & target_module) ;
7780 } ,
7881 )
7982}
@@ -169,8 +172,8 @@ fn replace_bool_expr(edit: &mut SourceChangeBuilder, expr: ast::Expr) {
169172
170173/// Converts an expression of type `bool` to one of the new enum type.
171174fn bool_expr_to_enum_expr ( expr : ast:: Expr ) -> ast:: Expr {
172- let true_expr = make:: expr_path ( make:: path_from_text ( "Bool::True" ) ) . clone_for_update ( ) ;
173- let false_expr = make:: expr_path ( make:: path_from_text ( "Bool::False" ) ) . clone_for_update ( ) ;
175+ let true_expr = make:: expr_path ( make:: path_from_text ( "Bool::True" ) ) ;
176+ let false_expr = make:: expr_path ( make:: path_from_text ( "Bool::False" ) ) ;
174177
175178 if let ast:: Expr :: Literal ( literal) = & expr {
176179 match literal. kind ( ) {
@@ -184,66 +187,62 @@ fn bool_expr_to_enum_expr(expr: ast::Expr) -> ast::Expr {
184187 make:: tail_only_block_expr ( true_expr) ,
185188 Some ( ast:: ElseBranch :: Block ( make:: tail_only_block_expr ( false_expr) ) ) ,
186189 )
187- . clone_for_update ( )
188190 }
189191}
190192
191193/// Replaces all usages of the target identifier, both when read and written to.
192194fn replace_usages (
193195 edit : & mut SourceChangeBuilder ,
194196 ctx : & AssistContext < ' _ > ,
195- usages : & UsageSearchResult ,
197+ usages : UsageSearchResult ,
196198 target_definition : Definition ,
197199 target_module : & hir:: Module ,
198200) {
199- for ( file_id, references) in usages. iter ( ) {
200- edit. edit_file ( * file_id) ;
201+ for ( file_id, references) in usages {
202+ edit. edit_file ( file_id) ;
201203
202- let refs_with_imports =
203- augment_references_with_imports ( edit, ctx, references, target_module) ;
204+ let refs_with_imports = augment_references_with_imports ( ctx, references, target_module) ;
204205
205206 refs_with_imports. into_iter ( ) . rev ( ) . for_each (
206- |FileReferenceWithImport { range, old_name , new_name , import_data } | {
207+ |FileReferenceWithImport { range, name , import_data } | {
207208 // replace the usages in patterns and expressions
208- if let Some ( ident_pat) = old_name. syntax ( ) . ancestors ( ) . find_map ( ast:: IdentPat :: cast)
209- {
209+ if let Some ( ident_pat) = name. syntax ( ) . ancestors ( ) . find_map ( ast:: IdentPat :: cast) {
210210 cov_mark:: hit!( replaces_record_pat_shorthand) ;
211211
212212 let definition = ctx. sema . to_def ( & ident_pat) . map ( Definition :: Local ) ;
213213 if let Some ( def) = definition {
214214 replace_usages (
215215 edit,
216216 ctx,
217- & def. usages ( & ctx. sema ) . all ( ) ,
217+ def. usages ( & ctx. sema ) . all ( ) ,
218218 target_definition,
219219 target_module,
220220 )
221221 }
222- } else if let Some ( initializer) = find_assignment_usage ( & new_name ) {
222+ } else if let Some ( initializer) = find_assignment_usage ( & name ) {
223223 cov_mark:: hit!( replaces_assignment) ;
224224
225225 replace_bool_expr ( edit, initializer) ;
226- } else if let Some ( ( prefix_expr, inner_expr) ) = find_negated_usage ( & new_name ) {
226+ } else if let Some ( ( prefix_expr, inner_expr) ) = find_negated_usage ( & name ) {
227227 cov_mark:: hit!( replaces_negation) ;
228228
229229 edit. replace (
230230 prefix_expr. syntax ( ) . text_range ( ) ,
231231 format ! ( "{} == Bool::False" , inner_expr) ,
232232 ) ;
233- } else if let Some ( ( record_field, initializer) ) = old_name
233+ } else if let Some ( ( record_field, initializer) ) = name
234234 . as_name_ref ( )
235235 . and_then ( ast:: RecordExprField :: for_field_name)
236236 . and_then ( |record_field| ctx. sema . resolve_record_field ( & record_field) )
237237 . and_then ( |( got_field, _, _) | {
238- find_record_expr_usage ( & new_name , got_field, target_definition)
238+ find_record_expr_usage ( & name , got_field, target_definition)
239239 } )
240240 {
241241 cov_mark:: hit!( replaces_record_expr) ;
242242
243- let record_field = edit. make_mut ( record_field) ;
244243 let enum_expr = bool_expr_to_enum_expr ( initializer) ;
245- record_field . replace_expr ( enum_expr) ;
246- } else if let Some ( pat) = find_record_pat_field_usage ( & old_name ) {
244+ utils :: replace_record_field_expr ( ctx , edit , record_field , enum_expr) ;
245+ } else if let Some ( pat) = find_record_pat_field_usage ( & name ) {
247246 match pat {
248247 ast:: Pat :: IdentPat ( ident_pat) => {
249248 cov_mark:: hit!( replaces_record_pat) ;
@@ -253,7 +252,7 @@ fn replace_usages(
253252 replace_usages (
254253 edit,
255254 ctx,
256- & def. usages ( & ctx. sema ) . all ( ) ,
255+ def. usages ( & ctx. sema ) . all ( ) ,
257256 target_definition,
258257 target_module,
259258 )
@@ -270,40 +269,44 @@ fn replace_usages(
270269 }
271270 _ => ( ) ,
272271 }
273- } else if let Some ( ( ty_annotation, initializer) ) = find_assoc_const_usage ( & new_name)
274- {
272+ } else if let Some ( ( ty_annotation, initializer) ) = find_assoc_const_usage ( & name) {
275273 edit. replace ( ty_annotation. syntax ( ) . text_range ( ) , "Bool" ) ;
276274 replace_bool_expr ( edit, initializer) ;
277- } else if let Some ( receiver) = find_method_call_expr_usage ( & new_name ) {
275+ } else if let Some ( receiver) = find_method_call_expr_usage ( & name ) {
278276 edit. replace (
279277 receiver. syntax ( ) . text_range ( ) ,
280278 format ! ( "({} == Bool::True)" , receiver) ,
281279 ) ;
282- } else if new_name . syntax ( ) . ancestors ( ) . find_map ( ast:: UseTree :: cast) . is_none ( ) {
280+ } else if name . syntax ( ) . ancestors ( ) . find_map ( ast:: UseTree :: cast) . is_none ( ) {
283281 // for any other usage in an expression, replace it with a check that it is the true variant
284- if let Some ( ( record_field, expr) ) = new_name
285- . as_name_ref ( )
286- . and_then ( ast:: RecordExprField :: for_field_name)
287- . and_then ( |record_field| {
288- record_field. expr ( ) . map ( |expr| ( record_field, expr) )
289- } )
282+ if let Some ( ( record_field, expr) ) =
283+ name. as_name_ref ( ) . and_then ( ast:: RecordExprField :: for_field_name) . and_then (
284+ |record_field| record_field. expr ( ) . map ( |expr| ( record_field, expr) ) ,
285+ )
290286 {
291- record_field. replace_expr (
287+ utils:: replace_record_field_expr (
288+ ctx,
289+ edit,
290+ record_field,
292291 make:: expr_bin_op (
293292 expr,
294293 ast:: BinaryOp :: CmpOp ( ast:: CmpOp :: Eq { negated : false } ) ,
295294 make:: expr_path ( make:: path_from_text ( "Bool::True" ) ) ,
296- )
297- . clone_for_update ( ) ,
295+ ) ,
298296 ) ;
299297 } else {
300- edit. replace ( range, format ! ( "{} == Bool::True" , new_name . text( ) ) ) ;
298+ edit. replace ( range, format ! ( "{} == Bool::True" , name . text( ) ) ) ;
301299 }
302300 }
303301
304302 // add imports across modules where needed
305303 if let Some ( ( import_scope, path) ) = import_data {
306- insert_use ( & import_scope, path, & ctx. config . insert_use ) ;
304+ let scope = match import_scope. clone ( ) {
305+ ImportScope :: File ( it) => ImportScope :: File ( edit. make_mut ( it) ) ,
306+ ImportScope :: Module ( it) => ImportScope :: Module ( edit. make_mut ( it) ) ,
307+ ImportScope :: Block ( it) => ImportScope :: Block ( edit. make_mut ( it) ) ,
308+ } ;
309+ insert_use ( & scope, path, & ctx. config . insert_use ) ;
307310 }
308311 } ,
309312 )
@@ -312,37 +315,31 @@ fn replace_usages(
312315
313316struct FileReferenceWithImport {
314317 range : TextRange ,
315- old_name : ast:: NameLike ,
316- new_name : ast:: NameLike ,
318+ name : ast:: NameLike ,
317319 import_data : Option < ( ImportScope , ast:: Path ) > ,
318320}
319321
320322fn augment_references_with_imports (
321- edit : & mut SourceChangeBuilder ,
322323 ctx : & AssistContext < ' _ > ,
323- references : & [ FileReference ] ,
324+ references : Vec < FileReference > ,
324325 target_module : & hir:: Module ,
325326) -> Vec < FileReferenceWithImport > {
326327 let mut visited_modules = FxHashSet :: default ( ) ;
327328
328329 references
329- . iter ( )
330+ . into_iter ( )
330331 . filter_map ( |FileReference { range, name, .. } | {
331332 let name = name. clone ( ) . into_name_like ( ) ?;
332- ctx. sema . scope ( name. syntax ( ) ) . map ( |scope| ( * range, name, scope. module ( ) ) )
333+ ctx. sema . scope ( name. syntax ( ) ) . map ( |scope| ( range, name, scope. module ( ) ) )
333334 } )
334335 . map ( |( range, name, ref_module) | {
335- let old_name = name. clone ( ) ;
336- let new_name = edit. make_mut ( name. clone ( ) ) ;
337-
338336 // if the referenced module is not the same as the target one and has not been seen before, add an import
339337 let import_data = if ref_module. nearest_non_block_module ( ctx. db ( ) ) != * target_module
340338 && !visited_modules. contains ( & ref_module)
341339 {
342340 visited_modules. insert ( ref_module) ;
343341
344- let import_scope =
345- ImportScope :: find_insert_use_container ( new_name. syntax ( ) , & ctx. sema ) ;
342+ let import_scope = ImportScope :: find_insert_use_container ( name. syntax ( ) , & ctx. sema ) ;
346343 let path = ref_module
347344 . find_use_path_prefixed (
348345 ctx. sema . db ,
@@ -360,7 +357,7 @@ fn augment_references_with_imports(
360357 None
361358 } ;
362359
363- FileReferenceWithImport { range, old_name , new_name , import_data }
360+ FileReferenceWithImport { range, name , import_data }
364361 } )
365362 . collect ( )
366363}
@@ -405,13 +402,12 @@ fn find_record_expr_usage(
405402 let record_field = ast:: RecordExprField :: for_field_name ( name_ref) ?;
406403 let initializer = record_field. expr ( ) ?;
407404
408- if let Definition :: Field ( expected_field ) = target_definition {
409- if got_field != expected_field {
410- return None ;
405+ match target_definition {
406+ Definition :: Field ( expected_field ) if got_field == expected_field => {
407+ Some ( ( record_field , initializer ) )
411408 }
409+ _ => None ,
412410 }
413-
414- Some ( ( record_field, initializer) )
415411}
416412
417413fn find_record_pat_field_usage ( name : & ast:: NameLike ) -> Option < ast:: Pat > {
@@ -466,12 +462,9 @@ fn add_enum_def(
466462 let indent = IndentLevel :: from_node ( & insert_before) ;
467463 enum_def. reindent_to ( indent) ;
468464
469- ted:: insert_all (
470- ted:: Position :: before ( & edit. make_syntax_mut ( insert_before) ) ,
471- vec ! [
472- enum_def. syntax( ) . clone( ) . into( ) ,
473- make:: tokens:: whitespace( & format!( "\n \n {indent}" ) ) . into( ) ,
474- ] ,
465+ edit. insert (
466+ insert_before. text_range ( ) . start ( ) ,
467+ format ! ( "{}\n \n {indent}" , enum_def. syntax( ) . text( ) ) ,
475468 ) ;
476469}
477470
@@ -800,6 +793,78 @@ fn main() {
800793 )
801794 }
802795
796+ #[ test]
797+ fn local_var_init_struct_usage ( ) {
798+ check_assist (
799+ bool_to_enum,
800+ r#"
801+ struct Foo {
802+ foo: bool,
803+ }
804+
805+ fn main() {
806+ let $0foo = true;
807+ let s = Foo { foo };
808+ }
809+ "# ,
810+ r#"
811+ struct Foo {
812+ foo: bool,
813+ }
814+
815+ #[derive(PartialEq, Eq)]
816+ enum Bool { True, False }
817+
818+ fn main() {
819+ let foo = Bool::True;
820+ let s = Foo { foo: foo == Bool::True };
821+ }
822+ "# ,
823+ )
824+ }
825+
826+ #[ test]
827+ fn local_var_init_struct_usage_in_macro ( ) {
828+ check_assist (
829+ bool_to_enum,
830+ r#"
831+ struct Struct {
832+ boolean: bool,
833+ }
834+
835+ macro_rules! identity {
836+ ($body:expr) => {
837+ $body
838+ }
839+ }
840+
841+ fn new() -> Struct {
842+ let $0boolean = true;
843+ identity![Struct { boolean }]
844+ }
845+ "# ,
846+ r#"
847+ struct Struct {
848+ boolean: bool,
849+ }
850+
851+ macro_rules! identity {
852+ ($body:expr) => {
853+ $body
854+ }
855+ }
856+
857+ #[derive(PartialEq, Eq)]
858+ enum Bool { True, False }
859+
860+ fn new() -> Struct {
861+ let boolean = Bool::True;
862+ identity![Struct { boolean: boolean == Bool::True }]
863+ }
864+ "# ,
865+ )
866+ }
867+
803868 #[ test]
804869 fn field_struct_basic ( ) {
805870 cov_mark:: check!( replaces_record_expr) ;
@@ -1321,6 +1386,46 @@ fn main() {
13211386 )
13221387 }
13231388
1389+ #[ test]
1390+ fn field_in_macro ( ) {
1391+ check_assist (
1392+ bool_to_enum,
1393+ r#"
1394+ struct Struct {
1395+ $0boolean: bool,
1396+ }
1397+
1398+ fn boolean(x: Struct) {
1399+ let Struct { boolean } = x;
1400+ }
1401+
1402+ macro_rules! identity { ($body:expr) => { $body } }
1403+
1404+ fn new() -> Struct {
1405+ identity!(Struct { boolean: true })
1406+ }
1407+ "# ,
1408+ r#"
1409+ #[derive(PartialEq, Eq)]
1410+ enum Bool { True, False }
1411+
1412+ struct Struct {
1413+ boolean: Bool,
1414+ }
1415+
1416+ fn boolean(x: Struct) {
1417+ let Struct { boolean } = x;
1418+ }
1419+
1420+ macro_rules! identity { ($body:expr) => { $body } }
1421+
1422+ fn new() -> Struct {
1423+ identity!(Struct { boolean: Bool::True })
1424+ }
1425+ "# ,
1426+ )
1427+ }
1428+
13241429 #[ test]
13251430 fn field_non_bool ( ) {
13261431 cov_mark:: check!( not_applicable_non_bool_field) ;
0 commit comments