1- package dotty .tools .dotc
1+ package dotty .tools
2+ package dotc
23package transform
34
45import core ._
@@ -7,6 +8,7 @@ import MegaPhase._
78import SymUtils ._
89import NullOpsDecorator ._
910import ast .Trees ._
11+ import ast .untpd
1012import reporting ._
1113import dotty .tools .dotc .util .Spans .Span
1214
@@ -103,78 +105,73 @@ class ExpandSAMs extends MiniPhase:
103105 * ```
104106 */
105107 private def toPartialFunction (tree : Block , tpe : Type )(using Context ): Tree = {
106- /** An extractor for match, either contained in a block or standalone. */
107- object PartialFunctionRHS {
108- def unapply (tree : Tree ): Option [Match ] = tree match {
109- case Block (Nil , expr) => unapply(expr)
110- case m : Match => Some (m)
111- case _ => None
112- }
113- }
114-
115108 val closureDef(anon @ DefDef (_, List (List (param)), _, _)) = tree
116- anon.rhs match {
117- case PartialFunctionRHS (pf) =>
118- val anonSym = anon.symbol
119- val anonTpe = anon.tpe.widen
120- val parents = List (
121- defn.AbstractPartialFunctionClass .typeRef.appliedTo(anonTpe.firstParamTypes.head, anonTpe.resultType),
122- defn.SerializableType )
123- val pfSym = newNormalizedClassSymbol(anonSym.owner, tpnme.ANON_CLASS , Synthetic | Final , parents, coord = tree.span)
124-
125- def overrideSym (sym : Symbol ) = sym.copy(
126- owner = pfSym,
127- flags = Synthetic | Method | Final | Override ,
128- info = tpe.memberInfo(sym),
129- coord = tree.span).asTerm.entered
130- val isDefinedAtFn = overrideSym(defn.PartialFunction_isDefinedAt )
131- val applyOrElseFn = overrideSym(defn.PartialFunction_applyOrElse )
132-
133- def translateMatch (tree : Match , pfParam : Symbol , cases : List [CaseDef ], defaultValue : Tree )(using Context ) = {
134- val selector = tree.selector
135- val selectorTpe = selector.tpe.widen
136- val defaultSym = newSymbol(pfParam.owner, nme.WILDCARD , Synthetic | Case , selectorTpe)
137- val defaultCase =
138- CaseDef (
139- Bind (defaultSym, Underscore (selectorTpe)),
140- EmptyTree ,
141- defaultValue)
142- val unchecked = selector.annotated(New (ref(defn.UncheckedAnnot .typeRef)))
143- cpy.Match (tree)(unchecked, cases :+ defaultCase)
144- .subst(param.symbol :: Nil , pfParam :: Nil )
145- // Needed because a partial function can be written as:
146- // param => param match { case "foo" if foo(param) => param }
147- // And we need to update all references to 'param'
148- }
149-
150- def isDefinedAtRhs (paramRefss : List [List [Tree ]])(using Context ) = {
151- val tru = Literal (Constant (true ))
152- def translateCase (cdef : CaseDef ) =
153- cpy.CaseDef (cdef)(body = tru).changeOwner(anonSym, isDefinedAtFn)
154- val paramRef = paramRefss.head.head
155- val defaultValue = Literal (Constant (false ))
156- translateMatch(pf, paramRef.symbol, pf.cases.map(translateCase), defaultValue)
157- }
158-
159- def applyOrElseRhs (paramRefss : List [List [Tree ]])(using Context ) = {
160- val List (paramRef, defaultRef) = paramRefss(1 )
161- def translateCase (cdef : CaseDef ) =
162- cdef.changeOwner(anonSym, applyOrElseFn)
163- val defaultValue = defaultRef.select(nme.apply).appliedTo(paramRef)
164- translateMatch(pf, paramRef.symbol, pf.cases.map(translateCase), defaultValue)
165- }
166-
167- val constr = newConstructor(pfSym, Synthetic , Nil , Nil ).entered
168- val isDefinedAtDef = transformFollowingDeep(DefDef (isDefinedAtFn, isDefinedAtRhs(_)(using ctx.withOwner(isDefinedAtFn))))
169- val applyOrElseDef = transformFollowingDeep(DefDef (applyOrElseFn, applyOrElseRhs(_)(using ctx.withOwner(applyOrElseFn))))
170- val pfDef = ClassDef (pfSym, DefDef (constr), List (isDefinedAtDef, applyOrElseDef))
171- cpy.Block (tree)(pfDef :: Nil , New (pfSym.typeRef, Nil ))
172109
110+ // The right hand side from which to construct the partial function. This is always a Match.
111+ // If the original rhs is already a Match (possibly in braces), return that.
112+ // Otherwise construct a match `x match case _ => rhs` where `x` is the parameter of the closure.
113+ def partialFunRHS (tree : Tree ): Match = tree match
114+ case m : Match => m
115+ case Block (Nil , expr) => partialFunRHS(expr)
173116 case _ =>
174- val found = tpe.baseType(defn.Function1 )
175- report.error(TypeMismatch (found, tpe), tree.srcPos)
176- tree
117+ Match (ref(param.symbol),
118+ CaseDef (untpd.Ident (nme.WILDCARD ).withType(param.symbol.info), EmptyTree , tree) :: Nil )
119+
120+ val pfRHS = partialFunRHS(anon.rhs)
121+ val anonSym = anon.symbol
122+ val anonTpe = anon.tpe.widen
123+ val parents = List (
124+ defn.AbstractPartialFunctionClass .typeRef.appliedTo(anonTpe.firstParamTypes.head, anonTpe.resultType),
125+ defn.SerializableType )
126+ val pfSym = newNormalizedClassSymbol(anonSym.owner, tpnme.ANON_CLASS , Synthetic | Final , parents, coord = tree.span)
127+
128+ def overrideSym (sym : Symbol ) = sym.copy(
129+ owner = pfSym,
130+ flags = Synthetic | Method | Final | Override ,
131+ info = tpe.memberInfo(sym),
132+ coord = tree.span).asTerm.entered
133+ val isDefinedAtFn = overrideSym(defn.PartialFunction_isDefinedAt )
134+ val applyOrElseFn = overrideSym(defn.PartialFunction_applyOrElse )
135+
136+ def translateMatch (tree : Match , pfParam : Symbol , cases : List [CaseDef ], defaultValue : Tree )(using Context ) = {
137+ val selector = tree.selector
138+ val selectorTpe = selector.tpe.widen
139+ val defaultSym = newSymbol(pfParam.owner, nme.WILDCARD , Synthetic | Case , selectorTpe)
140+ val defaultCase =
141+ CaseDef (
142+ Bind (defaultSym, Underscore (selectorTpe)),
143+ EmptyTree ,
144+ defaultValue)
145+ val unchecked = selector.annotated(New (ref(defn.UncheckedAnnot .typeRef)))
146+ cpy.Match (tree)(unchecked, cases :+ defaultCase)
147+ .subst(param.symbol :: Nil , pfParam :: Nil )
148+ // Needed because a partial function can be written as:
149+ // param => param match { case "foo" if foo(param) => param }
150+ // And we need to update all references to 'param'
151+ }
152+
153+ def isDefinedAtRhs (paramRefss : List [List [Tree ]])(using Context ) = {
154+ val tru = Literal (Constant (true ))
155+ def translateCase (cdef : CaseDef ) =
156+ cpy.CaseDef (cdef)(body = tru).changeOwner(anonSym, isDefinedAtFn)
157+ val paramRef = paramRefss.head.head
158+ val defaultValue = Literal (Constant (false ))
159+ translateMatch(pfRHS, paramRef.symbol, pfRHS.cases.map(translateCase), defaultValue)
160+ }
161+
162+ def applyOrElseRhs (paramRefss : List [List [Tree ]])(using Context ) = {
163+ val List (paramRef, defaultRef) = paramRefss(1 )
164+ def translateCase (cdef : CaseDef ) =
165+ cdef.changeOwner(anonSym, applyOrElseFn)
166+ val defaultValue = defaultRef.select(nme.apply).appliedTo(paramRef)
167+ translateMatch(pfRHS, paramRef.symbol, pfRHS.cases.map(translateCase), defaultValue)
177168 }
169+
170+ val constr = newConstructor(pfSym, Synthetic , Nil , Nil ).entered
171+ val isDefinedAtDef = transformFollowingDeep(DefDef (isDefinedAtFn, isDefinedAtRhs(_)(using ctx.withOwner(isDefinedAtFn))))
172+ val applyOrElseDef = transformFollowingDeep(DefDef (applyOrElseFn, applyOrElseRhs(_)(using ctx.withOwner(applyOrElseFn))))
173+ val pfDef = ClassDef (pfSym, DefDef (constr), List (isDefinedAtDef, applyOrElseDef))
174+ cpy.Block (tree)(pfDef :: Nil , New (pfSym.typeRef, Nil ))
178175 }
179176
180177 private def checkRefinements (tpe : Type , tree : Tree )(using Context ): Type = tpe.dealias match {
0 commit comments