@@ -785,42 +785,55 @@ trait Checking {
785785 * @param enumCtx the context immediately enclosing the corresponding enum
786786 */
787787 private def checkEnumCaseRefsLegal (cdef : TypeDef , enumCtx : Context )(implicit ctx : Context ): Unit = {
788- def check (tree : Tree ) = {
789- // allow access to `sym` if a typedIdent just outside the enclosing enum
790- // would have produced the same symbol without errors
791- def allowAccess (name : Name , sym : Symbol ): Boolean = {
792- val testCtx = enumCtx.fresh.setNewTyperState()
793- val ref = ctx.typer.typedIdent(untpd.Ident (name), WildcardType )(testCtx)
794- ref.symbol == sym && ! testCtx.reporter.hasErrors
788+
789+ def checkCaseOrDefault (stat : Tree , caseCtx : Context ) = {
790+
791+ def check (tree : Tree ) = {
792+ // allow access to `sym` if a typedIdent just outside the enclosing enum
793+ // would have produced the same symbol without errors
794+ def allowAccess (name : Name , sym : Symbol ): Boolean = {
795+ val testCtx = caseCtx.fresh.setNewTyperState()
796+ val ref = ctx.typer.typedIdent(untpd.Ident (name), WildcardType )(testCtx)
797+ ref.symbol == sym && ! testCtx.reporter.hasErrors
798+ }
799+ checkRefsLegal(tree, cdef.symbol, allowAccess, " enum case" )
795800 }
796- checkRefsLegal(tree, cdef.symbol, allowAccess, " enum case" )
801+
802+ if (stat.symbol.is(Case ))
803+ stat match {
804+ case TypeDef (_, Template (DefDef (_, tparams, vparamss, _, _), parents, _, _)) =>
805+ tparams.foreach(check)
806+ vparamss.foreach(_.foreach(check))
807+ parents.foreach(check)
808+ case vdef : ValDef =>
809+ vdef.rhs match {
810+ case Block ((clsDef @ TypeDef (_, impl : Template )) :: Nil , _)
811+ if clsDef.symbol.isAnonymousClass =>
812+ impl.parents.foreach(check)
813+ case _ =>
814+ }
815+ case _ =>
816+ }
817+ else if (stat.symbol.is(Module ) && stat.symbol.linkedClass.is(Case ))
818+ stat match {
819+ case TypeDef (_, impl : Template ) =>
820+ for ((defaultGetter @
821+ DefDef (DefaultGetterName (nme.CONSTRUCTOR , _), _, _, _, _)) <- impl.body)
822+ check(defaultGetter.rhs)
823+ case _ =>
824+ }
797825 }
826+
798827 cdef.rhs match {
799828 case impl : Template =>
800- for (stat <- impl.body)
801- if (stat.symbol.is(Case ))
802- stat match {
803- case TypeDef (_, Template (DefDef (_, tparams, vparamss, _, _), parents, _, _)) =>
804- tparams.foreach(check)
805- vparamss.foreach(_.foreach(check))
806- parents.foreach(check)
807- case vdef : ValDef =>
808- vdef.rhs match {
809- case Block ((clsDef @ TypeDef (_, impl : Template )) :: Nil , _)
810- if clsDef.symbol.isAnonymousClass =>
811- impl.parents.foreach(check)
812- case _ =>
813- }
814- case _ =>
815- }
816- else if (stat.symbol.is(Module ) && stat.symbol.linkedClass.is(Case ))
817- stat match {
818- case TypeDef (_, impl : Template ) =>
819- for ((defaultGetter @
820- DefDef (DefaultGetterName (nme.CONSTRUCTOR , _), _, _, _, _)) <- impl.body)
821- check(defaultGetter.rhs)
822- case _ =>
823- }
829+ def isCase (stat : Tree ) = stat match {
830+ case _ : ValDef | _ : TypeDef => stat.symbol.is(Case )
831+ case _ => false
832+ }
833+ val cases = for (stat <- impl.body if isCase(stat)) yield untpd.Ident (stat.symbol.name)
834+ val caseImport : Import = Import (ref(cdef.symbol), cases)
835+ val caseCtx = enumCtx.importContext(caseImport, caseImport.symbol)
836+ for (stat <- impl.body) checkCaseOrDefault(stat, caseCtx)
824837 case _ =>
825838 }
826839 }
0 commit comments