Skip to content

Commit 34827b1

Browse files
KacperFKorbanWojciechMazur
authored andcommitted
Correctly-ish desugar poly function context bounds in function types
[Cherry-picked 3ac5cec]
1 parent d97ddd6 commit 34827b1

File tree

3 files changed

+30
-21
lines changed

3 files changed

+30
-21
lines changed

compiler/src/dotty/tools/dotc/ast/Desugar.scala

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1226,31 +1226,36 @@ object desugar {
12261226
*/
12271227
def expandPolyFunctionContextBounds(tree: PolyFunction)(using Context): PolyFunction =
12281228
val PolyFunction(tparams: List[untpd.TypeDef] @unchecked, fun @ Function(vparamTypes, res)) = tree: @unchecked
1229-
val newTParams = tparams.map {
1229+
val newTParams = tparams.mapConserve {
12301230
case td @ TypeDef(name, cb @ ContextBounds(bounds, ctxBounds)) =>
12311231
TypeDef(name, ContextBounds(bounds, List.empty))
1232+
case t => t
12321233
}
12331234
var idx = 0
1234-
val collecedContextBounds = tparams.collect {
1235+
val collectedContextBounds = tparams.collect {
12351236
case td @ TypeDef(name, cb @ ContextBounds(bounds, ctxBounds)) if ctxBounds.nonEmpty =>
1236-
// TOOD(kπ) Should we handle non empty normal bounds here?
12371237
name -> ctxBounds
12381238
}.flatMap { case (name, ctxBounds) =>
12391239
ctxBounds.map { ctxBound =>
12401240
idx = idx + 1
12411241
ctxBound match
1242-
case ContextBoundTypeTree(_, _, ownName) =>
1243-
ValDef(ownName, ctxBound, EmptyTree).withFlags(TermParam | Given)
1242+
case ctxBound @ ContextBoundTypeTree(tycon, paramName, ownName) =>
1243+
if tree.isTerm then
1244+
ValDef(ownName, ctxBound, EmptyTree).withFlags(TermParam | Given)
1245+
else
1246+
ContextBoundTypeTree(tycon, paramName, EmptyTermName) // this has to be handled in Typer#typedFunctionType
12441247
case _ =>
12451248
makeSyntheticParameter(idx, ctxBound).withAddedFlags(Given)
12461249
}
12471250
}
12481251
val contextFunctionResult =
1249-
if collecedContextBounds.isEmpty then
1250-
fun
1252+
if collectedContextBounds.isEmpty then fun
12511253
else
1252-
Function(vparamTypes, Function(collecedContextBounds, res)).withSpan(fun.span)
1253-
PolyFunction(newTParams, contextFunctionResult).withSpan(tree.span)
1254+
val mods = EmptyModifiers.withFlags(Given)
1255+
val erasedParams = collectedContextBounds.map(_ => false)
1256+
Function(vparamTypes, FunctionWithMods(collectedContextBounds, res, mods, erasedParams)).withSpan(fun.span)
1257+
if collectedContextBounds.isEmpty then tree
1258+
else PolyFunction(newTParams, contextFunctionResult).withSpan(tree.span)
12541259

12551260
/** Desugar [T_1, ..., T_M] => (P_1, ..., P_N) => R
12561261
* Into scala.PolyFunction { def apply[T_1, ..., T_M](x$1: P_1, ..., x$N: P_N): R }

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ import annotation.tailrec
4040
import Implicits.*
4141
import util.Stats.record
4242
import config.Printers.{gadts, typr}
43-
import config.Feature, Feature.{migrateTo3, modularity, sourceVersion, warnOnMigration}
43+
import config.Feature, Feature.{migrateTo3, sourceVersion, warnOnMigration}
4444
import config.SourceVersion.*
4545
import rewrites.Rewrites, Rewrites.patch
4646
import staging.StagingLevel
@@ -1142,7 +1142,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
11421142
if templ1.parents.isEmpty
11431143
&& isFullyDefined(pt, ForceDegree.flipBottom)
11441144
&& isSkolemFree(pt)
1145-
&& isEligible(pt.underlyingClassRef(refinementOK = Feature.enabled(modularity)))
1145+
&& isEligible(pt.underlyingClassRef(refinementOK = Feature.enabled(Feature.modularity)))
11461146
then
11471147
templ1 = cpy.Template(templ)(parents = untpd.TypeTree(pt) :: Nil)
11481148
for case parent: RefTree <- templ1.parents do
@@ -1717,7 +1717,11 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
17171717
typedFunctionType(desugar.makeFunctionWithValDefs(tree, pt), pt)
17181718
else
17191719
val funSym = defn.FunctionSymbol(numArgs, isContextual, isImpure)
1720-
val result = typed(cpy.AppliedTypeTree(tree)(untpd.TypeTree(funSym.typeRef), args :+ body), pt)
1720+
val args1 = args.mapConserve {
1721+
case cb: untpd.ContextBoundTypeTree => typed(cb)
1722+
case t => t
1723+
}
1724+
val result = typed(cpy.AppliedTypeTree(tree)(untpd.TypeTree(funSym.typeRef), args1 :+ body), pt)
17211725
// if there are any erased classes, we need to re-do the typecheck.
17221726
result match
17231727
case r: AppliedTypeTree if r.args.exists(_.tpe.isErasedClass) =>
@@ -2448,12 +2452,12 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
24482452
if tycon.tpe.typeParams.nonEmpty then
24492453
val tycon0 = tycon.withType(tycon.tpe.etaCollapse)
24502454
typed(untpd.AppliedTypeTree(spliced(tycon0), tparam :: Nil))
2451-
else if Feature.enabled(modularity) && tycon.tpe.member(tpnme.Self).symbol.isAbstractOrParamType then
2455+
else if Feature.enabled(Feature.modularity) && tycon.tpe.member(tpnme.Self).symbol.isAbstractOrParamType then
24522456
val tparamSplice = untpd.TypedSplice(typedExpr(tparam))
24532457
typed(untpd.RefinedTypeTree(spliced(tycon), List(untpd.TypeDef(tpnme.Self, tparamSplice))))
24542458
else
24552459
def selfNote =
2456-
if Feature.enabled(modularity) then
2460+
if Feature.enabled(Feature.modularity) then
24572461
" and\ndoes not have an abstract type member named `Self` either"
24582462
else ""
24592463
errorTree(tree,
@@ -2472,7 +2476,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
24722476
val TypeDef(_, impl: Template) = typed(refineClsDef): @unchecked
24732477
val refinements1 = impl.body
24742478
val seen = mutable.Set[Symbol]()
2475-
for (refinement <- refinements1) { // TODO: get clarity whether we want to enforce these conditions
2479+
for refinement <- refinements1 do // TODO: get clarity whether we want to enforce these conditions
24762480
typr.println(s"adding refinement $refinement")
24772481
checkRefinementNonCyclic(refinement, refineCls, seen)
24782482
val rsym = refinement.symbol
@@ -2486,7 +2490,6 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
24862490
val member = refineCls.info.member(rsym.name)
24872491
if (member.isOverloaded)
24882492
report.error(OverloadInRefinement(rsym), refinement.srcPos)
2489-
}
24902493
assignType(cpy.RefinedTypeTree(tree)(tpt1, refinements1), tpt1, refinements1, refineCls)
24912494
}
24922495

@@ -4701,7 +4704,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
47014704
cpy.Ident(qual)(qual.symbol.name.sourceModuleName.toTypeName)
47024705
case _ =>
47034706
errorTree(tree, em"cannot convert from $tree to an instance creation expression")
4704-
val tycon = ctorResultType.underlyingClassRef(refinementOK = Feature.enabled(modularity))
4707+
val tycon = ctorResultType.underlyingClassRef(refinementOK = Feature.enabled(Feature.modularity))
47054708
typed(
47064709
untpd.Select(
47074710
untpd.New(untpd.TypedSplice(tpt.withType(tycon))),

tests/pos/contextbounds-for-poly-functions.scala

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,13 @@ import scala.language.future
55
trait Ord[X]:
66
def compare(x: X, y: X): Int
77

8-
val less1 = [X: Ord] => (x: X, y: X) => summon[Ord[X]].compare(x, y) < 0
8+
// val less1 = [X: Ord] => (x: X, y: X) => summon[Ord[X]].compare(x, y) < 0
99

10-
val less2 = [X: Ord as ord] => (x: X, y: X) => ord.compare(x, y) < 0
10+
// val less2 = [X: Ord as ord] => (x: X, y: X) => ord.compare(x, y) < 0
1111

12-
// type Comparer = [X: Ord] => (x: X, y: X) => Boolean
13-
// val less3: Comparer = [X: Ord as ord] => (x: X, y: X) => ord.compare(x, y) < 0
12+
type ComparerRef = [X] => (x: X, y: X) => Ord[X] ?=> Boolean
13+
type Comparer = [X: Ord] => (x: X, y: X) => Boolean
14+
val less3: Comparer = [X: Ord as ord] => (x: X, y: X) => ord.compare(x, y) < 0
1415

1516
// type Cmp[X] = (x: X, y: X) => Boolean
1617
// type Comparer2 = [X: Ord] => Cmp[X]

0 commit comments

Comments
 (0)