Skip to content

Commit 6b100a7

Browse files
authored
Changes in CC around classes, constructors, this (#23874)
This PR started with a shocking discovery. We actually did not record this-references in use sets. Once this was fixed, some tests failed and the stdlib build broke with 20 errors. 18 of these were oversights that were not caught before and were easily fixed. 2 of them pointed to further capture checking problems. This caused a sequence of successive fixes until everything compiled again and all the tests made sense. The order in the commits here is a bit different: We first apply the fixes to the stdlib and the compiler and then start recording uses of this. This is so that we don't have a failing build in intermediate steps.
2 parents f24da5c + 4e8f7d9 commit 6b100a7

27 files changed

+254
-82
lines changed

compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala

Lines changed: 86 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -407,25 +407,22 @@ class CheckCaptures extends Recheck, SymTransformer:
407407
else i"references $cs1$cs1description are not all",
408408
cs1, cs2, pos, provenance)
409409

410-
/** If `sym` is a class or method nested inside a term, a capture set variable representing
411-
* the captured variables of the environment associated with `sym`.
410+
/** If `sym` is a method or a non-static inner class, a capture set variable
411+
* representing the captured variables of the environment associated with `sym`.
412412
*/
413413
def capturedVars(sym: Symbol)(using Context): CaptureSet =
414414
myCapturedVars.getOrElseUpdate(sym,
415-
if sym.ownersIterator.exists(_.isTerm)
415+
if sym.isTerm || !sym.owner.isStaticOwner
416416
then CaptureSet.Var(sym.owner, level = ccState.symLevel(sym))
417417
else CaptureSet.empty)
418418

419419
// ---- Record Uses with MarkFree ----------------------------------------------------
420420

421421
/** The next environment enclosing `env` that needs to be charged
422422
* with free references.
423-
* @param included Whether an environment is included in the range of
424-
* environments to charge. Once `included` is false, no
425-
* more environments need to be charged.
426423
*/
427-
def nextEnvToCharge(env: Env, included: Env => Boolean)(using Context): Env =
428-
if env.owner.isConstructor && included(env.outer) then env.outer.outer
424+
def nextEnvToCharge(env: Env)(using Context): Env | Null =
425+
if env.owner.isConstructor then env.outer.outer0
429426
else env.outer
430427

431428
/** A description where this environment comes from */
@@ -458,21 +455,27 @@ class CheckCaptures extends Recheck, SymTransformer:
458455
markFree(sym, sym.termRef, tree)
459456

460457
def markFree(sym: Symbol, ref: Capability, tree: Tree)(using Context): Unit =
461-
if sym.exists && ref.isTracked then markFree(ref.singletonCaptureSet, tree)
458+
if sym.exists then markFree(ref, tree)
459+
460+
def markFree(ref: Capability, tree: Tree)(using Context): Unit =
461+
if ref.isTracked then markFree(ref.singletonCaptureSet, tree)
462462

463463
/** Make sure the (projected) `cs` is a subset of the capture sets of all enclosing
464464
* environments. At each stage, only include references from `cs` that are outside
465465
* the environment's owner
466466
*/
467-
def markFree(cs: CaptureSet, tree: Tree)(using Context): Unit =
467+
def markFree(cs: CaptureSet, tree: Tree, addUseInfo: Boolean = true)(using Context): Unit =
468468
// A captured reference with the symbol `sym` is visible from the environment
469469
// if `sym` is not defined inside the owner of the environment.
470470
inline def isVisibleFromEnv(sym: Symbol, env: Env) =
471471
sym.exists && {
472+
val effectiveOwner =
473+
if env.owner.isConstructor then env.owner.owner
474+
else env.owner
472475
if env.kind == EnvKind.NestedInOwner then
473-
!sym.isProperlyContainedIn(env.owner)
476+
!sym.isProperlyContainedIn(effectiveOwner)
474477
else
475-
!sym.isContainedIn(env.owner)
478+
!sym.isContainedIn(effectiveOwner)
476479
}
477480

478481
/** Avoid locally defined capability by charging the underlying type
@@ -535,13 +538,15 @@ class CheckCaptures extends Recheck, SymTransformer:
535538
checkSubset(included, env.captured, tree.srcPos, provenance(env))
536539
capt.println(i"Include call or box capture $included from $cs in ${env.owner} --> ${env.captured}")
537540
if !isOfNestedMethod(env) then
538-
recur(included, nextEnvToCharge(env, !_.owner.isStaticOwner), env)
541+
val nextEnv = nextEnvToCharge(env)
542+
if nextEnv != null && !nextEnv.owner.isStaticOwner then
543+
recur(included, nextEnv, env)
539544
// Under deferredReaches, don't propagate out of methods inside terms.
540545
// The use set of these methods will be charged when that method is called.
541546

542547
if !cs.isAlwaysEmpty then
543548
recur(cs, curEnv, null)
544-
useInfos += ((tree, cs, curEnv))
549+
if addUseInfo then useInfos += ((tree, cs, curEnv))
545550
end markFree
546551

547552
/** If capability `c` refers to a parameter that is not implicitly or explicitly
@@ -626,25 +631,33 @@ class CheckCaptures extends Recheck, SymTransformer:
626631
// If ident refers to a parameterless method, charge its cv to the environment
627632
includeCallCaptures(sym, sym.info, tree)
628633
else if !sym.isStatic then
629-
// Otherwise charge its symbol, but add all selections and also any `.rd`
630-
// modifier implied by the expected type `pt`.
631-
// Example: If we have `x` and the expected type says we select that with `.a.b`
632-
// where `b` is a read-only method, we charge `x.a.b.rd` instead of `x`.
633-
def addSelects(ref: TermRef, pt: Type): Capability = pt match
634-
case pt: PathSelectionProto if ref.isTracked =>
635-
if pt.sym.isReadOnlyMethod then
636-
ref.readOnly
637-
else
638-
// if `ref` is not tracked then the selection could not give anything new
639-
// class SerializationProxy in stdlib-cc/../LazyListIterable.scala has an example where this matters.
640-
addSelects(ref.select(pt.sym).asInstanceOf[TermRef], pt.pt)
641-
case _ => ref
642-
var pathRef: Capability = addSelects(sym.termRef, pt)
643-
if pathRef.derivesFromMutable && pt.isValueType && !pt.isMutableType then
644-
pathRef = pathRef.readOnly
645-
markFree(sym, pathRef, tree)
634+
markFree(sym, pathRef(sym.termRef, pt), tree)
646635
mapResultRoots(super.recheckIdent(tree, pt), tree.symbol)
647636

637+
override def recheckThis(tree: This, pt: Type)(using Context): Type =
638+
markFree(pathRef(tree.tpe.asInstanceOf[ThisType], pt), tree)
639+
super.recheckThis(tree, pt)
640+
641+
/** Add all selections and also any `.rd modifier implied by the expected
642+
* type `pt` to `base`. Example:
643+
* If we have `x` and the expected type says we select that with `.a.b`
644+
* where `b` is a read-only method, we charge `x.a.b.rd` instead of `x`.
645+
*/
646+
private def pathRef(base: TermRef | ThisType, pt: Type)(using Context): Capability =
647+
def addSelects(ref: TermRef | ThisType, pt: Type): Capability = pt match
648+
case pt: PathSelectionProto if ref.isTracked =>
649+
if pt.sym.isReadOnlyMethod then
650+
ref.readOnly
651+
else
652+
// if `ref` is not tracked then the selection could not give anything new
653+
// class SerializationProxy in stdlib-cc/../LazyListIterable.scala has an example where this matters.
654+
addSelects(ref.select(pt.sym).asInstanceOf[TermRef], pt.pt)
655+
case _ => ref
656+
val ref: Capability = addSelects(base, pt)
657+
if ref.derivesFromMutable && pt.isValueType && !pt.isMutableType
658+
then ref.readOnly
659+
else ref
660+
648661
/** The expected type for the qualifier of a selection. If the selection
649662
* could be part of a capability path or is a a read-only method, we return
650663
* a PathSelectionProto.
@@ -866,7 +879,7 @@ class CheckCaptures extends Recheck, SymTransformer:
866879
val (refined, cs) = addParamArgRefinements(core, initCs)
867880
refined.capturing(cs)
868881

869-
augmentConstructorType(resType, capturedVars(cls) ++ capturedVars(constr))
882+
augmentConstructorType(resType, capturedVars(cls))
870883
.showing(i"constr type $mt with $argTypes%, % in $constr = $result", capt)
871884
end refineConstructorInstance
872885

@@ -975,6 +988,8 @@ class CheckCaptures extends Recheck, SymTransformer:
975988
* - Interpolate contravariant capture set variables in result type.
976989
*/
977990
override def recheckValDef(tree: ValDef, sym: Symbol)(using Context): Type =
991+
val savedEnv = curEnv
992+
val runInConstructor = !sym.isOneOf(Param | ParamAccessor | Lazy | NonMember)
978993
try
979994
if sym.is(Module) then sym.info // Modules are checked by checking the module class
980995
else
@@ -993,6 +1008,8 @@ class CheckCaptures extends Recheck, SymTransformer:
9931008
""
9941009
disallowBadRootsIn(
9951010
tree.tpt.nuType, NoSymbol, i"Mutable $sym", "have type", addendum, sym.srcPos)
1011+
if runInConstructor then
1012+
pushConstructorEnv()
9961013
checkInferredResult(super.recheckValDef(tree, sym), tree)
9971014
finally
9981015
if !sym.is(Param) then
@@ -1002,6 +1019,22 @@ class CheckCaptures extends Recheck, SymTransformer:
10021019
// function is compiled since we do not propagate expected types into blocks.
10031020
interpolateIfInferred(tree.tpt, sym)
10041021

1022+
def declaredCaptures = tree.tpt.nuType.captureSet
1023+
if runInConstructor && savedEnv.owner.isClass then
1024+
curEnv = savedEnv
1025+
markFree(declaredCaptures, tree, addUseInfo = false)
1026+
1027+
if sym.owner.isStaticOwner && !declaredCaptures.elems.isEmpty && sym != defn.captureRoot then
1028+
def where =
1029+
if sym.effectiveOwner.is(Package) then "top-level definition"
1030+
else i"member of static ${sym.owner}"
1031+
report.warning(
1032+
em"""$sym has a non-empty capture set but will not be added as
1033+
|a capability to computed capture sets since it is globally accessible
1034+
|as a $where. Global values cannot be capabilities.""",
1035+
tree.namePos)
1036+
end recheckValDef
1037+
10051038
/** Recheck method definitions:
10061039
* - check body in a nested environment that tracks uses, in a nested level,
10071040
* and in a nested context that knows abaout Contains parameters so that we
@@ -1228,6 +1261,24 @@ class CheckCaptures extends Recheck, SymTransformer:
12281261
recheckFinish(result, arg, pt)
12291262
*/
12301263

1264+
/** If environment is owned by a class, run in a new environment owned by
1265+
* its primary constructor instead.
1266+
*/
1267+
def pushConstructorEnv()(using Context): Unit =
1268+
if curEnv.owner.isClass then
1269+
val constr = curEnv.owner.primaryConstructor
1270+
if constr.exists then
1271+
val constrSet = capturedVars(constr)
1272+
if capturedVars(constr) ne CaptureSet.empty then
1273+
curEnv = Env(constr, EnvKind.Regular, constrSet, curEnv)
1274+
1275+
override def recheckStat(stat: Tree)(using Context): Unit =
1276+
val saved = curEnv
1277+
if !stat.isInstanceOf[MemberDef] then
1278+
pushConstructorEnv()
1279+
try recheck(stat)
1280+
finally curEnv = saved
1281+
12311282
/** The main recheck method does some box adapation for all nodes:
12321283
* - If expected type `pt` is boxed and the tree is a lambda or a reference,
12331284
* don't propagate free variables.
@@ -2021,7 +2072,9 @@ class CheckCaptures extends Recheck, SymTransformer:
20212072
if env.kind == EnvKind.Boxed then env.owner
20222073
else if isOfNestedMethod(env) then env.owner.owner
20232074
else if env.owner.isStaticOwner then NoSymbol
2024-
else boxedOwner(nextEnvToCharge(env, alwaysTrue))
2075+
else
2076+
val nextEnv = nextEnvToCharge(env)
2077+
if nextEnv == null then NoSymbol else boxedOwner(nextEnv)
20252078

20262079
def checkUseUnlessBoxed(c: Capability, croot: NamedType) =
20272080
if !boxedOwner(env).isContainedIn(croot.symbol.owner) then

compiler/src/dotty/tools/dotc/transform/Recheck.scala

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,12 @@ abstract class Recheck extends Phase, SymTransformer:
248248
def recheckSelection(tree: Select, qualType: Type, name: Name, pt: Type)(using Context): Type =
249249
recheckSelection(tree, qualType, name, sharpen = identity[Denotation])
250250

251+
def recheckThis(tree: This, pt: Type)(using Context): Type =
252+
tree.tpe
253+
254+
def recheckSuper(tree: Super, pt: Type)(using Context): Type =
255+
tree.tpe
256+
251257
def recheckBind(tree: Bind, pt: Type)(using Context): Type = tree match
252258
case Bind(name, body) =>
253259
recheck(body, pt)
@@ -487,12 +493,15 @@ abstract class Recheck extends Phase, SymTransformer:
487493
recheckStats(tree.stats)
488494
NoType
489495

496+
def recheckStat(stat: Tree)(using Context): Unit =
497+
recheck(stat)
498+
490499
def recheckStats(stats: List[Tree])(using Context): Unit =
491500
@tailrec def traverse(stats: List[Tree])(using Context): Unit = stats match
492501
case (imp: Import) :: rest =>
493502
traverse(rest)(using ctx.importContext(imp, imp.symbol))
494503
case stat :: rest =>
495-
recheck(stat)
504+
recheckStat(stat)
496505
traverse(rest)
497506
case _ =>
498507
traverse(stats)
@@ -540,7 +549,9 @@ abstract class Recheck extends Phase, SymTransformer:
540549
def recheckUnnamed(tree: Tree, pt: Type): Type = tree match
541550
case tree: Apply => recheckApply(tree, pt)
542551
case tree: TypeApply => recheckTypeApply(tree, pt)
543-
case _: New | _: This | _: Super | _: Literal => tree.tpe
552+
case tree: This => recheckThis(tree, pt)
553+
case tree: Super => recheckSuper(tree, pt)
554+
case _: New | _: Literal => tree.tpe
544555
case tree: Typed => recheckTyped(tree)
545556
case tree: Assign => recheckAssign(tree)
546557
case tree: Block => recheckBlock(tree, pt)

library/src/scala/collection/Iterator.scala

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -418,7 +418,7 @@ trait Iterator[+A] extends IterableOnce[A] with IterableOnceOps[A, Iterator, Ite
418418

419419
@deprecated("Call scanRight on an Iterable instead.", "2.13.0")
420420
def scanRight[B](z: B)(op: (A, B) => B): Iterator[B]^{this, op} = ArrayBuffer.from(this).scanRight(z)(op).iterator
421-
421+
422422
/** Finds index of the first element satisfying some predicate after or at some start index.
423423
*
424424
* $mayNotTerminateInf
@@ -494,9 +494,9 @@ trait Iterator[+A] extends IterableOnce[A] with IterableOnceOps[A, Iterator, Ite
494494
while (p(hd) == isFlipped) {
495495
if (!self.hasNext) return false
496496
hd = self.next()
497-
}
497+
}
498498
hdDefined = true
499-
true
499+
true
500500
}
501501

502502
def next() =
@@ -874,7 +874,7 @@ trait Iterator[+A] extends IterableOnce[A] with IterableOnceOps[A, Iterator, Ite
874874
*/
875875
def duplicate: (Iterator[A]^{this}, Iterator[A]^{this}) = {
876876
val gap = new scala.collection.mutable.Queue[A]
877-
var ahead: Iterator[A] = null
877+
var ahead: Iterator[A]^ = null
878878
class Partner extends AbstractIterator[A] {
879879
override def knownSize: Int = self.synchronized {
880880
val thisSize = self.knownSize
@@ -890,7 +890,7 @@ trait Iterator[+A] extends IterableOnce[A] with IterableOnceOps[A, Iterator, Ite
890890
if (gap.isEmpty) ahead = this
891891
if (this eq ahead) {
892892
val e = self.next()
893-
gap enqueue e
893+
gap.enqueue(e)
894894
e
895895
} else gap.dequeue()
896896
}
@@ -918,7 +918,7 @@ trait Iterator[+A] extends IterableOnce[A] with IterableOnceOps[A, Iterator, Ite
918918
*/
919919
def patch[B >: A](from: Int, patchElems: Iterator[B]^, replaced: Int): Iterator[B]^{this, patchElems} =
920920
new AbstractIterator[B] {
921-
private[this] var origElems = self
921+
private[this] var origElems: Iterator[B]^ = self
922922
// > 0 => that many more elems from `origElems` before switching to `patchElems`
923923
// 0 => need to drop elems from `origElems` and start using `patchElems`
924924
// -1 => have dropped elems from `origElems`, will be using `patchElems` until it's empty

library/src/scala/collection/LazyZipOps.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -389,7 +389,7 @@ final class LazyZip4[+El1, +El2, +El3, +El4, C1] private[collection](src: C1,
389389
}
390390

391391
private def toIterable: View[(El1, El2, El3, El4)]^{this} = new AbstractView[(El1, El2, El3, El4)] {
392-
def iterator: AbstractIterator[(El1, El2, El3, El4)] = new AbstractIterator[(El1, El2, El3, El4)] {
392+
def iterator: AbstractIterator[(El1, El2, El3, El4)]^{this} = new AbstractIterator[(El1, El2, El3, El4)] {
393393
private[this] val elems1 = coll1.iterator
394394
private[this] val elems2 = coll2.iterator
395395
private[this] val elems3 = coll3.iterator

library/src/scala/collection/SeqView.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -202,10 +202,10 @@ object SeqView {
202202
override def knownSize: Int = len
203203
override def isEmpty: Boolean = len == 0
204204
override def to[C1](factory: Factory[A, C1]): C1 = _sorted.to(factory)
205-
override def reverse: SeqView[A] = new ReverseSorted
205+
override def reverse: SeqView[A]^{this} = new ReverseSorted
206206
// we know `_sorted` is either tiny or has efficient random access,
207207
// so this is acceptable for `reversed`
208-
override protected def reversed: Iterable[A] = new ReverseSorted
208+
override protected def reversed: Iterable[A]^{this} = new ReverseSorted
209209

210210
override def sorted[B1 >: A](implicit ord1: Ordering[B1]): SeqView[A]^{this} =
211211
if (ord1 == this.ord) this

library/src/scala/collection/Stepper.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,7 @@ trait IntStepper extends Stepper[Int] {
260260

261261
def spliterator[B >: Int]: Spliterator.OfInt^{this} = new IntStepper.IntStepperSpliterator(this)
262262

263-
def javaIterator[B >: Int]: PrimitiveIterator.OfInt = new PrimitiveIterator.OfInt {
263+
def javaIterator[B >: Int]: PrimitiveIterator.OfInt^{this} = new PrimitiveIterator.OfInt {
264264
def hasNext: Boolean = hasStep
265265
def nextInt(): Int = nextStep()
266266
}
@@ -298,7 +298,7 @@ trait DoubleStepper extends Stepper[Double] {
298298

299299
def spliterator[B >: Double]: Spliterator.OfDouble^{this} = new DoubleStepper.DoubleStepperSpliterator(this)
300300

301-
def javaIterator[B >: Double]: PrimitiveIterator.OfDouble = new PrimitiveIterator.OfDouble {
301+
def javaIterator[B >: Double]: PrimitiveIterator.OfDouble^{this} = new PrimitiveIterator.OfDouble {
302302
def hasNext: Boolean = hasStep
303303
def nextDouble(): Double = nextStep()
304304
}
@@ -337,7 +337,7 @@ trait LongStepper extends Stepper[Long] {
337337

338338
def spliterator[B >: Long]: Spliterator.OfLong^{this} = new LongStepper.LongStepperSpliterator(this)
339339

340-
def javaIterator[B >: Long]: PrimitiveIterator.OfLong = new PrimitiveIterator.OfLong {
340+
def javaIterator[B >: Long]: PrimitiveIterator.OfLong^{this} = new PrimitiveIterator.OfLong {
341341
def hasNext: Boolean = hasStep
342342
def nextLong(): Long = nextStep()
343343
}

library/src/scala/collection/View.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ object View extends IterableFactory[View] {
172172

173173
@SerialVersionUID(3L)
174174
class LeftPartitionMapped[A, A1, A2](underlying: SomeIterableOps[A]^, f: A => Either[A1, A2]) extends AbstractView[A1] {
175-
def iterator: AbstractIterator[A1] = new AbstractIterator[A1] {
175+
def iterator: AbstractIterator[A1]^{this} = new AbstractIterator[A1] {
176176
private[this] val self = underlying.iterator
177177
private[this] var hd: A1 = _
178178
private[this] var hdDefined: Boolean = false
@@ -197,7 +197,7 @@ object View extends IterableFactory[View] {
197197

198198
@SerialVersionUID(3L)
199199
class RightPartitionMapped[A, A1, A2](underlying: SomeIterableOps[A]^, f: A => Either[A1, A2]) extends AbstractView[A2] {
200-
def iterator: AbstractIterator[A2] = new AbstractIterator[A2] {
200+
def iterator: AbstractIterator[A2]^{this} = new AbstractIterator[A2] {
201201
private[this] val self = underlying.iterator
202202
private[this] var hd: A2 = _
203203
private[this] var hdDefined: Boolean = false

library/src/scala/collection/convert/StreamExtensions.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ import scala.jdk._
3030
* [[scala.jdk.javaapi.StreamConverters]].
3131
*/
3232
trait StreamExtensions {
33+
this: StreamExtensions =>
3334
// collections
3435

3536
implicit class IterableHasSeqStream[A](cc: IterableOnce[A]) {

library/src/scala/collection/mutable/HashTable.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ private[collection] trait HashTable[A, B, Entry >: Null <: HashEntry[A, Entry]]
211211

212212
/** An iterator returning all entries.
213213
*/
214-
def entriesIterator: Iterator[Entry] = new AbstractIterator[Entry] {
214+
def entriesIterator: Iterator[Entry]^{this} = new AbstractIterator[Entry] {
215215
val iterTable = table
216216
var idx = lastPopulatedIndex
217217
var es = iterTable(idx)

0 commit comments

Comments
 (0)