@@ -20,12 +20,11 @@ object DesugarEnums {
2020 val Simple, Object, Class : Value = Value
2121 }
2222
23- final class EnumConstraints (minKind : CaseKind .Value , maxKind : CaseKind .Value , cases : List [(Int , TermName )]):
24- require(minKind <= maxKind && ! (cached && cachedValues .isEmpty))
23+ final case class EnumConstraints (minKind : CaseKind .Value , maxKind : CaseKind .Value , enumCases : List [(Int , RefTree )]):
24+ require(minKind <= maxKind && ! (cached && enumCases .isEmpty))
2525 def requiresCreator = minKind == CaseKind .Simple
2626 def isEnumeration = maxKind < CaseKind .Class
2727 def cached = minKind < CaseKind .Class
28- def cachedValues = cases
2928 end EnumConstraints
3029
3130 /** Attachment containing the number of enum cases, the smallest kind that was seen so far,
@@ -47,6 +46,11 @@ object DesugarEnums {
4746 if (cls.is(Module )) cls.linkedClass else cls
4847 }
4948
49+ def enumCompanion (using Context ): Symbol = {
50+ val cls = ctx.owner
51+ if (cls.is(Module )) cls.sourceModule else cls.linkedClass.sourceModule
52+ }
53+
5054 /** Is `tree` an (untyped) enum case? */
5155 def isEnumCase (tree : Tree )(using Context ): Boolean = tree match {
5256 case tree : MemberDef => tree.mods.isEnumCase
@@ -109,14 +113,12 @@ object DesugarEnums {
109113 * case _ => throw new IllegalArgumentException("case not found: " + $name)
110114 * }
111115 */
112- private def enumScaffolding (enumCases : List [( Int , TermName ) ])(using Context ): List [Tree ] = {
116+ private def enumScaffolding (enumValues : List [RefTree ])(using Context ): List [Tree ] = {
113117 val rawEnumClassRef = rawRef(enumClass.typeRef)
114118 extension (tpe : NamedType ) def ofRawEnum = AppliedTypeTree (ref(tpe), rawEnumClassRef)
115119
116- val privateValuesDef =
117- ValDef (nme.DOLLAR_VALUES , TypeTree (),
118- ArrayLiteral (enumCases.map((_, name) => Ident (name)), rawEnumClassRef))
119- .withFlags(Private | Synthetic )
120+ val privateValuesDef = ValDef (nme.DOLLAR_VALUES , TypeTree (), ArrayLiteral (enumValues, rawEnumClassRef))
121+ .withFlags(Private | Synthetic )
120122
121123 val valuesDef =
122124 DefDef (nme.values, Nil , Nil , defn.ArrayType .ofRawEnum, valuesDot(nme.clone_))
@@ -127,8 +129,8 @@ object DesugarEnums {
127129 val msg = Apply (Select (Literal (Constant (" enum case not found: " )), nme.PLUS ), Ident (nme.nameDollar))
128130 CaseDef (Ident (nme.WILDCARD ), EmptyTree ,
129131 Throw (New (TypeTree (defn.IllegalArgumentExceptionType ), List (msg :: Nil ))))
130- val stringCases = enumCases .map((_, name) =>
131- CaseDef (Literal (Constant (name.toString)), EmptyTree , Ident (name) )
132+ val stringCases = enumValues .map(enumValue =>
133+ CaseDef (Literal (Constant (enumValue. name.toString)), EmptyTree , enumValue )
132134 ) ::: defaultCase :: Nil
133135 Match (Ident (nme.nameDollar), stringCases)
134136 val valueOfDef = DefDef (nme.valueOf, Nil , List (param(nme.nameDollar, defn.StringType ) :: Nil ),
@@ -141,7 +143,7 @@ object DesugarEnums {
141143 }
142144
143145 private def enumLookupMethods (constraints : EnumConstraints )(using Context ): List [Tree ] =
144- def scaffolding : List [Tree ] = if constraints.cached then enumScaffolding(constraints.cachedValues ) else Nil
146+ def scaffolding : List [Tree ] = if constraints.cached then enumScaffolding(constraints.enumCases.map(_._2) ) else Nil
145147 def valueCtor : List [Tree ] = if constraints.requiresCreator then enumValueCreator :: Nil else Nil
146148 def byOrdinal : List [Tree ] =
147149 if isJavaEnum || ! constraints.cached then Nil
@@ -150,8 +152,8 @@ object DesugarEnums {
150152 val ord = Ident (nme.ordinal)
151153 val err = Throw (New (TypeTree (defn.IndexOutOfBoundsException .typeRef), List (Select (ord, nme.toString_) :: Nil )))
152154 CaseDef (ord, EmptyTree , err)
153- val valueCases = constraints.cachedValues .map((i, name ) =>
154- CaseDef (Literal (Constant (i)), EmptyTree , Ident (name) )
155+ val valueCases = constraints.enumCases .map((i, enumValue ) =>
156+ CaseDef (Literal (Constant (i)), EmptyTree , enumValue )
155157 ) ::: defaultCase :: Nil
156158 val fromOrdinalDef = DefDef (nme.fromOrdinalDollar, Nil , List (param(nme.ordinalDollar_, defn.IntType ) :: Nil ),
157159 rawRef(enumClass.typeRef), Match (Ident (nme.ordinalDollar_), valueCases))
@@ -304,7 +306,9 @@ object DesugarEnums {
304306 case name : TermName => (ordinal, name) :: seenCases
305307 case _ => seenCases
306308 if definesLookups then
307- (ordinal, enumLookupMethods(EnumConstraints (minKind, maxKind, cases.reverse)))
309+ val companionRef = ref(enumCompanion.termRef)
310+ val cachedValues = cases.reverse.map((i, name) => (i, Select (companionRef, name)))
311+ (ordinal, enumLookupMethods(EnumConstraints (minKind, maxKind, cachedValues)))
308312 else
309313 ctx.tree.pushAttachment(EnumCaseCount , (ordinal + 1 , minKind, maxKind, cases))
310314 (ordinal, Nil )
@@ -313,7 +317,7 @@ object DesugarEnums {
313317 def param (name : TermName , typ : Type )(using Context ): ValDef = param(name, TypeTree (typ))
314318 def param (name : TermName , tpt : Tree )(using Context ): ValDef = ValDef (name, tpt, EmptyTree ).withFlags(Param )
315319
316- private def isJavaEnum (using Context ): Boolean = ctx.owner.linkedClass .derivesFrom(defn.JavaEnumClass )
320+ private def isJavaEnum (using Context ): Boolean = enumClass .derivesFrom(defn.JavaEnumClass )
317321
318322 def ordinalMeth (body : Tree )(using Context ): DefDef =
319323 DefDef (nme.ordinal, Nil , Nil , TypeTree (defn.IntType ), body)
0 commit comments