@@ -1504,7 +1504,7 @@ object desugar {
15041504 .withSpan(original.span.withPoint(named.span.start))
15051505
15061506 /** Main desugaring method */
1507- def apply (tree : Tree )(using Context ): Tree = {
1507+ def apply (tree : Tree , pt : Type = NoType )(using Context ): Tree = {
15081508
15091509 /** Create tree for for-comprehension `<for (enums) do body>` or
15101510 * `<for (enums) yield body>` where mapName and flatMapName are chosen
@@ -1698,11 +1698,11 @@ object desugar {
16981698 }
16991699 }
17001700
1701- def makePolyFunction (targs : List [Tree ], body : Tree ): Tree = body match {
1701+ def makePolyFunction (targs : List [Tree ], body : Tree , pt : Type ): Tree = body match {
17021702 case Parens (body1) =>
1703- makePolyFunction(targs, body1)
1703+ makePolyFunction(targs, body1, pt )
17041704 case Block (Nil , body1) =>
1705- makePolyFunction(targs, body1)
1705+ makePolyFunction(targs, body1, pt )
17061706 case Function (vargs, res) =>
17071707 assert(targs.nonEmpty)
17081708 // TODO: Figure out if we need a `PolyFunctionWithMods` instead.
@@ -1726,12 +1726,26 @@ object desugar {
17261726 }
17271727 else {
17281728 // Desugar [T_1, ..., T_M] -> (x_1: P_1, ..., x_N: P_N) => body
1729- // Into new scala.PolyFunction { def apply[T_1, ..., T_M](x_1: P_1, ..., x_N: P_N) = body }
1729+ // with pt [S_1, ..., S_M] -> (O_1, ..., O_N) => R
1730+ // Into new scala.PolyFunction { def apply[T_1, ..., T_M](x_1: P_1, ..., x_N: P_N): R2 = body }
1731+ // where R2 is R, with all references to S_1..S_M replaced with T1..T_M.
1732+
1733+ def typeTree (tp : Type ) = tp match
1734+ case RefinedType (parent, nme.apply, PolyType (_, mt)) if parent.typeSymbol eq defn.PolyFunctionClass =>
1735+ var bail = false
1736+ def mapper (tp : Type , topLevel : Boolean = false ): Tree = tp match
1737+ case tp : TypeRef => ref(tp)
1738+ case tp : TypeParamRef => Ident (applyTParams(tp.paramNum).name)
1739+ case AppliedType (tycon, args) => AppliedTypeTree (mapper(tycon), args.map(mapper(_)))
1740+ case _ => if topLevel then TypeTree () else { bail = true ; genericEmptyTree }
1741+ val mapped = mapper(mt.resultType, topLevel = true )
1742+ if bail then TypeTree () else mapped
1743+ case _ => TypeTree ()
17301744
17311745 val applyVParams = vargs.asInstanceOf [List [ValDef ]]
17321746 .map(varg => varg.withAddedFlags(mods.flags | Param ))
17331747 New (Template (emptyConstructor, List (polyFunctionTpt), Nil , EmptyValDef ,
1734- List (DefDef (nme.apply, applyTParams :: applyVParams :: Nil , TypeTree ( ), res))
1748+ List (DefDef (nme.apply, applyTParams :: applyVParams :: Nil , typeTree(pt ), res))
17351749 ))
17361750 }
17371751 case _ =>
@@ -1753,7 +1767,7 @@ object desugar {
17531767
17541768 val desugared = tree match {
17551769 case PolyFunction (targs, body) =>
1756- makePolyFunction(targs, body) orElse tree
1770+ makePolyFunction(targs, body, pt ) orElse tree
17571771 case SymbolLit (str) =>
17581772 Apply (
17591773 ref(defn.ScalaSymbolClass .companionModule.termRef),
0 commit comments