@@ -10,8 +10,10 @@ import Flags._
1010import config .Config
1111import config .Printers .typr
1212import reporting .trace
13- import typer .ProtoTypes .newTypeVar
13+ import typer .ProtoTypes .{ newTypeVar , representedParamRef }
1414import StdNames .tpnme
15+ import UnificationDirection .*
16+ import NameKinds .AvoidNameKind
1517
1618/** Methods for adding constraints and solving them.
1719 *
@@ -56,20 +58,68 @@ trait ConstraintHandling {
5658 */
5759 protected var comparedTypeLambdas : Set [TypeLambda ] = Set .empty
5860
61+ protected var myNecessaryConstraintsOnly = false
62+ /** When collecting the constraints needed for a particular subtyping
63+ * judgment to be true, we sometimes need to approximate the constraint
64+ * set (see `TypeComparer#either` for example).
65+ *
66+ * Normally, this means adding extra constraints which may not be necessary
67+ * for the subtyping judgment to be true, but if this variable is set to true
68+ * we will instead under-approximate and keep only the constraints that must
69+ * always be present for the subtyping judgment to hold.
70+ *
71+ * This is needed for GADT bounds inference to be sound, but it is also used
72+ * when constraining a method call based on its expected type to avoid adding
73+ * constraints that would later prevent us from typechecking method
74+ * arguments, see or-inf.scala and and-inf.scala for examples.
75+ */
76+ protected def necessaryConstraintsOnly (using Context ): Boolean =
77+ ctx.mode.is(Mode .GadtConstraintInference ) || myNecessaryConstraintsOnly
78+
5979 def checkReset () =
6080 assert(addConstraintInvocations == 0 )
6181 assert(frozenConstraint == false )
6282 assert(caseLambda == NoType )
6383 assert(homogenizeArgs == false )
6484 assert(comparedTypeLambdas == Set .empty)
6585
86+ def nestingLevel (param : TypeParamRef ) = constraint.typeVarOfParam(param) match
87+ case tv : TypeVar => tv.nestingLevel
88+ case _ => Int .MaxValue
89+
90+ /** If `param` is nested deeper than `maxLevel`, try to instantiate it to a
91+ * fresh type variable of level `maxLevel` and return the new variable.
92+ * If this isn't possible, throw a TypeError.
93+ */
94+ def atLevel (maxLevel : Int , param : TypeParamRef )(using Context ): TypeParamRef =
95+ if nestingLevel(param) <= maxLevel then return param
96+ LevelAvoidMap (0 , maxLevel)(param) match
97+ case freshVar : TypeVar => freshVar.origin
98+ case _ => throw new TypeError (
99+ i " Could not decrease the nesting level of ${param} from ${nestingLevel(param)} to $maxLevel in $constraint" )
100+
66101 def nonParamBounds (param : TypeParamRef )(using Context ): TypeBounds = constraint.nonParamBounds(param)
67102
103+ /** The full lower bound of `param` includes both the `nonParamBounds` and the
104+ * params in the constraint known to be `<: param`, except that
105+ * params with a `nestingLevel` higher than `param` will be instantiated
106+ * to a fresh param at a legal level. See the documentation of `TypeVar`
107+ * for details.
108+ */
68109 def fullLowerBound (param : TypeParamRef )(using Context ): Type =
69- constraint.minLower(param).foldLeft(nonParamBounds(param).lo)(_ | _)
110+ val maxLevel = nestingLevel(param)
111+ var loParams = constraint.minLower(param)
112+ if maxLevel != Int .MaxValue then
113+ loParams = loParams.mapConserve(atLevel(maxLevel, _))
114+ loParams.foldLeft(nonParamBounds(param).lo)(_ | _)
70115
116+ /** The full upper bound of `param`, see the documentation of `fullLowerBounds` above. */
71117 def fullUpperBound (param : TypeParamRef )(using Context ): Type =
72- constraint.minUpper(param).foldLeft(nonParamBounds(param).hi)(_ & _)
118+ val maxLevel = nestingLevel(param)
119+ var hiParams = constraint.minUpper(param)
120+ if maxLevel != Int .MaxValue then
121+ hiParams = hiParams.mapConserve(atLevel(maxLevel, _))
122+ hiParams.foldLeft(nonParamBounds(param).hi)(_ & _)
73123
74124 /** Full bounds of `param`, including other lower/upper params.
75125 *
@@ -79,10 +129,111 @@ trait ConstraintHandling {
79129 def fullBounds (param : TypeParamRef )(using Context ): TypeBounds =
80130 nonParamBounds(param).derivedTypeBounds(fullLowerBound(param), fullUpperBound(param))
81131
82- /** If true, eliminate wildcards in bounds by avoidance, otherwise replace
83- * them by fresh variables.
132+ /** An approximating map that prevents types nested deeper than maxLevel as
133+ * well as WildcardTypes from leaking into the constraint.
134+ * Note that level-checking is turned off after typer and in uncommitable
135+ * TyperState since these leaks should be safe.
84136 */
85- protected def approximateWildcards : Boolean = true
137+ class LevelAvoidMap (topLevelVariance : Int , maxLevel : Int )(using Context ) extends TypeOps .AvoidMap :
138+ variance = topLevelVariance
139+
140+ /** Are we allowed to refer to types of the given `level`? */
141+ private def levelOK (level : Int ): Boolean =
142+ level <= maxLevel || ctx.isAfterTyper || ! ctx.typerState.isCommittable
143+
144+ def toAvoid (tp : NamedType ): Boolean =
145+ tp.prefix == NoPrefix && ! tp.symbol.isStatic && ! levelOK(tp.symbol.nestingLevel)
146+
147+ /** Return a (possibly fresh) type variable of a level no greater than `maxLevel` which is:
148+ * - lower-bounded by `tp` if variance >= 0
149+ * - upper-bounded by `tp` if variance <= 0
150+ * If this isn't possible, return the empty range.
151+ */
152+ def legalVar (tp : TypeVar ): Type =
153+ val oldParam = tp.origin
154+ val nameKind =
155+ if variance > 0 then AvoidNameKind .UpperBound
156+ else if variance < 0 then AvoidNameKind .LowerBound
157+ else AvoidNameKind .BothBounds
158+
159+ /** If it exists, return the first param in the list created in a previous call to `legalVar(tp)`
160+ * with the appropriate level and variance.
161+ */
162+ def findParam (params : List [TypeParamRef ]): Option [TypeParamRef ] =
163+ params.find(p =>
164+ nestingLevel(p) <= maxLevel && representedParamRef(p) == oldParam &&
165+ (p.paramName.is(AvoidNameKind .BothBounds ) ||
166+ variance != 0 && p.paramName.is(nameKind)))
167+
168+ // First, check if we can reuse an existing parameter, this is more than an optimization
169+ // since it avoids an infinite loop in tests/pos/i8900-cycle.scala
170+ findParam(constraint.lower(oldParam)).orElse(findParam(constraint.upper(oldParam))) match
171+ case Some (param) =>
172+ constraint.typeVarOfParam(param)
173+ case _ =>
174+ // Otherwise, try to return a fresh type variable at `maxLevel` with
175+ // the appropriate constraints.
176+ val name = nameKind(oldParam.paramName.toTermName).toTypeName
177+ val freshVar = newTypeVar(TypeBounds .upper(tp.topType), name,
178+ nestingLevel = maxLevel, represents = oldParam)
179+ val ok =
180+ if variance < 0 then
181+ addLess(freshVar.origin, oldParam)
182+ else if variance > 0 then
183+ addLess(oldParam, freshVar.origin)
184+ else
185+ unify(freshVar.origin, oldParam)
186+ if ok then freshVar else emptyRange
187+ end legalVar
188+
189+ override def apply (tp : Type ): Type = tp match
190+ case tp : TypeVar if ! tp.isInstantiated && ! levelOK(tp.nestingLevel) =>
191+ legalVar(tp)
192+ // TypeParamRef can occur in tl bounds
193+ case tp : TypeParamRef =>
194+ constraint.typeVarOfParam(tp) match
195+ case tvar : TypeVar =>
196+ apply(tvar)
197+ case _ => super .apply(tp)
198+ case _ =>
199+ super .apply(tp)
200+
201+ override def mapWild (t : WildcardType ) =
202+ if ctx.mode.is(Mode .TypevarsMissContext ) then super .mapWild(t)
203+ else
204+ val tvar = newTypeVar(apply(t.effectiveBounds).toBounds, nestingLevel = maxLevel)
205+ tvar
206+ end LevelAvoidMap
207+
208+ /** Approximate `rawBound` if needed to make it a legal bound of `param` by
209+ * avoiding wildcards and types with a level strictly greater than its
210+ * `nestingLevel`.
211+ *
212+ * Note that level-checking must be performed here and cannot be delayed
213+ * until instantiation because if we allow level-incorrect bounds, then we
214+ * might end up reasoning with bad bounds outside of the scope where they are
215+ * defined. This can lead to level-correct but unsound instantiations as
216+ * demonstrated by tests/neg/i8900.scala.
217+ */
218+ protected def legalBound (param : TypeParamRef , rawBound : Type , isUpper : Boolean )(using Context ): Type =
219+ // Over-approximate for soundness.
220+ var variance = if isUpper then - 1 else 1
221+ // ...unless we can only infer necessary constraints, in which case we
222+ // flip the variance to under-approximate.
223+ if necessaryConstraintsOnly then variance = - variance
224+
225+ val approx = new LevelAvoidMap (variance, nestingLevel(param)):
226+ override def legalVar (tp : TypeVar ): Type =
227+ // `legalVar` will create a type variable whose bounds depend on
228+ // `variance`, but whether the variance is positive or negative,
229+ // we can still infer necessary constraints since just creating a
230+ // type variable doesn't reduce the set of possible solutions.
231+ // Therefore, we can safely "unflip" the variance flipped above.
232+ // This is necessary for i8900-unflip.scala to typecheck.
233+ val v = if necessaryConstraintsOnly then - this .variance else this .variance
234+ atVariance(v)(super .legalVar(tp))
235+ approx(rawBound)
236+ end legalBound
86237
87238 protected def addOneBound (param : TypeParamRef , rawBound : Type , isUpper : Boolean )(using Context ): Boolean =
88239 if ! constraint.contains(param) then true
@@ -91,12 +242,7 @@ trait ConstraintHandling {
91242 // so we shouldn't allow them as constraints either.
92243 false
93244 else
94- val dropWildcards = new AvoidWildcardsMap :
95- if ! isUpper then variance = - 1
96- override def mapWild (t : WildcardType ) =
97- if approximateWildcards then super .mapWild(t)
98- else newTypeVar(apply(t.effectiveBounds).toBounds)
99- val bound = dropWildcards(rawBound)
245+ val bound = legalBound(param, rawBound, isUpper)
100246 val oldBounds @ TypeBounds (lo, hi) = constraint.nonParamBounds(param)
101247 val equalBounds = (if isUpper then lo else hi) eq bound
102248 if equalBounds && ! bound.existsPart(_ eq param, StopAt .Static ) then
@@ -191,19 +337,50 @@ trait ConstraintHandling {
191337
192338 def location (using Context ) = " " // i"in ${ctx.typerState.stateChainStr}" // use for debugging
193339
194- /** Make p2 = p1, transfer all bounds of p2 to p1
195- * @pre less(p1)(p2)
340+ /** Unify p1 with p2: one parameter will be kept in the constraint, the
341+ * other will be removed and its bounds transferred to the remaining one.
342+ *
343+ * If p1 and p2 have different `nestingLevel`, the parameter with the lowest
344+ * level will be kept and the transferred bounds from the other parameter
345+ * will be adjusted for level-correctness.
196346 */
197347 private def unify (p1 : TypeParamRef , p2 : TypeParamRef )(using Context ): Boolean = {
198348 constr.println(s " unifying $p1 $p2" )
199- assert(constraint.isLess(p1, p2))
200- constraint = constraint.addLess(p2, p1)
349+ if ! constraint.isLess(p1, p2) then
350+ constraint = constraint.addLess(p1, p2)
351+
352+ val level1 = nestingLevel(p1)
353+ val level2 = nestingLevel(p2)
354+ val pKept = if level1 <= level2 then p1 else p2
355+ val pRemoved = if level1 <= level2 then p2 else p1
356+
357+ constraint = constraint.addLess(p2, p1, direction = if pKept eq p1 then KeepParam2 else KeepParam1 )
358+
359+ val boundKept = constraint.nonParamBounds(pKept).substParam(pRemoved, pKept)
360+ var boundRemoved = constraint.nonParamBounds(pRemoved).substParam(pRemoved, pKept)
361+
362+ if level1 != level2 then
363+ boundRemoved = LevelAvoidMap (- 1 , math.min(level1, level2))(boundRemoved)
364+ val TypeBounds (lo, hi) = boundRemoved
365+ // After avoidance, the interval might be empty, e.g. in
366+ // tests/pos/i8900-promote.scala:
367+ // >: x.type <: Singleton
368+ // becomes:
369+ // >: Int <: Singleton
370+ // In that case, we can still get a legal constraint
371+ // by replacing the lower-bound to get:
372+ // >: Int & Singleton <: Singleton
373+ if ! isSub(lo, hi) then
374+ boundRemoved = TypeBounds (lo & hi, hi)
375+
201376 val down = constraint.exclusiveLower(p2, p1)
202377 val up = constraint.exclusiveUpper(p1, p2)
203- constraint = constraint.unify(p1, p2)
204- val bounds = constraint.nonParamBounds(p1)
205- val lo = bounds.lo
206- val hi = bounds.hi
378+
379+ val newBounds = (boundKept & boundRemoved).bounds
380+ constraint = constraint.updateEntry(pKept, newBounds).replace(pRemoved, pKept)
381+
382+ val lo = newBounds.lo
383+ val hi = newBounds.hi
207384 isSub(lo, hi) &&
208385 down.forall(addOneBound(_, hi, isUpper = true )) &&
209386 up.forall(addOneBound(_, lo, isUpper = false ))
@@ -256,6 +433,7 @@ trait ConstraintHandling {
256433 final def approximation (param : TypeParamRef , fromBelow : Boolean )(using Context ): Type =
257434 constraint.entry(param) match
258435 case entry : TypeBounds =>
436+ val maxLevel = nestingLevel(param)
259437 val useLowerBound = fromBelow || param.occursIn(entry.hi)
260438 val inst = if useLowerBound then fullLowerBound(param) else fullUpperBound(param)
261439 typr.println(s " approx ${param.show}, from below = $fromBelow, inst = ${inst.show}" )
0 commit comments