@@ -17,8 +17,8 @@ use std::fmt;
1717
1818use hir_def:: {
1919 data:: adt:: VariantData , db:: DefDatabase , hir:: Pat , src:: HasSource , AdtId , AttrDefId , ConstId ,
20- EnumId , FunctionId , ItemContainerId , Lookup , ModuleDefId , ModuleId , StaticId , StructId ,
21- TraitId , TypeAliasId ,
20+ EnumId , EnumVariantId , FunctionId , ItemContainerId , Lookup , ModuleDefId , ModuleId , StaticId ,
21+ StructId , TraitId , TypeAliasId ,
2222} ;
2323use hir_expand:: {
2424 name:: { AsName , Name } ,
@@ -353,17 +353,16 @@ impl<'a> DeclValidator<'a> {
353353 continue ;
354354 } ;
355355
356- let is_param = ast:: Param :: can_cast ( parent. kind ( ) ) ;
357- // We have to check that it's either `let var = ...` or `var @ Variant(_)` statement,
358- // because e.g. match arms are patterns as well.
359- // In other words, we check that it's a named variable binding.
360- let is_binding = ast:: LetStmt :: can_cast ( parent. kind ( ) )
361- || ( ast:: MatchArm :: can_cast ( parent. kind ( ) ) && ident_pat. at_token ( ) . is_some ( ) ) ;
362- if !( is_param || is_binding) {
363- // This pattern is not an actual variable declaration, e.g. `Some(val) => {..}` match arm.
356+ let is_shorthand = ast:: RecordPatField :: cast ( parent. clone ( ) )
357+ . map ( |parent| parent. name_ref ( ) . is_none ( ) )
358+ . unwrap_or_default ( ) ;
359+ if is_shorthand {
360+ // We don't check shorthand field patterns, such as 'field' in `Thing { field }`,
361+ // since the shorthand isn't the declaration.
364362 continue ;
365363 }
366364
365+ let is_param = ast:: Param :: can_cast ( parent. kind ( ) ) ;
367366 let ident_type = if is_param { IdentType :: Parameter } else { IdentType :: Variable } ;
368367
369368 self . create_incorrect_case_diagnostic_for_ast_node (
@@ -489,6 +488,11 @@ impl<'a> DeclValidator<'a> {
489488 /// Check incorrect names for enum variants.
490489 fn validate_enum_variants ( & mut self , enum_id : EnumId ) {
491490 let data = self . db . enum_data ( enum_id) ;
491+
492+ for ( variant_id, _) in data. variants . iter ( ) {
493+ self . validate_enum_variant_fields ( * variant_id) ;
494+ }
495+
492496 let mut enum_variants_replacements = data
493497 . variants
494498 . iter ( )
@@ -551,6 +555,75 @@ impl<'a> DeclValidator<'a> {
551555 }
552556 }
553557
558+ /// Check incorrect names for fields of enum variant.
559+ fn validate_enum_variant_fields ( & mut self , variant_id : EnumVariantId ) {
560+ let variant_data = self . db . enum_variant_data ( variant_id) ;
561+ let VariantData :: Record ( fields) = variant_data. variant_data . as_ref ( ) else {
562+ return ;
563+ } ;
564+ let mut variant_field_replacements = fields
565+ . iter ( )
566+ . filter_map ( |( _, field) | {
567+ to_lower_snake_case ( & field. name . to_smol_str ( ) ) . map ( |new_name| Replacement {
568+ current_name : field. name . clone ( ) ,
569+ suggested_text : new_name,
570+ expected_case : CaseType :: LowerSnakeCase ,
571+ } )
572+ } )
573+ . peekable ( ) ;
574+
575+ // XXX: only look at sources if we do have incorrect names
576+ if variant_field_replacements. peek ( ) . is_none ( ) {
577+ return ;
578+ }
579+
580+ let variant_loc = variant_id. lookup ( self . db . upcast ( ) ) ;
581+ let variant_src = variant_loc. source ( self . db . upcast ( ) ) ;
582+
583+ let Some ( ast:: FieldList :: RecordFieldList ( variant_fields_list) ) =
584+ variant_src. value . field_list ( )
585+ else {
586+ always ! (
587+ variant_field_replacements. peek( ) . is_none( ) ,
588+ "Replacements ({:?}) were generated for an enum variant \
589+ which had no fields list: {:?}",
590+ variant_field_replacements. collect:: <Vec <_>>( ) ,
591+ variant_src
592+ ) ;
593+ return ;
594+ } ;
595+ let mut variant_variants_iter = variant_fields_list. fields ( ) ;
596+ for field_replacement in variant_field_replacements {
597+ // We assume that parameters in replacement are in the same order as in the
598+ // actual params list, but just some of them (ones that named correctly) are skipped.
599+ let field = loop {
600+ if let Some ( field) = variant_variants_iter. next ( ) {
601+ let Some ( field_name) = field. name ( ) else {
602+ continue ;
603+ } ;
604+ if field_name. as_name ( ) == field_replacement. current_name {
605+ break field;
606+ }
607+ } else {
608+ never ! (
609+ "Replacement ({:?}) was generated for an enum variant field \
610+ which was not found: {:?}",
611+ field_replacement,
612+ variant_src
613+ ) ;
614+ return ;
615+ }
616+ } ;
617+
618+ self . create_incorrect_case_diagnostic_for_ast_node (
619+ field_replacement,
620+ variant_src. file_id ,
621+ & field,
622+ IdentType :: Field ,
623+ ) ;
624+ }
625+ }
626+
554627 fn validate_const ( & mut self , const_id : ConstId ) {
555628 let container = const_id. lookup ( self . db . upcast ( ) ) . container ;
556629 if self . is_trait_impl_container ( container) {
0 commit comments