@@ -3595,14 +3595,17 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
35953595
35963596 private def pushDownDeferredEvidenceParams (tpe : Type , params : List [untpd.ValDef ], span : Span )(using Context ): Type = tpe.dealias match {
35973597 case tpe : MethodType =>
3598- MethodType (tpe.paramNames)(paramNames => tpe.paramInfos, _ => pushDownDeferredEvidenceParams(tpe.resultType, params, span))
3598+ tpe.derivedLambdaType (tpe.paramNames, tpe.paramInfos, pushDownDeferredEvidenceParams(tpe.resultType, params, span))
35993599 case tpe : PolyType =>
3600- PolyType (tpe.paramNames)(paramNames => tpe.paramInfos, _ => pushDownDeferredEvidenceParams(tpe.resultType, params, span))
3600+ tpe.derivedLambdaType (tpe.paramNames, tpe.paramInfos, pushDownDeferredEvidenceParams(tpe.resultType, params, span))
36013601 case tpe : RefinedType =>
3602- // TODO(kπ): Doesn't seem right, but the PolyFunction ends up being a refinement
3603- RefinedType (pushDownDeferredEvidenceParams(tpe.parent, params, span), tpe.refinedName, pushDownDeferredEvidenceParams(tpe.refinedInfo, params, span))
3602+ tpe.derivedRefinedType(
3603+ pushDownDeferredEvidenceParams(tpe.parent, params, span),
3604+ tpe.refinedName,
3605+ pushDownDeferredEvidenceParams(tpe.refinedInfo, params, span)
3606+ )
36043607 case tpe @ AppliedType (tycon, args) if defn.isFunctionType(tpe) && args.size > 1 =>
3605- AppliedType ( tpe.tycon, args.init :+ pushDownDeferredEvidenceParams(args.last, params, span))
3608+ tpe.derivedAppliedType( tycon, args.init :+ pushDownDeferredEvidenceParams(args.last, params, span))
36063609 case tpe =>
36073610 val paramNames = params.map(_.name)
36083611 val paramTpts = params.map(_.tpt)
@@ -3611,18 +3614,52 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
36113614 typed(ctxFunction).tpe
36123615 }
36133616
3614- private def addDownDeferredEvidenceParams (tree : Tree , pt : Type )(using Context ): (Tree , Type ) = {
3617+ private def extractTopMethodTermParams (tpe : Type )(using Context ): (List [TermName ], List [Type ]) = tpe match {
3618+ case tpe : MethodType =>
3619+ tpe.paramNames -> tpe.paramInfos
3620+ case tpe : RefinedType if defn.isFunctionType(tpe.parent) =>
3621+ extractTopMethodTermParams(tpe.refinedInfo)
3622+ case _ =>
3623+ Nil -> Nil
3624+ }
3625+
3626+ private def removeTopMethodTermParams (tpe : Type )(using Context ): Type = tpe match {
3627+ case tpe : MethodType =>
3628+ tpe.resultType
3629+ case tpe : RefinedType if defn.isFunctionType(tpe.parent) =>
3630+ tpe.derivedRefinedType(tpe.parent, tpe.refinedName, removeTopMethodTermParams(tpe.refinedInfo))
3631+ case tpe : AppliedType if defn.isFunctionType(tpe) =>
3632+ tpe.args.last
3633+ case _ =>
3634+ tpe
3635+ }
3636+
3637+ private def healToPolyFunctionType (tree : Tree )(using Context ): Tree = tree match {
3638+ case defdef : DefDef if defdef.name == nme.apply && defdef.paramss.forall(_.forall(_.symbol.flags.is(TypeParam ))) && defdef.paramss.size == 1 =>
3639+ val (names, types) = extractTopMethodTermParams(defdef.tpt.tpe)
3640+ val newTpe = removeTopMethodTermParams(defdef.tpt.tpe)
3641+ val newParams = names.lazyZip(types).map((name, tpe) => SyntheticValDef (name, TypeTree (tpe), flags = SyntheticTermParam ))
3642+ val newDefDef = cpy.DefDef (defdef)(paramss = defdef.paramss ++ List (newParams), tpt = untpd.TypeTree (newTpe))
3643+ val nestedCtx = ctx.fresh.setNewTyperState()
3644+ typed(newDefDef)(using nestedCtx)
3645+ case _ => tree
3646+ }
3647+
3648+ private def addDeferredEvidenceParams (tree : Tree , pt : Type )(using Context ): (Tree , Type ) = {
36153649 tree.getAttachment(desugar.PolyFunctionApply ) match
36163650 case Some (params) if params.nonEmpty =>
36173651 tree.removeAttachment(desugar.PolyFunctionApply )
36183652 val tpe = pushDownDeferredEvidenceParams(tree.tpe, params, tree.span)
36193653 TypeTree (tpe).withSpan(tree.span) -> tpe
3654+ // case Some(params) if params.isEmpty =>
3655+ // println(s"tree: $tree")
3656+ // healToPolyFunctionType(tree) -> pt
36203657 case _ => tree -> pt
36213658 }
36223659
36233660 /** Interpolate and simplify the type of the given tree. */
36243661 protected def simplify (tree : Tree , pt : Type , locked : TypeVars )(using Context ): Tree =
3625- val (tree1, pt1) = addDownDeferredEvidenceParams (tree, pt)
3662+ val (tree1, pt1) = addDeferredEvidenceParams (tree, pt)
36263663 if ! tree1.denot.isOverloaded then // for overloaded trees: resolve overloading before simplifying
36273664 if ! tree1.tpe.widen.isInstanceOf [MethodOrPoly ] // wait with simplifying until method is fully applied
36283665 || tree1.isDef // ... unless tree is a definition
0 commit comments