33use std:: fmt;
44use tracing:: debug;
55
6- use hir_def:: { DefWithBodyId , EnumVariantId , HasModule , LocalFieldId , ModuleId , VariantId } ;
6+ use hir_def:: { DefWithBodyId , EnumId , EnumVariantId , HasModule , LocalFieldId , ModuleId , VariantId } ;
77use rustc_hash:: FxHashMap ;
88use rustc_pattern_analysis:: {
99 constructor:: { Constructor , ConstructorSet , VariantVisibility } ,
@@ -36,6 +36,24 @@ pub(crate) type WitnessPat<'p> = rustc_pattern_analysis::pat::WitnessPat<MatchCh
3636#[ derive( Copy , Clone , Debug , PartialEq , Eq ) ]
3737pub ( crate ) enum Void { }
3838
39+ /// An index type for enum variants. This ranges from 0 to `variants.len()`, whereas `EnumVariantId`
40+ /// can take arbitrary large values (and hence mustn't be used with `IndexVec`/`BitSet`).
41+ #[ derive( Copy , Clone , Debug , PartialEq , Eq , Hash ) ]
42+ pub ( crate ) struct EnumVariantContiguousIndex ( usize ) ;
43+
44+ impl EnumVariantContiguousIndex {
45+ fn from_enum_variant_id ( db : & dyn HirDatabase , target_evid : EnumVariantId ) -> Self {
46+ // Find the index of this variant in the list of variants.
47+ use hir_def:: Lookup ;
48+ let i = target_evid. lookup ( db. upcast ( ) ) . index as usize ;
49+ EnumVariantContiguousIndex ( i)
50+ }
51+
52+ fn to_enum_variant_id ( self , db : & dyn HirDatabase , eid : EnumId ) -> EnumVariantId {
53+ db. enum_data ( eid) . variants [ self . 0 ] . 0
54+ }
55+ }
56+
3957#[ derive( Clone ) ]
4058pub ( crate ) struct MatchCheckCtx < ' p > {
4159 module : ModuleId ,
@@ -89,9 +107,18 @@ impl<'p> MatchCheckCtx<'p> {
89107 }
90108 }
91109
92- fn variant_id_for_adt ( ctor : & Constructor < Self > , adt : hir_def:: AdtId ) -> Option < VariantId > {
110+ fn variant_id_for_adt (
111+ db : & ' p dyn HirDatabase ,
112+ ctor : & Constructor < Self > ,
113+ adt : hir_def:: AdtId ,
114+ ) -> Option < VariantId > {
93115 match ctor {
94- & Variant ( id) => Some ( id. into ( ) ) ,
116+ Variant ( id) => {
117+ let hir_def:: AdtId :: EnumId ( eid) = adt else {
118+ panic ! ( "bad constructor {ctor:?} for adt {adt:?}" )
119+ } ;
120+ Some ( id. to_enum_variant_id ( db, eid) . into ( ) )
121+ }
95122 Struct | UnionField => match adt {
96123 hir_def:: AdtId :: EnumId ( _) => None ,
97124 hir_def:: AdtId :: StructId ( id) => Some ( id. into ( ) ) ,
@@ -175,19 +202,24 @@ impl<'p> MatchCheckCtx<'p> {
175202 ctor = Struct ;
176203 arity = 1 ;
177204 }
178- & TyKind :: Adt ( adt, _) => {
205+ & TyKind :: Adt ( AdtId ( adt) , _) => {
179206 ctor = match pat. kind . as_ref ( ) {
180- PatKind :: Leaf { .. } if matches ! ( adt. 0 , hir_def:: AdtId :: UnionId ( _) ) => {
207+ PatKind :: Leaf { .. } if matches ! ( adt, hir_def:: AdtId :: UnionId ( _) ) => {
181208 UnionField
182209 }
183210 PatKind :: Leaf { .. } => Struct ,
184- PatKind :: Variant { enum_variant, .. } => Variant ( * enum_variant) ,
211+ PatKind :: Variant { enum_variant, .. } => {
212+ Variant ( EnumVariantContiguousIndex :: from_enum_variant_id (
213+ self . db ,
214+ * enum_variant,
215+ ) )
216+ }
185217 _ => {
186218 never ! ( ) ;
187219 Wildcard
188220 }
189221 } ;
190- let variant = Self :: variant_id_for_adt ( & ctor, adt. 0 ) . unwrap ( ) ;
222+ let variant = Self :: variant_id_for_adt ( self . db , & ctor, adt) . unwrap ( ) ;
191223 arity = variant. variant_data ( self . db . upcast ( ) ) . fields ( ) . len ( ) ;
192224 }
193225 _ => {
@@ -239,7 +271,7 @@ impl<'p> MatchCheckCtx<'p> {
239271 PatKind :: Deref { subpattern : subpatterns. next ( ) . unwrap ( ) }
240272 }
241273 TyKind :: Adt ( adt, substs) => {
242- let variant = Self :: variant_id_for_adt ( pat. ctor ( ) , adt. 0 ) . unwrap ( ) ;
274+ let variant = Self :: variant_id_for_adt ( self . db , pat. ctor ( ) , adt. 0 ) . unwrap ( ) ;
243275 let subpatterns = self
244276 . list_variant_fields ( pat. ty ( ) , variant)
245277 . zip ( subpatterns)
@@ -277,7 +309,7 @@ impl<'p> MatchCheckCtx<'p> {
277309impl < ' p > PatCx for MatchCheckCtx < ' p > {
278310 type Error = ( ) ;
279311 type Ty = Ty ;
280- type VariantIdx = EnumVariantId ;
312+ type VariantIdx = EnumVariantContiguousIndex ;
281313 type StrLit = Void ;
282314 type ArmData = ( ) ;
283315 type PatData = PatData < ' p > ;
@@ -303,7 +335,7 @@ impl<'p> PatCx for MatchCheckCtx<'p> {
303335 // patterns. If we're here we can assume this is a box pattern.
304336 1
305337 } else {
306- let variant = Self :: variant_id_for_adt ( ctor, adt) . unwrap ( ) ;
338+ let variant = Self :: variant_id_for_adt ( self . db , ctor, adt) . unwrap ( ) ;
307339 variant. variant_data ( self . db . upcast ( ) ) . fields ( ) . len ( )
308340 }
309341 }
@@ -343,7 +375,7 @@ impl<'p> PatCx for MatchCheckCtx<'p> {
343375 let subst_ty = substs. at ( Interner , 0 ) . assert_ty_ref ( Interner ) . clone ( ) ;
344376 single ( subst_ty)
345377 } else {
346- let variant = Self :: variant_id_for_adt ( ctor, adt) . unwrap ( ) ;
378+ let variant = Self :: variant_id_for_adt ( self . db , ctor, adt) . unwrap ( ) ;
347379 let ( adt, _) = ty. as_adt ( ) . unwrap ( ) ;
348380
349381 let adt_is_local =
@@ -421,15 +453,15 @@ impl<'p> PatCx for MatchCheckCtx<'p> {
421453 ConstructorSet :: NoConstructors
422454 } else {
423455 let mut variants = FxHashMap :: default ( ) ;
424- for & ( variant, _) in enum_data. variants . iter ( ) {
456+ for ( i , & ( variant, _) ) in enum_data. variants . iter ( ) . enumerate ( ) {
425457 let is_uninhabited =
426458 is_enum_variant_uninhabited_from ( variant, subst, cx. module , cx. db ) ;
427459 let visibility = if is_uninhabited {
428460 VariantVisibility :: Empty
429461 } else {
430462 VariantVisibility :: Visible
431463 } ;
432- variants. insert ( variant , visibility) ;
464+ variants. insert ( EnumVariantContiguousIndex ( i ) , visibility) ;
433465 }
434466
435467 ConstructorSet :: Variants {
@@ -453,10 +485,10 @@ impl<'p> PatCx for MatchCheckCtx<'p> {
453485 f : & mut fmt:: Formatter < ' _ > ,
454486 pat : & rustc_pattern_analysis:: pat:: DeconstructedPat < Self > ,
455487 ) -> fmt:: Result {
488+ let db = pat. data ( ) . db ;
456489 let variant =
457- pat. ty ( ) . as_adt ( ) . and_then ( |( adt, _) | Self :: variant_id_for_adt ( pat. ctor ( ) , adt) ) ;
490+ pat. ty ( ) . as_adt ( ) . and_then ( |( adt, _) | Self :: variant_id_for_adt ( db , pat. ctor ( ) , adt) ) ;
458491
459- let db = pat. data ( ) . db ;
460492 if let Some ( variant) = variant {
461493 match variant {
462494 VariantId :: EnumVariantId ( v) => {
0 commit comments