@@ -1049,15 +1049,35 @@ class Typer extends Namer
10491049 */
10501050 var paramIndex = Map [Name , Int ]()
10511051
1052- /** If function is of the form
1052+ /** Infer parameter type from the body of the function
1053+ *
1054+ * 1. If function is of the form
1055+ *
10531056 * (x1, ..., xN) => f(... x1, ..., XN, ...)
1057+ *
10541058 * where each `xi` occurs exactly once in the argument list of `f` (in
10551059 * any order), the type of `f`, otherwise NoType.
1060+ *
1061+ * 2. If the function is of the form
1062+ *
1063+ * (using x1, ..., xN) => f
1064+ *
1065+ * where `f` is a contextual function type of the form `(T1, ..., TN) ?=> T`,
1066+ * then `xi` takes the type `Ti`.
1067+ *
10561068 * Updates `fnBody` and `paramIndex` as a side effect.
10571069 * @post: If result exists, `paramIndex` is defined for the name of
10581070 * every parameter in `params`.
10591071 */
1060- lazy val calleeType : Type = fnBody match {
1072+ lazy val calleeType : Type = untpd.stripAnnotated(fnBody) match {
1073+ case ident : untpd.Ident if isContextual =>
1074+ val ident1 = typedIdent(ident, WildcardType )
1075+ val tp = ident1.tpe.widen
1076+ if defn.isContextFunctionType(tp) && params.size == defn.functionArity(tp) then
1077+ paramIndex = params.map(_.name).zipWithIndex.toMap
1078+ fnBody = untpd.TypedSplice (ident1)
1079+ tp.select(nme.apply)
1080+ else NoType
10611081 case app @ Apply (expr, args) =>
10621082 paramIndex = {
10631083 for (param <- params; idx <- paramIndices(param, args))
@@ -2450,7 +2470,34 @@ class Typer extends Namer
24502470
24512471 protected def makeContextualFunction (tree : untpd.Tree , pt : Type )(using Context ): Tree = {
24522472 val defn .FunctionOf (formals, _, true , _) = pt.dropDependentRefinement
2453- val ifun = desugar.makeContextualFunction(formals, tree, defn.isErasedFunctionType(pt))
2473+
2474+ // The getter of default parameters may reach here.
2475+ // Given the code below
2476+ //
2477+ // class Foo[A](run: A ?=> Int) {
2478+ // def foo[T](f: T ?=> Int = run) = ()
2479+ // }
2480+ //
2481+ // it desugars to
2482+ //
2483+ // class Foo[A](run: A ?=> Int) {
2484+ // def foo$default$1[T] = run
2485+ // def foo[T](f: T ?=> Int = run) = ()
2486+ // }
2487+ //
2488+ // The expected type for checking `run` in `foo$default$1` is
2489+ //
2490+ // <?> ?=> Int
2491+ //
2492+ // see tests/pos/i7778b.scala
2493+
2494+ val paramTypes = {
2495+ val hasWildcard = formals.exists(_.isInstanceOf [WildcardType ])
2496+ if hasWildcard then formals.map(_ => untpd.TypeTree ())
2497+ else formals.map(untpd.TypeTree )
2498+ }
2499+
2500+ val ifun = desugar.makeContextualFunction(paramTypes, tree, defn.isErasedFunctionType(pt))
24542501 typr.println(i " make contextual function $tree / $pt ---> $ifun" )
24552502 typed(ifun, pt)
24562503 }
0 commit comments