@@ -8,8 +8,8 @@ use rustc_hash::FxHashMap;
88use syntax:: SmolStr ;
99
1010use crate :: {
11- db:: DefDatabase , AdtId , AttrDefId , CrateId , EnumId , FunctionId , ImplId , ModuleDefId , StaticId ,
12- StructId , TraitId ,
11+ db:: DefDatabase , AdtId , AttrDefId , CrateId , EnumId , EnumVariantId , FunctionId , ImplId ,
12+ ModuleDefId , StaticId , StructId , TraitId ,
1313} ;
1414
1515#[ derive( Debug , Clone , Copy , PartialEq , Eq , Hash ) ]
@@ -20,6 +20,7 @@ pub enum LangItemTarget {
2020 StaticId ( StaticId ) ,
2121 StructId ( StructId ) ,
2222 TraitId ( TraitId ) ,
23+ EnumVariantId ( EnumVariantId ) ,
2324}
2425
2526impl LangItemTarget {
@@ -64,6 +65,13 @@ impl LangItemTarget {
6465 _ => None ,
6566 }
6667 }
68+
69+ pub fn as_enum_variant ( self ) -> Option < EnumVariantId > {
70+ match self {
71+ LangItemTarget :: EnumVariantId ( id) => Some ( id) ,
72+ _ => None ,
73+ }
74+ }
6775}
6876
6977#[ derive( Default , Debug , Clone , PartialEq , Eq ) ]
@@ -92,19 +100,31 @@ impl LangItems {
92100 for def in module_data. scope . declarations ( ) {
93101 match def {
94102 ModuleDefId :: TraitId ( trait_) => {
95- lang_items. collect_lang_item ( db, trait_, LangItemTarget :: TraitId )
103+ lang_items. collect_lang_item ( db, trait_, LangItemTarget :: TraitId ) ;
104+ db. trait_data ( trait_) . items . iter ( ) . for_each ( |& ( _, assoc_id) | {
105+ if let crate :: AssocItemId :: FunctionId ( f) = assoc_id {
106+ lang_items. collect_lang_item ( db, f, LangItemTarget :: FunctionId ) ;
107+ }
108+ } ) ;
96109 }
97110 ModuleDefId :: AdtId ( AdtId :: EnumId ( e) ) => {
98- lang_items. collect_lang_item ( db, e, LangItemTarget :: EnumId )
111+ lang_items. collect_lang_item ( db, e, LangItemTarget :: EnumId ) ;
112+ db. enum_data ( e) . variants . iter ( ) . for_each ( |( local_id, _) | {
113+ lang_items. collect_lang_item (
114+ db,
115+ EnumVariantId { parent : e, local_id } ,
116+ LangItemTarget :: EnumVariantId ,
117+ ) ;
118+ } ) ;
99119 }
100120 ModuleDefId :: AdtId ( AdtId :: StructId ( s) ) => {
101- lang_items. collect_lang_item ( db, s, LangItemTarget :: StructId )
121+ lang_items. collect_lang_item ( db, s, LangItemTarget :: StructId ) ;
102122 }
103123 ModuleDefId :: FunctionId ( f) => {
104- lang_items. collect_lang_item ( db, f, LangItemTarget :: FunctionId )
124+ lang_items. collect_lang_item ( db, f, LangItemTarget :: FunctionId ) ;
105125 }
106126 ModuleDefId :: StaticId ( s) => {
107- lang_items. collect_lang_item ( db, s, LangItemTarget :: StaticId )
127+ lang_items. collect_lang_item ( db, s, LangItemTarget :: StaticId ) ;
108128 }
109129 _ => { }
110130 }
0 commit comments