@@ -33,11 +33,15 @@ object desugar {
3333// ----- DerivedTypeTrees -----------------------------------
3434
3535 class SetterParamTree extends DerivedTypeTree {
36- def derivedType (sym : Symbol )(implicit ctx : Context ) = sym.info.resultType
36+ def derivedTree (sym : Symbol )(implicit ctx : Context ) = tpd. TypeTree ( sym.info.resultType)
3737 }
3838
3939 class TypeRefTree extends DerivedTypeTree {
40- def derivedType (sym : Symbol )(implicit ctx : Context ) = sym.typeRef
40+ def derivedTree (sym : Symbol )(implicit ctx : Context ) = tpd.TypeTree (sym.typeRef)
41+ }
42+
43+ class TermRefTree extends DerivedTypeTree {
44+ def derivedTree (sym : Symbol )(implicit ctx : Context ) = tpd.ref(sym)
4145 }
4246
4347 /** A type tree that computes its type from an existing parameter.
@@ -73,7 +77,7 @@ object desugar {
7377 *
7478 * parameter name == reference name ++ suffix
7579 */
76- def derivedType (sym : Symbol )(implicit ctx : Context ) = {
80+ def derivedTree (sym : Symbol )(implicit ctx : Context ) = {
7781 val relocate = new TypeMap {
7882 val originalOwner = sym.owner
7983 def apply (tp : Type ) = tp match {
@@ -91,7 +95,7 @@ object desugar {
9195 mapOver(tp)
9296 }
9397 }
94- relocate(sym.info)
98+ tpd. TypeTree ( relocate(sym.info) )
9599 }
96100 }
97101
@@ -301,34 +305,56 @@ object desugar {
301305 val isCaseObject = mods.is(Case ) && mods.is(Module )
302306 val isImplicit = mods.is(Implicit )
303307 val isEnum = mods.hasMod[Mod .Enum ] && ! mods.is(Module )
304- val isEnumCase = isLegalEnumCase(cdef)
308+ val isEnumCase = mods.hasMod[ Mod . EnumCase ]
305309 val isValueClass = parents.nonEmpty && isAnyVal(parents.head)
306- // This is not watertight, but `extends AnyVal` will be replaced by `inline` later.
307-
310+ // This is not watertight, but `extends AnyVal` will be replaced by `inline` later.
308311
309312 val originalTparams = constr1.tparams
310313 val originalVparamss = constr1.vparamss
311- val constrTparams = originalTparams.map(toDefParam)
314+ lazy val derivedEnumParams = enumClass.typeParams.map(derivedTypeParam)
315+ val impliedTparams =
316+ if (isEnumCase && originalTparams.isEmpty)
317+ derivedEnumParams.map(tdef => tdef.withFlags(tdef.mods.flags | PrivateLocal ))
318+ else
319+ originalTparams
320+ val constrTparams = impliedTparams.map(toDefParam)
312321 val constrVparamss =
313322 if (originalVparamss.isEmpty) { // ensure parameter list is non-empty
314- if (isCaseClass) ctx.error(CaseClassMissingParamList (cdef), cdef.namePos)
323+ if (isCaseClass && originalTparams.isEmpty)
324+ ctx.error(CaseClassMissingParamList (cdef), cdef.namePos)
315325 ListOfNil
316326 }
317327 else originalVparamss.nestedMap(toDefParam)
318328 val constr = cpy.DefDef (constr1)(tparams = constrTparams, vparamss = constrVparamss)
319329
320- // Add constructor type parameters and evidence implicit parameters
321- // to auxiliary constructors
322- val normalizedBody = impl.body map {
323- case ddef : DefDef if ddef.name.isConstructorName =>
324- decompose(
325- defDef(
326- addEvidenceParams(
327- cpy.DefDef (ddef)(tparams = constrTparams),
328- evidenceParams(constr1).map(toDefParam))))
329- case stat =>
330- stat
330+ val (normalizedBody, enumCases, enumCompanionRef) = {
331+ // Add constructor type parameters and evidence implicit parameters
332+ // to auxiliary constructors; set defaultGetters as a side effect.
333+ def expandConstructor (tree : Tree ) = tree match {
334+ case ddef : DefDef if ddef.name.isConstructorName =>
335+ decompose(
336+ defDef(
337+ addEvidenceParams(
338+ cpy.DefDef (ddef)(tparams = constrTparams),
339+ evidenceParams(constr1).map(toDefParam))))
340+ case stat =>
341+ stat
342+ }
343+ // The Identifiers defined by a case
344+ def caseIds (tree : Tree ) = tree match {
345+ case tree : MemberDef => Ident (tree.name.toTermName) :: Nil
346+ case PatDef (_, ids, _, _) => ids
347+ }
348+ val stats = impl.body.map(expandConstructor)
349+ if (isEnum) {
350+ val (enumCases, enumStats) = stats.partition(DesugarEnums .isEnumCase)
351+ val enumCompanionRef = new TermRefTree ()
352+ val enumImport = Import (enumCompanionRef, enumCases.flatMap(caseIds))
353+ (enumImport :: enumStats, enumCases, enumCompanionRef)
354+ }
355+ else (stats, Nil , EmptyTree )
331356 }
357+
332358 def anyRef = ref(defn.AnyRefAlias .typeRef)
333359
334360 val derivedTparams = constrTparams.map(derivedTypeParam(_))
@@ -361,20 +387,16 @@ object desugar {
361387 val classTypeRef = appliedRef(classTycon)
362388
363389 // a reference to `enumClass`, with type parameters coming from the case constructor
364- lazy val enumClassTypeRef = enumClass.primaryConstructor.info match {
365- case info : PolyType =>
366- if (constrTparams.isEmpty)
367- interpolatedEnumParent(cdef.pos.startPos)
368- else if ((constrTparams.corresponds(info.paramNames))((param, name) => param.name == name))
369- appliedRef(enumClassRef)
370- else {
371- ctx.error(i " explicit extends clause needed because type parameters of case and enum class differ "
372- , cdef.pos.startPos)
373- appliedTypeTree(enumClassRef, constrTparams map (_ => anyRef))
374- }
375- case _ =>
390+ lazy val enumClassTypeRef =
391+ if (enumClass.typeParams.isEmpty)
376392 enumClassRef
377- }
393+ else if (originalTparams.isEmpty)
394+ appliedRef(enumClassRef)
395+ else {
396+ ctx.error(i " explicit extends clause needed because both enum case and enum class have type parameters "
397+ , cdef.pos.startPos)
398+ appliedTypeTree(enumClassRef, constrTparams map (_ => anyRef))
399+ }
378400
379401 // new C[Ts](paramss)
380402 lazy val creatorExpr = New (classTypeRef, constrVparamss nestedMap refOfDef)
@@ -428,6 +450,7 @@ object desugar {
428450 }
429451
430452 // Case classes and case objects get Product parents
453+ // Enum cases get an inferred parent if no parents are given
431454 var parents1 = parents
432455 if (isEnumCase && parents.isEmpty)
433456 parents1 = enumClassTypeRef :: Nil
@@ -473,7 +496,7 @@ object desugar {
473496 .withMods(companionMods | Synthetic ))
474497 .withPos(cdef.pos).toList
475498
476- val companionMeths = defaultGetters ::: eqInstances
499+ val companionMembers = defaultGetters ::: eqInstances ::: enumCases
477500
478501 // The companion object definitions, if a companion is needed, Nil otherwise.
479502 // companion definitions include:
@@ -486,18 +509,17 @@ object desugar {
486509 // For all other classes, the parent is AnyRef.
487510 val companions =
488511 if (isCaseClass) {
489- // The return type of the `apply` method
512+ // The return type of the `apply` method, and an (empty or singleton) list
513+ // of widening coercions
490514 val (applyResultTpt, widenDefs) =
491515 if (! isEnumCase)
492516 (TypeTree (), Nil )
493517 else if (parents.isEmpty || enumClass.typeParams.isEmpty)
494518 (enumClassTypeRef, Nil )
495- else {
496- val tparams = enumClass.typeParams.map(derivedTypeParam)
497- enumApplyResult(cdef, parents, tparams, appliedRef(enumClassRef, tparams))
498- }
519+ else
520+ enumApplyResult(cdef, parents, derivedEnumParams, appliedRef(enumClassRef, derivedEnumParams))
499521
500- val parent =
522+ val companionParent =
501523 if (constrTparams.nonEmpty ||
502524 constrVparamss.length > 1 ||
503525 mods.is(Abstract ) ||
@@ -519,10 +541,10 @@ object desugar {
519541 DefDef (nme.unapply, derivedTparams, (unapplyParam :: Nil ) :: Nil , TypeTree (), unapplyRHS)
520542 .withMods(synthetic)
521543 }
522- companionDefs(parent , applyMeths ::: unapplyMeth :: companionMeths )
544+ companionDefs(companionParent , applyMeths ::: unapplyMeth :: companionMembers )
523545 }
524- else if (companionMeths .nonEmpty)
525- companionDefs(anyRef, companionMeths )
546+ else if (companionMembers .nonEmpty)
547+ companionDefs(anyRef, companionMembers )
526548 else if (isValueClass) {
527549 constr0.vparamss match {
528550 case (_ :: Nil ) :: _ => companionDefs(anyRef, Nil )
@@ -531,6 +553,13 @@ object desugar {
531553 }
532554 else Nil
533555
556+ enumCompanionRef match {
557+ case ref : TermRefTree => // have the enum import watch the companion object
558+ val (modVal : ValDef ) :: _ = companions
559+ ref.watching(modVal)
560+ case _ =>
561+ }
562+
534563 // For an implicit class C[Ts](p11: T11, ..., p1N: T1N) ... (pM1: TM1, .., pMN: TMN), the method
535564 // synthetic implicit C[Ts](p11: T11, ..., p1N: T1N) ... (pM1: TM1, ..., pMN: TMN): C[Ts] =
536565 // new C[Ts](p11, ..., p1N) ... (pM1, ..., pMN) =
@@ -563,7 +592,7 @@ object desugar {
563592 }
564593
565594 val cdef1 = addEnumFlags {
566- val originalTparamsIt = originalTparams .toIterator
595+ val originalTparamsIt = impliedTparams .toIterator
567596 val originalVparamsIt = originalVparamss.toIterator.flatten
568597 val tparamAccessors = derivedTparams.map(_.withMods(originalTparamsIt.next().mods))
569598 val caseAccessor = if (isCaseClass) CaseAccessor else EmptyFlags
@@ -603,7 +632,7 @@ object desugar {
603632 val moduleName = checkNotReservedName(mdef).asTermName
604633 val impl = mdef.impl
605634 val mods = mdef.mods
606- lazy val isEnumCase = isLegalEnumCase(mdef)
635+ lazy val isEnumCase = mods.hasMod[ Mod . EnumCase ]
607636 if (mods is Package )
608637 PackageDef (Ident (moduleName), cpy.ModuleDef (mdef)(nme.PACKAGE , impl).withMods(mods &~ Package ) :: Nil )
609638 else if (isEnumCase)
@@ -650,7 +679,7 @@ object desugar {
650679 */
651680 def patDef (pdef : PatDef )(implicit ctx : Context ): Tree = flatTree {
652681 val PatDef (mods, pats, tpt, rhs) = pdef
653- if (mods.hasMod[Mod .EnumCase ] && enumCaseIsLegal(pdef) )
682+ if (mods.hasMod[Mod .EnumCase ])
654683 pats map {
655684 case id : Ident =>
656685 expandSimpleEnumCase(id.name.asTermName, mods,
0 commit comments