@@ -11,7 +11,7 @@ use crate::{
1111 ted:: { self , Position } ,
1212 AstNode , AstToken , Direction ,
1313 SyntaxKind :: { ATTR , COMMENT , WHITESPACE } ,
14- SyntaxNode ,
14+ SyntaxNode , SyntaxToken ,
1515} ;
1616
1717use super :: HasName ;
@@ -506,19 +506,7 @@ impl ast::RecordExprFieldList {
506506
507507 let position = match self . fields ( ) . last ( ) {
508508 Some ( last_field) => {
509- let comma = match last_field
510- . syntax ( )
511- . siblings_with_tokens ( Direction :: Next )
512- . filter_map ( |it| it. into_token ( ) )
513- . find ( |it| it. kind ( ) == T ! [ , ] )
514- {
515- Some ( it) => it,
516- None => {
517- let comma = ast:: make:: token ( T ! [ , ] ) ;
518- ted:: insert ( Position :: after ( last_field. syntax ( ) ) , & comma) ;
519- comma
520- }
521- } ;
509+ let comma = get_or_insert_comma_after ( last_field. syntax ( ) ) ;
522510 Position :: after ( comma)
523511 }
524512 None => match self . l_curly_token ( ) {
@@ -579,19 +567,8 @@ impl ast::RecordPatFieldList {
579567
580568 let position = match self . fields ( ) . last ( ) {
581569 Some ( last_field) => {
582- let comma = match last_field
583- . syntax ( )
584- . siblings_with_tokens ( Direction :: Next )
585- . filter_map ( |it| it. into_token ( ) )
586- . find ( |it| it. kind ( ) == T ! [ , ] )
587- {
588- Some ( it) => it,
589- None => {
590- let comma = ast:: make:: token ( T ! [ , ] ) ;
591- ted:: insert ( Position :: after ( last_field. syntax ( ) ) , & comma) ;
592- comma
593- }
594- } ;
570+ let syntax = last_field. syntax ( ) ;
571+ let comma = get_or_insert_comma_after ( syntax) ;
595572 Position :: after ( comma)
596573 }
597574 None => match self . l_curly_token ( ) {
@@ -606,12 +583,53 @@ impl ast::RecordPatFieldList {
606583 }
607584 }
608585}
586+
587+ fn get_or_insert_comma_after ( syntax : & SyntaxNode ) -> SyntaxToken {
588+ let comma = match syntax
589+ . siblings_with_tokens ( Direction :: Next )
590+ . filter_map ( |it| it. into_token ( ) )
591+ . find ( |it| it. kind ( ) == T ! [ , ] )
592+ {
593+ Some ( it) => it,
594+ None => {
595+ let comma = ast:: make:: token ( T ! [ , ] ) ;
596+ ted:: insert ( Position :: after ( syntax) , & comma) ;
597+ comma
598+ }
599+ } ;
600+ comma
601+ }
602+
609603impl ast:: StmtList {
610604 pub fn push_front ( & self , statement : ast:: Stmt ) {
611605 ted:: insert ( Position :: after ( self . l_curly_token ( ) . unwrap ( ) ) , statement. syntax ( ) ) ;
612606 }
613607}
614608
609+ impl ast:: VariantList {
610+ pub fn add_variant ( & self , variant : ast:: Variant ) {
611+ let ( indent, position) = match self . variants ( ) . last ( ) {
612+ Some ( last_item) => (
613+ IndentLevel :: from_node ( last_item. syntax ( ) ) ,
614+ Position :: after ( get_or_insert_comma_after ( last_item. syntax ( ) ) ) ,
615+ ) ,
616+ None => match self . l_curly_token ( ) {
617+ Some ( l_curly) => {
618+ normalize_ws_between_braces ( self . syntax ( ) ) ;
619+ ( IndentLevel :: from_token ( & l_curly) + 1 , Position :: after ( & l_curly) )
620+ }
621+ None => ( IndentLevel :: single ( ) , Position :: last_child_of ( self . syntax ( ) ) ) ,
622+ } ,
623+ } ;
624+ let elements: Vec < SyntaxElement < _ > > = vec ! [
625+ make:: tokens:: whitespace( & format!( "{}{}" , "\n " , indent) ) . into( ) ,
626+ variant. syntax( ) . clone( ) . into( ) ,
627+ ast:: make:: token( T ![ , ] ) . into( ) ,
628+ ] ;
629+ ted:: insert_all ( position, elements) ;
630+ }
631+ }
632+
615633fn normalize_ws_between_braces ( node : & SyntaxNode ) -> Option < ( ) > {
616634 let l = node
617635 . children_with_tokens ( )
@@ -661,6 +679,9 @@ impl<N: AstNode + Clone> Indent for N {}
661679mod tests {
662680 use std:: fmt;
663681
682+ use stdx:: trim_indent;
683+ use test_utils:: assert_eq_text;
684+
664685 use crate :: SourceFile ;
665686
666687 use super :: * ;
@@ -714,4 +735,100 @@ mod tests {
714735 }" ,
715736 ) ;
716737 }
738+
739+ #[ test]
740+ fn add_variant_to_empty_enum ( ) {
741+ let variant = make:: variant ( make:: name ( "Bar" ) , None ) . clone_for_update ( ) ;
742+
743+ check_add_variant (
744+ r#"
745+ enum Foo {}
746+ "# ,
747+ r#"
748+ enum Foo {
749+ Bar,
750+ }
751+ "# ,
752+ variant,
753+ ) ;
754+ }
755+
756+ #[ test]
757+ fn add_variant_to_non_empty_enum ( ) {
758+ let variant = make:: variant ( make:: name ( "Baz" ) , None ) . clone_for_update ( ) ;
759+
760+ check_add_variant (
761+ r#"
762+ enum Foo {
763+ Bar,
764+ }
765+ "# ,
766+ r#"
767+ enum Foo {
768+ Bar,
769+ Baz,
770+ }
771+ "# ,
772+ variant,
773+ ) ;
774+ }
775+
776+ #[ test]
777+ fn add_variant_with_tuple_field_list ( ) {
778+ let variant = make:: variant (
779+ make:: name ( "Baz" ) ,
780+ Some ( ast:: FieldList :: TupleFieldList ( make:: tuple_field_list ( std:: iter:: once (
781+ make:: tuple_field ( None , make:: ty ( "bool" ) ) ,
782+ ) ) ) ) ,
783+ )
784+ . clone_for_update ( ) ;
785+
786+ check_add_variant (
787+ r#"
788+ enum Foo {
789+ Bar,
790+ }
791+ "# ,
792+ r#"
793+ enum Foo {
794+ Bar,
795+ Baz(bool),
796+ }
797+ "# ,
798+ variant,
799+ ) ;
800+ }
801+
802+ #[ test]
803+ fn add_variant_with_record_field_list ( ) {
804+ let variant = make:: variant (
805+ make:: name ( "Baz" ) ,
806+ Some ( ast:: FieldList :: RecordFieldList ( make:: record_field_list ( std:: iter:: once (
807+ make:: record_field ( None , make:: name ( "x" ) , make:: ty ( "bool" ) ) ,
808+ ) ) ) ) ,
809+ )
810+ . clone_for_update ( ) ;
811+
812+ check_add_variant (
813+ r#"
814+ enum Foo {
815+ Bar,
816+ }
817+ "# ,
818+ r#"
819+ enum Foo {
820+ Bar,
821+ Baz { x: bool },
822+ }
823+ "# ,
824+ variant,
825+ ) ;
826+ }
827+
828+ fn check_add_variant ( before : & str , expected : & str , variant : ast:: Variant ) {
829+ let enum_ = ast_mut_from_text :: < ast:: Enum > ( before) ;
830+ enum_. variant_list ( ) . map ( |it| it. add_variant ( variant) ) ;
831+ let after = enum_. to_string ( ) ;
832+ assert_eq_text ! ( & trim_indent( expected. trim( ) ) , & trim_indent( & after. trim( ) ) ) ;
833+ }
717834}
0 commit comments