@@ -15,6 +15,8 @@ import DenotTransformers._
1515import dotty .tools .dotc .ast .Trees ._
1616import SymUtils ._
1717
18+ import annotation .threadUnsafe
19+
1820object CompleteJavaEnums {
1921 val name : String = " completeJavaEnums"
2022
@@ -62,9 +64,10 @@ class CompleteJavaEnums extends MiniPhase with InfoTransformer { thisPhase =>
6264 /** The list of parameter definitions `$name: String, $ordinal: Int`, in given `owner`
6365 * with given flags (either `Param` or `ParamAccessor`)
6466 */
65- private def addedParams (owner : Symbol , flag : FlagSet )(using Context ): List [ValDef ] = {
66- val nameParam = newSymbol(owner, nameParamName, flag | Synthetic , defn.StringType , coord = owner.span)
67- val ordinalParam = newSymbol(owner, ordinalParamName, flag | Synthetic , defn.IntType , coord = owner.span)
67+ private def addedParams (owner : Symbol , isLocal : Boolean , flag : FlagSet )(using Context ): List [ValDef ] = {
68+ val flags = flag | Synthetic | (if isLocal then Private | Deferred else EmptyFlags )
69+ val nameParam = newSymbol(owner, nameParamName, flags, defn.StringType , coord = owner.span)
70+ val ordinalParam = newSymbol(owner, ordinalParamName, flags, defn.IntType , coord = owner.span)
6871 List (ValDef (nameParam), ValDef (ordinalParam))
6972 }
7073
@@ -85,7 +88,7 @@ class CompleteJavaEnums extends MiniPhase with InfoTransformer { thisPhase =>
8588 val sym = tree.symbol
8689 if (sym.isConstructor && sym.owner.derivesFromJavaEnum)
8790 val tree1 = cpy.DefDef (tree)(
88- vparamss = tree.vparamss.init :+ (tree.vparamss.last ++ addedParams(sym, Param )))
91+ vparamss = tree.vparamss.init :+ (tree.vparamss.last ++ addedParams(sym, isLocal = false , Param )))
8992 sym.setParamssFromDefs(tree1.tparams, tree1.vparamss)
9093 tree1
9194 else tree
@@ -107,47 +110,68 @@ class CompleteJavaEnums extends MiniPhase with InfoTransformer { thisPhase =>
107110 }
108111 }
109112
113+ private def isJavaEnumValueImpl (cls : Symbol )(using Context ): Boolean =
114+ cls.isAnonymousClass
115+ && (((cls.owner.name eq nme.DOLLAR_NEW ) && cls.owner.isAllOf(Private | Synthetic )) || cls.owner.isAllOf(EnumCase ))
116+ && cls.owner.owner.linkedClass.derivesFromJavaEnum
117+
118+ private val enumCaseOrdinals : MutableSymbolMap [Int ] = newMutableSymbolMap
119+
120+ private def registerEnumClass (cls : Symbol )(using Context ): Unit =
121+ cls.children.zipWithIndex.foreach(enumCaseOrdinals.put)
122+
123+ private def ordinalFor (enumCase : Symbol ): Int =
124+ enumCaseOrdinals.remove(enumCase).get
125+
110126 /** 1. If this is an enum class, add $name and $ordinal parameters to its
111127 * parameter accessors and pass them on to the java.lang.Enum constructor.
112128 *
113- * 2. If this is an anonymous class that implement a value enum case,
129+ * 2. If this is an anonymous class that implement a singleton enum case,
114130 * pass $name and $ordinal parameters to the enum superclass. The class
115131 * looks like this:
116132 *
117133 * class $anon extends E(...) {
118134 * ...
119- * def ordinal = N
120- * def toString = S
121- * ...
122135 * }
123136 *
124137 * After the transform it is expanded to
125138 *
126- * class $anon extends E(..., N, S) {
127- * "same as before"
139+ * class $anon extends E(..., $name, _$ordinal) { // if class implements a simple enum case
140+ * "same as before"
141+ * }
142+ *
143+ * class $anon extends E(..., "A", 0) { // if class implements a value enum case `A` with ordinal 0
144+ * "same as before"
128145 * }
129146 */
130- override def transformTemplate (templ : Template )(using Context ): Template = {
147+ override def transformTemplate (templ : Template )(using Context ): Tree = {
131148 val cls = templ.symbol.owner
132- if (cls.derivesFromJavaEnum) {
149+ if cls.derivesFromJavaEnum then
150+ registerEnumClass(cls) // invariant: class is visited before cases: see tests/pos/enum-companion-first.scala
133151 val (params, rest) = decomposeTemplateBody(templ.body)
134- val addedDefs = addedParams(cls, ParamAccessor )
152+ val addedDefs = addedParams(cls, isLocal = true , ParamAccessor )
135153 val addedSyms = addedDefs.map(_.symbol.entered)
136154 val addedForwarders = addedEnumForwarders(cls)
137155 cpy.Template (templ)(
138156 parents = addEnumConstrArgs(defn.JavaEnumClass , templ.parents, addedSyms.map(ref)),
139157 body = params ++ addedDefs ++ addedForwarders ++ rest)
140- }
141- else if (cls.isAnonymousClass && ((cls.owner. name eq nme. DOLLAR_NEW ) || cls.owner.isAllOf( EnumCase )) &&
142- cls.owner.owner.linkedClass.derivesFromJavaEnum) {
143- def rhsOf ( name : TermName ) =
144- templ.body.collect {
145- case mdef : DefDef if mdef. name == name => mdef.rhs
146- }.head
147- val args = List (rhsOf (nme.toString_ ), rhsOf (nme.ordinalDollar ))
158+ else if isJavaEnumValueImpl(cls) then
159+ def creatorParamRef ( name : TermName ) =
160+ ref( cls.owner.paramSymss.head.find(_.name == name).get)
161+ val args =
162+ if cls.owner.isAllOf( EnumCase ) then
163+ List ( Literal ( Constant (cls.owner. name.toString)), Literal ( Constant (ordinalFor(cls.owner))))
164+ else
165+ List (creatorParamRef (nme.nameDollar ), creatorParamRef (nme.ordinalDollar_ ))
148166 cpy.Template (templ)(
149- parents = addEnumConstrArgs(cls.owner.owner.linkedClass, templ.parents, args))
150- }
167+ parents = addEnumConstrArgs(cls.owner.owner.linkedClass, templ.parents, args),
168+ )
169+ else if cls.linkedClass.derivesFromJavaEnum then
170+ enumCaseOrdinals.clear() // remove simple cases // invariant: companion is visited after cases
171+ templ
151172 else templ
152173 }
174+
175+ override def checkPostCondition (tree : Tree )(using Context ): Unit =
176+ assert(enumCaseOrdinals.isEmpty, " Java based enum ordinal cache was not cleared" )
153177}
0 commit comments