@@ -946,7 +946,7 @@ class Typer extends Namer
946946 * def double(x: Char): String = s"$x$x"
947947 * "abc" flatMap double
948948 */
949- private def decomposeProtoFunction (pt : Type , defaultArity : Int )(using Context ): (List [Type ], untpd.Tree ) = {
949+ private def decomposeProtoFunction (pt : Type , defaultArity : Int , tree : untpd. Tree )(using Context ): (List [Type ], untpd.Tree ) = {
950950 def typeTree (tp : Type ) = tp match {
951951 case _ : WildcardType => untpd.TypeTree ()
952952 case _ => untpd.TypeTree (tp)
@@ -957,7 +957,15 @@ class Typer extends Namer
957957 newTypeVar(apply(bounds.orElse(TypeBounds .empty)).bounds)
958958 case _ => mapOver(t)
959959 }
960- pt.stripTypeVar.dealias match {
960+ val pt1 = pt.stripTypeVar.dealias
961+ if (pt1 ne pt1.dropDependentRefinement)
962+ && defn.isContextFunctionType(pt1.nonPrivateMember(nme.apply).info.finalResultType)
963+ then
964+ ctx.error(
965+ i """ Implementation restriction: Expected result type $pt1
966+ |is a curried dependent context function type. Such types are not yet supported. """ ,
967+ tree.sourcePos)
968+ pt1 match {
961969 case pt1 if defn.isNonRefinedFunction(pt1) =>
962970 // if expected parameter type(s) are wildcards, approximate from below.
963971 // if expected result type is a wildcard, approximate from above.
@@ -970,7 +978,7 @@ class Typer extends Namer
970978 else
971979 typeTree(restpe))
972980 case tp : TypeParamRef =>
973- decomposeProtoFunction(ctx.typerState.constraint.entry(tp).bounds.hi, defaultArity)
981+ decomposeProtoFunction(ctx.typerState.constraint.entry(tp).bounds.hi, defaultArity, tree )
974982 case _ =>
975983 (List .tabulate(defaultArity)(alwaysWildcardType), untpd.TypeTree ())
976984 }
@@ -1131,7 +1139,7 @@ class Typer extends Namer
11311139 case _ =>
11321140 }
11331141
1134- val (protoFormals, resultTpt) = decomposeProtoFunction(pt, params.length)
1142+ val (protoFormals, resultTpt) = decomposeProtoFunction(pt, params.length, tree )
11351143
11361144 /** The inferred parameter type for a parameter in a lambda that does
11371145 * not have an explicit type given.
@@ -1261,7 +1269,7 @@ class Typer extends Namer
12611269 typedMatchFinish(tree, tpd.EmptyTree , defn.ImplicitScrutineeTypeRef , cases1, pt)
12621270 }
12631271 else {
1264- val (protoFormals, _) = decomposeProtoFunction(pt, 1 )
1272+ val (protoFormals, _) = decomposeProtoFunction(pt, 1 , tree )
12651273 val checkMode =
12661274 if (pt.isRef(defn.PartialFunctionClass )) desugar.MatchCheck .None
12671275 else desugar.MatchCheck .Exhaustive
@@ -1447,17 +1455,40 @@ class Typer extends Namer
14471455 }
14481456
14491457 def typedReturn (tree : untpd.Return )(using Context ): Return = {
1458+
1459+ /** If `pt` is a context function type, its return type. If the CFT
1460+ * is dependent, instantiate with the parameters of the associated
1461+ * anonymous function.
1462+ * @param paramss the parameters of the anonymous functions
1463+ * enclosing the return expression
1464+ */
1465+ def instantiateCFT (pt : Type , paramss : => List [List [Symbol ]]): Type =
1466+ val ift = defn.asContextFunctionType(pt)
1467+ if ift.exists then
1468+ ift.nonPrivateMember(nme.apply).info match
1469+ case appType : MethodType =>
1470+ instantiateCFT(appType.instantiate(paramss.head.map(_.termRef)), paramss.tail)
1471+ else pt
1472+
14501473 def returnProto (owner : Symbol , locals : Scope ): Type =
14511474 if (owner.isConstructor) defn.UnitType
1452- else owner.info match {
1453- case info : PolyType =>
1454- val tparams = locals.toList.takeWhile(_ is TypeParam )
1455- assert(info.paramNames.length == tparams.length,
1456- i " return mismatch from $owner, tparams = $tparams, locals = ${locals.toList}%, % " )
1457- info.instantiate(tparams.map(_.typeRef)).finalResultType
1458- case info =>
1459- info.finalResultType
1460- }
1475+ else
1476+ val rt = owner.info match
1477+ case info : PolyType =>
1478+ val tparams = locals.toList.takeWhile(_ is TypeParam )
1479+ assert(info.paramNames.length == tparams.length,
1480+ i " return mismatch from $owner, tparams = $tparams, locals = ${locals.toList}%, % " )
1481+ info.instantiate(tparams.map(_.typeRef)).finalResultType
1482+ case info =>
1483+ info.finalResultType
1484+ def iftParamss = ctx.owner.ownersIterator
1485+ .filter(_.is(Method , butNot = Accessor ))
1486+ .takeWhile(_.isAnonymousFunction)
1487+ .toList
1488+ .reverse
1489+ .map(_.paramSymss.head)
1490+ instantiateCFT(rt, iftParamss)
1491+
14611492 def enclMethInfo (cx : Context ): (Tree , Type ) = {
14621493 val owner = cx.owner
14631494 if (owner.isType) {
@@ -3155,7 +3186,7 @@ class Typer extends Namer
31553186
31563187 def isContextFunctionRef (wtp : Type ): Boolean = wtp match {
31573188 case RefinedType (parent, nme.apply, _) =>
3158- isContextFunctionRef(parent) // apply refinements indicate a dependent IFT
3189+ isContextFunctionRef(parent) // apply refinements indicate a dependent CFT
31593190 case _ =>
31603191 val underlying = wtp.underlyingClassRef(refinementOK = false ) // other refinements are not OK
31613192 defn.isContextFunctionClass(underlying.classSymbol)
0 commit comments