@@ -16,19 +16,22 @@ use std::env;
1616use std:: sync:: LazyLock ;
1717
1818use base_db:: SourceDatabaseFileInputExt as _;
19+ use either:: Either ;
1920use expect_test:: Expect ;
2021use hir_def:: {
2122 db:: DefDatabase ,
2223 expr_store:: { Body , BodySourceMap } ,
2324 hir:: { ExprId , Pat , PatId } ,
2425 item_scope:: ItemScope ,
2526 nameres:: DefMap ,
26- src:: HasSource ,
27- AssocItemId , DefWithBodyId , HasModule , LocalModuleId , Lookup , ModuleDefId , SyntheticSyntax ,
27+ src:: { HasChildSource , HasSource } ,
28+ AdtId , AssocItemId , DefWithBodyId , FieldId , HasModule , LocalModuleId , Lookup , ModuleDefId ,
29+ SyntheticSyntax ,
2830} ;
2931use hir_expand:: { db:: ExpandDatabase , FileRange , InFile } ;
3032use itertools:: Itertools ;
3133use rustc_hash:: FxHashMap ;
34+ use span:: TextSize ;
3235use stdx:: format_to;
3336use syntax:: {
3437 ast:: { self , AstNode , HasName } ,
@@ -132,14 +135,40 @@ fn check_impl(
132135 None => continue ,
133136 } ;
134137 let def_map = module. def_map ( & db) ;
135- visit_module ( & db, & def_map, module. local_id , & mut |it| {
136- defs. push ( match it {
137- ModuleDefId :: FunctionId ( it) => it. into ( ) ,
138- ModuleDefId :: EnumVariantId ( it) => it. into ( ) ,
139- ModuleDefId :: ConstId ( it) => it. into ( ) ,
140- ModuleDefId :: StaticId ( it) => it. into ( ) ,
141- _ => return ,
142- } )
138+ visit_module ( & db, & def_map, module. local_id , & mut |it| match it {
139+ ModuleDefId :: FunctionId ( it) => defs. push ( it. into ( ) ) ,
140+ ModuleDefId :: EnumVariantId ( it) => {
141+ defs. push ( it. into ( ) ) ;
142+ let variant_id = it. into ( ) ;
143+ let vd = db. variant_data ( variant_id) ;
144+ defs. extend ( vd. fields ( ) . iter ( ) . filter_map ( |( local_id, fd) | {
145+ if fd. has_default {
146+ let field = FieldId { parent : variant_id, local_id, has_default : true } ;
147+ Some ( DefWithBodyId :: FieldId ( field) )
148+ } else {
149+ None
150+ }
151+ } ) ) ;
152+ }
153+ ModuleDefId :: ConstId ( it) => defs. push ( it. into ( ) ) ,
154+ ModuleDefId :: StaticId ( it) => defs. push ( it. into ( ) ) ,
155+ ModuleDefId :: AdtId ( it) => {
156+ let variant_id = match it {
157+ AdtId :: StructId ( it) => it. into ( ) ,
158+ AdtId :: UnionId ( it) => it. into ( ) ,
159+ AdtId :: EnumId ( _) => return ,
160+ } ;
161+ let vd = db. variant_data ( variant_id) ;
162+ defs. extend ( vd. fields ( ) . iter ( ) . filter_map ( |( local_id, fd) | {
163+ if fd. has_default {
164+ let field = FieldId { parent : variant_id, local_id, has_default : true } ;
165+ Some ( DefWithBodyId :: FieldId ( field) )
166+ } else {
167+ None
168+ }
169+ } ) ) ;
170+ }
171+ _ => { }
143172 } ) ;
144173 }
145174 defs. sort_by_key ( |def| match def {
@@ -160,12 +189,20 @@ fn check_impl(
160189 loc. source ( & db) . value . syntax ( ) . text_range ( ) . start ( )
161190 }
162191 DefWithBodyId :: InTypeConstId ( it) => it. source ( & db) . syntax ( ) . text_range ( ) . start ( ) ,
163- DefWithBodyId :: FieldId ( _) => unreachable ! ( ) ,
192+ DefWithBodyId :: FieldId ( it) => {
193+ let cs = it. parent . child_source ( & db) ;
194+ match cs. value . get ( it. local_id ) {
195+ Some ( Either :: Left ( it) ) => it. syntax ( ) . text_range ( ) . start ( ) ,
196+ Some ( Either :: Right ( it) ) => it. syntax ( ) . text_range ( ) . end ( ) ,
197+ None => TextSize :: new ( u32:: MAX ) ,
198+ }
199+ }
164200 } ) ;
165201 let mut unexpected_type_mismatches = String :: new ( ) ;
166202 for def in defs {
167203 let ( body, body_source_map) = db. body_with_source_map ( def) ;
168204 let inference_result = db. infer ( def) ;
205+ dbg ! ( & inference_result) ;
169206
170207 for ( pat, mut ty) in inference_result. type_of_pat . iter ( ) {
171208 if let Pat :: Bind { id, .. } = body. pats [ pat] {
@@ -389,14 +426,40 @@ fn infer_with_mismatches(content: &str, include_mismatches: bool) -> String {
389426 let def_map = module. def_map ( & db) ;
390427
391428 let mut defs: Vec < DefWithBodyId > = Vec :: new ( ) ;
392- visit_module ( & db, & def_map, module. local_id , & mut |it| {
393- defs. push ( match it {
394- ModuleDefId :: FunctionId ( it) => it. into ( ) ,
395- ModuleDefId :: EnumVariantId ( it) => it. into ( ) ,
396- ModuleDefId :: ConstId ( it) => it. into ( ) ,
397- ModuleDefId :: StaticId ( it) => it. into ( ) ,
398- _ => return ,
399- } )
429+ visit_module ( & db, & def_map, module. local_id , & mut |it| match it {
430+ ModuleDefId :: FunctionId ( it) => defs. push ( it. into ( ) ) ,
431+ ModuleDefId :: EnumVariantId ( it) => {
432+ defs. push ( it. into ( ) ) ;
433+ let variant_id = it. into ( ) ;
434+ let vd = db. variant_data ( variant_id) ;
435+ defs. extend ( vd. fields ( ) . iter ( ) . filter_map ( |( local_id, fd) | {
436+ if fd. has_default {
437+ let field = FieldId { parent : variant_id, local_id, has_default : true } ;
438+ Some ( DefWithBodyId :: FieldId ( field) )
439+ } else {
440+ None
441+ }
442+ } ) ) ;
443+ }
444+ ModuleDefId :: ConstId ( it) => defs. push ( it. into ( ) ) ,
445+ ModuleDefId :: StaticId ( it) => defs. push ( it. into ( ) ) ,
446+ ModuleDefId :: AdtId ( it) => {
447+ let variant_id = match it {
448+ AdtId :: StructId ( it) => it. into ( ) ,
449+ AdtId :: UnionId ( it) => it. into ( ) ,
450+ AdtId :: EnumId ( _) => return ,
451+ } ;
452+ let vd = db. variant_data ( variant_id) ;
453+ defs. extend ( vd. fields ( ) . iter ( ) . filter_map ( |( local_id, fd) | {
454+ if fd. has_default {
455+ let field = FieldId { parent : variant_id, local_id, has_default : true } ;
456+ Some ( DefWithBodyId :: FieldId ( field) )
457+ } else {
458+ None
459+ }
460+ } ) ) ;
461+ }
462+ _ => { }
400463 } ) ;
401464 defs. sort_by_key ( |def| match def {
402465 DefWithBodyId :: FunctionId ( it) => {
@@ -416,7 +479,14 @@ fn infer_with_mismatches(content: &str, include_mismatches: bool) -> String {
416479 loc. source ( & db) . value . syntax ( ) . text_range ( ) . start ( )
417480 }
418481 DefWithBodyId :: InTypeConstId ( it) => it. source ( & db) . syntax ( ) . text_range ( ) . start ( ) ,
419- DefWithBodyId :: FieldId ( _) => unreachable ! ( ) ,
482+ DefWithBodyId :: FieldId ( it) => {
483+ let cs = it. parent . child_source ( & db) ;
484+ match cs. value . get ( it. local_id ) {
485+ Some ( Either :: Left ( it) ) => it. syntax ( ) . text_range ( ) . start ( ) ,
486+ Some ( Either :: Right ( it) ) => it. syntax ( ) . text_range ( ) . end ( ) ,
487+ None => TextSize :: new ( u32:: MAX ) ,
488+ }
489+ }
420490 } ) ;
421491 for def in defs {
422492 let ( body, source_map) = db. body_with_source_map ( def) ;
@@ -477,7 +547,7 @@ pub(crate) fn visit_module(
477547 let body = db. body ( it. into ( ) ) ;
478548 visit_body ( db, & body, cb) ;
479549 }
480- ModuleDefId :: AdtId ( hir_def :: AdtId :: EnumId ( it) ) => {
550+ ModuleDefId :: AdtId ( AdtId :: EnumId ( it) ) => {
481551 db. enum_data ( it) . variants . iter ( ) . for_each ( |& ( it, _) | {
482552 let body = db. body ( it. into ( ) ) ;
483553 cb ( it. into ( ) ) ;
0 commit comments