1- package dotty .tools .dotc
1+ package dotty .tools
2+ package dotc
23package transform
34
45import core ._
@@ -7,6 +8,7 @@ import MegaPhase._
78import SymUtils ._
89import NullOpsDecorator ._
910import ast .Trees ._
11+ import ast .untpd
1012import reporting ._
1113import dotty .tools .dotc .util .Spans .Span
1214
@@ -113,68 +115,72 @@ class ExpandSAMs extends MiniPhase:
113115 }
114116
115117 val closureDef(anon @ DefDef (_, List (List (param)), _, _)) = tree
116- anon.rhs match {
117- case PartialFunctionRHS (pf) =>
118- val anonSym = anon.symbol
119- val anonTpe = anon.tpe.widen
120- val parents = List (
121- defn.AbstractPartialFunctionClass .typeRef.appliedTo(anonTpe.firstParamTypes.head, anonTpe.resultType),
122- defn.SerializableType )
123- val pfSym = newNormalizedClassSymbol(anonSym.owner, tpnme.ANON_CLASS , Synthetic | Final , parents, coord = tree.span)
124-
125- def overrideSym (sym : Symbol ) = sym.copy(
126- owner = pfSym,
127- flags = Synthetic | Method | Final | Override ,
128- info = tpe.memberInfo(sym),
129- coord = tree.span).asTerm.entered
130- val isDefinedAtFn = overrideSym(defn.PartialFunction_isDefinedAt )
131- val applyOrElseFn = overrideSym(defn.PartialFunction_applyOrElse )
132-
133- def translateMatch (tree : Match , pfParam : Symbol , cases : List [CaseDef ], defaultValue : Tree )(using Context ) = {
134- val selector = tree.selector
135- val selectorTpe = selector.tpe.widen
136- val defaultSym = newSymbol(pfParam.owner, nme.WILDCARD , Synthetic | Case , selectorTpe)
137- val defaultCase =
138- CaseDef (
139- Bind (defaultSym, Underscore (selectorTpe)),
140- EmptyTree ,
141- defaultValue)
142- val unchecked = selector.annotated(New (ref(defn.UncheckedAnnot .typeRef)))
143- cpy.Match (tree)(unchecked, cases :+ defaultCase)
144- .subst(param.symbol :: Nil , pfParam :: Nil )
145- // Needed because a partial function can be written as:
146- // param => param match { case "foo" if foo(param) => param }
147- // And we need to update all references to 'param'
148- }
149-
150- def isDefinedAtRhs (paramRefss : List [List [Tree ]])(using Context ) = {
151- val tru = Literal (Constant (true ))
152- def translateCase (cdef : CaseDef ) =
153- cpy.CaseDef (cdef)(body = tru).changeOwner(anonSym, isDefinedAtFn)
154- val paramRef = paramRefss.head.head
155- val defaultValue = Literal (Constant (false ))
156- translateMatch(pf, paramRef.symbol, pf.cases.map(translateCase), defaultValue)
157- }
158-
159- def applyOrElseRhs (paramRefss : List [List [Tree ]])(using Context ) = {
160- val List (paramRef, defaultRef) = paramRefss(1 )
161- def translateCase (cdef : CaseDef ) =
162- cdef.changeOwner(anonSym, applyOrElseFn)
163- val defaultValue = defaultRef.select(nme.apply).appliedTo(paramRef)
164- translateMatch(pf, paramRef.symbol, pf.cases.map(translateCase), defaultValue)
165- }
166-
167- val constr = newConstructor(pfSym, Synthetic , Nil , Nil ).entered
168- val isDefinedAtDef = transformFollowingDeep(DefDef (isDefinedAtFn, isDefinedAtRhs(_)(using ctx.withOwner(isDefinedAtFn))))
169- val applyOrElseDef = transformFollowingDeep(DefDef (applyOrElseFn, applyOrElseRhs(_)(using ctx.withOwner(applyOrElseFn))))
170- val pfDef = ClassDef (pfSym, DefDef (constr), List (isDefinedAtDef, applyOrElseDef))
171- cpy.Block (tree)(pfDef :: Nil , New (pfSym.typeRef, Nil ))
172-
118+
119+ // The right hand side from which to construct the partial function. This is always a Match.
120+ // If the original rhs is already a Match (possibly in braces), return that.
121+ // Otherwise construct a match `x match case _ => rhs` where `x` is the parameter of the closure.
122+ def partialFunRHS (tree : Tree ): Match = tree match
123+ case m : Match => m
124+ case Block (Nil , expr) => partialFunRHS(expr)
173125 case _ =>
174- val found = tpe.baseType(defn.Function1 )
175- report.error(TypeMismatch (found, tpe), tree.srcPos)
176- tree
126+ Match (ref(param.symbol),
127+ CaseDef (untpd.Ident (nme.WILDCARD ).withType(param.symbol.info), EmptyTree , tree) :: Nil )
128+
129+ val pfRHS = partialFunRHS(anon.rhs)
130+ val anonSym = anon.symbol
131+ val anonTpe = anon.tpe.widen
132+ val parents = List (
133+ defn.AbstractPartialFunctionClass .typeRef.appliedTo(anonTpe.firstParamTypes.head, anonTpe.resultType),
134+ defn.SerializableType )
135+ val pfSym = newNormalizedClassSymbol(anonSym.owner, tpnme.ANON_CLASS , Synthetic | Final , parents, coord = tree.span)
136+
137+ def overrideSym (sym : Symbol ) = sym.copy(
138+ owner = pfSym,
139+ flags = Synthetic | Method | Final | Override ,
140+ info = tpe.memberInfo(sym),
141+ coord = tree.span).asTerm.entered
142+ val isDefinedAtFn = overrideSym(defn.PartialFunction_isDefinedAt )
143+ val applyOrElseFn = overrideSym(defn.PartialFunction_applyOrElse )
144+
145+ def translateMatch (tree : Match , pfParam : Symbol , cases : List [CaseDef ], defaultValue : Tree )(using Context ) = {
146+ val selector = tree.selector
147+ val selectorTpe = selector.tpe.widen
148+ val defaultSym = newSymbol(pfParam.owner, nme.WILDCARD , Synthetic | Case , selectorTpe)
149+ val defaultCase =
150+ CaseDef (
151+ Bind (defaultSym, Underscore (selectorTpe)),
152+ EmptyTree ,
153+ defaultValue)
154+ val unchecked = selector.annotated(New (ref(defn.UncheckedAnnot .typeRef)))
155+ cpy.Match (tree)(unchecked, cases :+ defaultCase)
156+ .subst(param.symbol :: Nil , pfParam :: Nil )
157+ // Needed because a partial function can be written as:
158+ // param => param match { case "foo" if foo(param) => param }
159+ // And we need to update all references to 'param'
160+ }
161+
162+ def isDefinedAtRhs (paramRefss : List [List [Tree ]])(using Context ) = {
163+ val tru = Literal (Constant (true ))
164+ def translateCase (cdef : CaseDef ) =
165+ cpy.CaseDef (cdef)(body = tru).changeOwner(anonSym, isDefinedAtFn)
166+ val paramRef = paramRefss.head.head
167+ val defaultValue = Literal (Constant (false ))
168+ translateMatch(pfRHS, paramRef.symbol, pfRHS.cases.map(translateCase), defaultValue)
177169 }
170+
171+ def applyOrElseRhs (paramRefss : List [List [Tree ]])(using Context ) = {
172+ val List (paramRef, defaultRef) = paramRefss(1 )
173+ def translateCase (cdef : CaseDef ) =
174+ cdef.changeOwner(anonSym, applyOrElseFn)
175+ val defaultValue = defaultRef.select(nme.apply).appliedTo(paramRef)
176+ translateMatch(pfRHS, paramRef.symbol, pfRHS.cases.map(translateCase), defaultValue)
177+ }
178+
179+ val constr = newConstructor(pfSym, Synthetic , Nil , Nil ).entered
180+ val isDefinedAtDef = transformFollowingDeep(DefDef (isDefinedAtFn, isDefinedAtRhs(_)(using ctx.withOwner(isDefinedAtFn))))
181+ val applyOrElseDef = transformFollowingDeep(DefDef (applyOrElseFn, applyOrElseRhs(_)(using ctx.withOwner(applyOrElseFn))))
182+ val pfDef = ClassDef (pfSym, DefDef (constr), List (isDefinedAtDef, applyOrElseDef))
183+ cpy.Block (tree)(pfDef :: Nil , New (pfSym.typeRef, Nil ))
178184 }
179185
180186 private def checkRefinements (tpe : Type , tree : Tree )(using Context ): Type = tpe.dealias match {
0 commit comments