@@ -16,7 +16,7 @@ use syntax::{
1616 edit_in_place:: { AttrsOwnerEdit , Indent } ,
1717 make, HasName ,
1818 } ,
19- ted, AstNode , NodeOrToken , SyntaxNode , T ,
19+ match_ast , ted, AstNode , NodeOrToken , SyntaxNode , T ,
2020} ;
2121use text_edit:: TextRange ;
2222
@@ -40,10 +40,10 @@ use crate::assist_context::{AssistContext, Assists};
4040// ```
4141// ->
4242// ```
43- // fn main() {
44- // #[derive(PartialEq, Eq)]
45- // enum Bool { True, False }
43+ // #[derive(PartialEq, Eq)]
44+ // enum Bool { True, False }
4645//
46+ // fn main() {
4747// let bool = Bool::True;
4848//
4949// if bool == Bool::True {
@@ -270,6 +270,10 @@ fn replace_usages(
270270 }
271271 _ => ( ) ,
272272 }
273+ } else if let Some ( ( ty_annotation, initializer) ) = find_assoc_const_usage ( & new_name)
274+ {
275+ edit. replace ( ty_annotation. syntax ( ) . text_range ( ) , "Bool" ) ;
276+ replace_bool_expr ( edit, initializer) ;
273277 } else if new_name. syntax ( ) . ancestors ( ) . find_map ( ast:: UseTree :: cast) . is_none ( ) {
274278 // for any other usage in an expression, replace it with a check that it is the true variant
275279 if let Some ( ( record_field, expr) ) = new_name
@@ -413,6 +417,15 @@ fn find_record_pat_field_usage(name: &ast::NameLike) -> Option<ast::Pat> {
413417 }
414418}
415419
420+ fn find_assoc_const_usage ( name : & ast:: NameLike ) -> Option < ( ast:: Type , ast:: Expr ) > {
421+ let const_ = name. syntax ( ) . parent ( ) . and_then ( ast:: Const :: cast) ?;
422+ if const_. syntax ( ) . parent ( ) . and_then ( ast:: AssocItemList :: cast) . is_none ( ) {
423+ return None ;
424+ }
425+
426+ Some ( ( const_. ty ( ) ?, const_. body ( ) ?) )
427+ }
428+
416429/// Adds the definition of the new enum before the target node.
417430fn add_enum_def (
418431 edit : & mut SourceChangeBuilder ,
@@ -430,18 +443,48 @@ fn add_enum_def(
430443 . any ( |module| module. nearest_non_block_module ( ctx. db ( ) ) != * target_module) ;
431444 let enum_def = make_bool_enum ( make_enum_pub) ;
432445
433- let indent = IndentLevel :: from_node ( & target_node) ;
446+ let insert_before = node_to_insert_before ( target_node) ;
447+ let indent = IndentLevel :: from_node ( & insert_before) ;
434448 enum_def. reindent_to ( indent) ;
435449
436450 ted:: insert_all (
437- ted:: Position :: before ( & edit. make_syntax_mut ( target_node ) ) ,
451+ ted:: Position :: before ( & edit. make_syntax_mut ( insert_before ) ) ,
438452 vec ! [
439453 enum_def. syntax( ) . clone( ) . into( ) ,
440454 make:: tokens:: whitespace( & format!( "\n \n {indent}" ) ) . into( ) ,
441455 ] ,
442456 ) ;
443457}
444458
459+ /// Finds where to put the new enum definition, at the nearest module or at top-level.
460+ fn node_to_insert_before ( mut target_node : SyntaxNode ) -> SyntaxNode {
461+ let mut ancestors = target_node. ancestors ( ) ;
462+
463+ while let Some ( ancestor) = ancestors. next ( ) {
464+ match_ast ! {
465+ match ancestor {
466+ ast:: Item ( item) => {
467+ if item
468+ . syntax( )
469+ . parent( )
470+ . and_then( |item_list| item_list. parent( ) )
471+ . and_then( ast:: Module :: cast)
472+ . is_some( )
473+ {
474+ return ancestor;
475+ }
476+ } ,
477+ ast:: SourceFile ( _) => break ,
478+ _ => ( ) ,
479+ }
480+ }
481+
482+ target_node = ancestor;
483+ }
484+
485+ target_node
486+ }
487+
445488fn make_bool_enum ( make_pub : bool ) -> ast:: Enum {
446489 let enum_def = make:: enum_ (
447490 if make_pub { Some ( make:: visibility_pub ( ) ) } else { None } ,
@@ -491,10 +534,10 @@ fn main() {
491534}
492535"# ,
493536 r#"
494- fn main() {
495- #[derive(PartialEq, Eq)]
496- enum Bool { True, False }
537+ #[derive(PartialEq, Eq)]
538+ enum Bool { True, False }
497539
540+ fn main() {
498541 let foo = Bool::True;
499542
500543 if foo == Bool::True {
@@ -520,10 +563,10 @@ fn main() {
520563}
521564"# ,
522565 r#"
523- fn main() {
524- #[derive(PartialEq, Eq)]
525- enum Bool { True, False }
566+ #[derive(PartialEq, Eq)]
567+ enum Bool { True, False }
526568
569+ fn main() {
527570 let foo = Bool::True;
528571
529572 if foo == Bool::False {
@@ -545,10 +588,10 @@ fn main() {
545588}
546589"# ,
547590 r#"
548- fn main() {
549- #[derive(PartialEq, Eq)]
550- enum Bool { True, False }
591+ #[derive(PartialEq, Eq)]
592+ enum Bool { True, False }
551593
594+ fn main() {
552595 let foo: Bool = Bool::False;
553596}
554597"# ,
@@ -565,10 +608,10 @@ fn main() {
565608}
566609"# ,
567610 r#"
568- fn main() {
569- #[derive(PartialEq, Eq)]
570- enum Bool { True, False }
611+ #[derive(PartialEq, Eq)]
612+ enum Bool { True, False }
571613
614+ fn main() {
572615 let foo = if 1 == 2 { Bool::True } else { Bool::False };
573616}
574617"# ,
@@ -590,10 +633,10 @@ fn main() {
590633}
591634"# ,
592635 r#"
593- fn main() {
594- #[derive(PartialEq, Eq)]
595- enum Bool { True, False }
636+ #[derive(PartialEq, Eq)]
637+ enum Bool { True, False }
596638
639+ fn main() {
597640 let foo = Bool::False;
598641 let bar = true;
599642
@@ -619,10 +662,10 @@ fn main() {
619662}
620663"# ,
621664 r#"
622- fn main() {
623- #[derive(PartialEq, Eq)]
624- enum Bool { True, False }
665+ #[derive(PartialEq, Eq)]
666+ enum Bool { True, False }
625667
668+ fn main() {
626669 let foo = Bool::True;
627670
628671 if *&foo == Bool::True {
@@ -645,10 +688,10 @@ fn main() {
645688}
646689"# ,
647690 r#"
648- fn main() {
649- #[derive(PartialEq, Eq)]
650- enum Bool { True, False }
691+ #[derive(PartialEq, Eq)]
692+ enum Bool { True, False }
651693
694+ fn main() {
652695 let foo: Bool;
653696 foo = Bool::True;
654697}
@@ -671,10 +714,10 @@ fn main() {
671714}
672715"# ,
673716 r#"
674- fn main() {
675- #[derive(PartialEq, Eq)]
676- enum Bool { True, False }
717+ #[derive(PartialEq, Eq)]
718+ enum Bool { True, False }
677719
720+ fn main() {
678721 let foo = Bool::True;
679722 let bar = foo == Bool::False;
680723
@@ -702,11 +745,11 @@ fn main() {
702745}
703746"# ,
704747 r#"
748+ #[derive(PartialEq, Eq)]
749+ enum Bool { True, False }
750+
705751fn main() {
706752 if !"foo".chars().any(|c| {
707- #[derive(PartialEq, Eq)]
708- enum Bool { True, False }
709-
710753 let foo = Bool::True;
711754 foo == Bool::True
712755 }) {
@@ -1445,6 +1488,90 @@ pub mod bar {
14451488 )
14461489 }
14471490
1491+ #[ test]
1492+ fn const_in_impl_cross_file ( ) {
1493+ check_assist (
1494+ bool_to_enum,
1495+ r#"
1496+ //- /main.rs
1497+ mod foo;
1498+
1499+ struct Foo;
1500+
1501+ impl Foo {
1502+ pub const $0BOOL: bool = true;
1503+ }
1504+
1505+ //- /foo.rs
1506+ use crate::Foo;
1507+
1508+ fn foo() -> bool {
1509+ Foo::BOOL
1510+ }
1511+ "# ,
1512+ r#"
1513+ //- /main.rs
1514+ mod foo;
1515+
1516+ struct Foo;
1517+
1518+ #[derive(PartialEq, Eq)]
1519+ pub enum Bool { True, False }
1520+
1521+ impl Foo {
1522+ pub const BOOL: Bool = Bool::True;
1523+ }
1524+
1525+ //- /foo.rs
1526+ use crate::{Foo, Bool};
1527+
1528+ fn foo() -> bool {
1529+ Foo::BOOL == Bool::True
1530+ }
1531+ "# ,
1532+ )
1533+ }
1534+
1535+ #[ test]
1536+ fn const_in_trait ( ) {
1537+ check_assist (
1538+ bool_to_enum,
1539+ r#"
1540+ trait Foo {
1541+ const $0BOOL: bool;
1542+ }
1543+
1544+ impl Foo for usize {
1545+ const BOOL: bool = true;
1546+ }
1547+
1548+ fn main() {
1549+ if <usize as Foo>::BOOL {
1550+ println!("foo");
1551+ }
1552+ }
1553+ "# ,
1554+ r#"
1555+ #[derive(PartialEq, Eq)]
1556+ enum Bool { True, False }
1557+
1558+ trait Foo {
1559+ const BOOL: Bool;
1560+ }
1561+
1562+ impl Foo for usize {
1563+ const BOOL: Bool = Bool::True;
1564+ }
1565+
1566+ fn main() {
1567+ if <usize as Foo>::BOOL == Bool::True {
1568+ println!("foo");
1569+ }
1570+ }
1571+ "# ,
1572+ )
1573+ }
1574+
14481575 #[ test]
14491576 fn const_non_bool ( ) {
14501577 cov_mark:: check!( not_applicable_non_bool_const) ;
0 commit comments