1+ use hir:: ModuleDef ;
12use ide_db:: {
23 assists:: { AssistId , AssistKind } ,
34 defs:: Definition ,
4- search:: { FileReference , SearchScope , UsageSearchResult } ,
5+ helpers:: mod_path_to_ast,
6+ imports:: insert_use:: { insert_use, ImportScope } ,
7+ search:: { FileReference , UsageSearchResult } ,
58 source_change:: SourceChangeBuilder ,
69} ;
10+ use itertools:: Itertools ;
711use syntax:: {
812 ast:: {
913 self ,
@@ -48,6 +52,7 @@ use crate::assist_context::{AssistContext, Assists};
4852pub ( crate ) fn bool_to_enum ( acc : & mut Assists , ctx : & AssistContext < ' _ > ) -> Option < ( ) > {
4953 let BoolNodeData { target_node, name, ty_annotation, initializer, definition } =
5054 find_bool_node ( ctx) ?;
55+ let target_module = ctx. sema . scope ( & target_node) ?. module ( ) ;
5156
5257 let target = name. syntax ( ) . text_range ( ) ;
5358 acc. add (
@@ -64,13 +69,10 @@ pub(crate) fn bool_to_enum(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option
6469 replace_bool_expr ( edit, initializer) ;
6570 }
6671
67- let usages = definition
68- . usages ( & ctx. sema )
69- . in_scope ( & SearchScope :: single_file ( ctx. file_id ( ) ) )
70- . all ( ) ;
71- replace_usages ( edit, & usages) ;
72+ let usages = definition. usages ( & ctx. sema ) . all ( ) ;
7273
73- add_enum_def ( edit, ctx, & usages, target_node) ;
74+ add_enum_def ( edit, ctx, & usages, target_node, & target_module) ;
75+ replace_usages ( edit, ctx, & usages, & target_module) ;
7476 } ,
7577 )
7678}
@@ -186,8 +188,45 @@ fn bool_expr_to_enum_expr(expr: ast::Expr) -> ast::Expr {
186188}
187189
188190/// Replaces all usages of the target identifier, both when read and written to.
189- fn replace_usages ( edit : & mut SourceChangeBuilder , usages : & UsageSearchResult ) {
190- for ( _, references) in usages. iter ( ) {
191+ fn replace_usages (
192+ edit : & mut SourceChangeBuilder ,
193+ ctx : & AssistContext < ' _ > ,
194+ usages : & UsageSearchResult ,
195+ target_module : & hir:: Module ,
196+ ) {
197+ for ( file_id, references) in usages. iter ( ) {
198+ edit. edit_file ( * file_id) ;
199+
200+ // add imports across modules where needed
201+ references
202+ . iter ( )
203+ . filter_map ( |FileReference { name, .. } | {
204+ ctx. sema . scope ( name. syntax ( ) ) . map ( |scope| ( name, scope. module ( ) ) )
205+ } )
206+ . unique_by ( |name_and_module| name_and_module. 1 )
207+ . filter ( |( _, module) | module != target_module)
208+ . filter_map ( |( name, module) | {
209+ let import_scope = ImportScope :: find_insert_use_container ( name. syntax ( ) , & ctx. sema ) ;
210+ let mod_path = module. find_use_path_prefixed (
211+ ctx. sema . db ,
212+ ModuleDef :: Module ( * target_module) ,
213+ ctx. config . insert_use . prefix_kind ,
214+ ctx. config . prefer_no_std ,
215+ ) ;
216+ import_scope. zip ( mod_path)
217+ } )
218+ . for_each ( |( import_scope, mod_path) | {
219+ let import_scope = match import_scope {
220+ ImportScope :: File ( it) => ImportScope :: File ( edit. make_mut ( it) ) ,
221+ ImportScope :: Module ( it) => ImportScope :: Module ( edit. make_mut ( it) ) ,
222+ ImportScope :: Block ( it) => ImportScope :: Block ( edit. make_mut ( it) ) ,
223+ } ;
224+ let path =
225+ make:: path_concat ( mod_path_to_ast ( & mod_path) , make:: path_from_text ( "Bool" ) ) ;
226+ insert_use ( & import_scope, path, & ctx. config . insert_use ) ;
227+ } ) ;
228+
229+ // replace the usages in expressions
191230 references
192231 . into_iter ( )
193232 . filter_map ( |FileReference { range, name, .. } | match name {
@@ -213,7 +252,7 @@ fn replace_usages(edit: &mut SourceChangeBuilder, usages: &UsageSearchResult) {
213252 let record_field = edit. make_mut ( record_field) ;
214253 let enum_expr = bool_expr_to_enum_expr ( initializer) ;
215254 record_field. replace_expr ( enum_expr) ;
216- } else if name_ref. syntax ( ) . ancestors ( ) . find_map ( ast:: Expr :: cast) . is_some ( ) {
255+ } else if name_ref. syntax ( ) . ancestors ( ) . find_map ( ast:: UseTree :: cast) . is_none ( ) {
217256 // for any other usage in an expression, replace it with a check that it is the true variant
218257 edit. replace ( range, format ! ( "{} == Bool::True" , name_ref. text( ) ) ) ;
219258 }
@@ -255,8 +294,15 @@ fn add_enum_def(
255294 ctx : & AssistContext < ' _ > ,
256295 usages : & UsageSearchResult ,
257296 target_node : SyntaxNode ,
297+ target_module : & hir:: Module ,
258298) {
259- let make_enum_pub = usages. iter ( ) . any ( |( file_id, _) | file_id != & ctx. file_id ( ) ) ;
299+ let make_enum_pub = usages
300+ . iter ( )
301+ . flat_map ( |( _, refs) | refs)
302+ . filter_map ( |FileReference { name, .. } | {
303+ ctx. sema . scope ( name. syntax ( ) ) . map ( |scope| scope. module ( ) )
304+ } )
305+ . any ( |module| & module != target_module) ;
260306 let enum_def = make_bool_enum ( make_enum_pub) ;
261307
262308 let indent = IndentLevel :: from_node ( & target_node) ;
@@ -649,7 +695,7 @@ fn main() {
649695"# ,
650696 r#"
651697#[derive(PartialEq, Eq)]
652- enum $0Bool { True, False }
698+ enum Bool { True, False }
653699
654700struct Foo {
655701 bar: Bool,
@@ -713,6 +759,162 @@ fn main() {
713759 )
714760 }
715761
762+ #[ test]
763+ fn const_in_module ( ) {
764+ check_assist (
765+ bool_to_enum,
766+ r#"
767+ fn main() {
768+ if foo::FOO {
769+ println!("foo");
770+ }
771+ }
772+
773+ mod foo {
774+ pub const $0FOO: bool = true;
775+ }
776+ "# ,
777+ r#"
778+ use foo::Bool;
779+
780+ fn main() {
781+ if foo::FOO == Bool::True {
782+ println!("foo");
783+ }
784+ }
785+
786+ mod foo {
787+ #[derive(PartialEq, Eq)]
788+ pub enum Bool { True, False }
789+
790+ pub const FOO: Bool = Bool::True;
791+ }
792+ "# ,
793+ )
794+ }
795+
796+ #[ test]
797+ fn const_in_module_with_import ( ) {
798+ check_assist (
799+ bool_to_enum,
800+ r#"
801+ fn main() {
802+ use foo::FOO;
803+
804+ if FOO {
805+ println!("foo");
806+ }
807+ }
808+
809+ mod foo {
810+ pub const $0FOO: bool = true;
811+ }
812+ "# ,
813+ r#"
814+ use crate::foo::Bool;
815+
816+ fn main() {
817+ use foo::FOO;
818+
819+ if FOO == Bool::True {
820+ println!("foo");
821+ }
822+ }
823+
824+ mod foo {
825+ #[derive(PartialEq, Eq)]
826+ pub enum Bool { True, False }
827+
828+ pub const FOO: Bool = Bool::True;
829+ }
830+ "# ,
831+ )
832+ }
833+
834+ #[ test]
835+ fn const_cross_file ( ) {
836+ check_assist (
837+ bool_to_enum,
838+ r#"
839+ //- /main.rs
840+ mod foo;
841+
842+ fn main() {
843+ if foo::FOO {
844+ println!("foo");
845+ }
846+ }
847+
848+ //- /foo.rs
849+ pub const $0FOO: bool = true;
850+ "# ,
851+ r#"
852+ //- /main.rs
853+ use foo::Bool;
854+
855+ mod foo;
856+
857+ fn main() {
858+ if foo::FOO == Bool::True {
859+ println!("foo");
860+ }
861+ }
862+
863+ //- /foo.rs
864+ #[derive(PartialEq, Eq)]
865+ pub enum Bool { True, False }
866+
867+ pub const FOO: Bool = Bool::True;
868+ "# ,
869+ )
870+ }
871+
872+ #[ test]
873+ fn const_cross_file_and_module ( ) {
874+ check_assist (
875+ bool_to_enum,
876+ r#"
877+ //- /main.rs
878+ mod foo;
879+
880+ fn main() {
881+ use foo::bar;
882+
883+ if bar::BAR {
884+ println!("foo");
885+ }
886+ }
887+
888+ //- /foo.rs
889+ pub mod bar {
890+ pub const $0BAR: bool = false;
891+ }
892+ "# ,
893+ r#"
894+ //- /main.rs
895+ use crate::foo::bar::Bool;
896+
897+ mod foo;
898+
899+ fn main() {
900+ use foo::bar;
901+
902+ if bar::BAR == Bool::True {
903+ println!("foo");
904+ }
905+ }
906+
907+ //- /foo.rs
908+ pub mod bar {
909+ #[derive(PartialEq, Eq)]
910+ pub enum Bool { True, False }
911+
912+ pub const BAR: Bool = Bool::False;
913+ }
914+ "# ,
915+ )
916+ }
917+
716918 #[ test]
717919 fn const_non_bool ( ) {
718920 cov_mark:: check!( not_applicable_non_bool_const) ;
0 commit comments