@@ -12,6 +12,7 @@ import config.Printers.capt
1212import StdNames .nme
1313import util .{SimpleIdentitySet , EqHashMap , SrcPos }
1414import tpd .*
15+ import reflect .ClassTag
1516
1617object SepChecker :
1718
@@ -39,14 +40,92 @@ object SepChecker:
3940 case _ => NoSymbol
4041 end TypeKind
4142
43+ /** A class for segmented sets of consumed references.
44+ * References are associated with the source positions where they first appeared.
45+ * References are compared with `eq`.
46+ */
47+ abstract class ConsumedSet :
48+ /** The references in the set. The array should be treated as immutable in client code */
49+ def refs : Array [CaptureRef ]
50+
51+ /** The associated source positoons. The array should be treated as immutable in client code */
52+ def locs : Array [SrcPos ]
53+
54+ /** The number of references in the set */
55+ def size : Int
56+
57+ def toMap : Map [CaptureRef , SrcPos ] = refs.take(size).zip(locs).toMap
58+
59+ def show (using Context ) =
60+ s " [ ${toMap.map((ref, loc) => i " $ref -> $loc" ).toList}] "
61+ end ConsumedSet
62+
63+ /** A fixed consumed set consisting of the given references `refs` and
64+ * associated source positions `locs`
65+ */
66+ class ConstConsumedSet (val refs : Array [CaptureRef ], val locs : Array [SrcPos ]) extends ConsumedSet :
67+ def size = refs.size
68+
69+ /** A mutable consumed set, which is initially empty */
70+ class MutConsumedSet extends ConsumedSet :
71+ var refs : Array [CaptureRef ] = new Array (4 )
72+ var locs : Array [SrcPos ] = new Array (4 )
73+ var size = 0
74+
75+ private def double [T <: AnyRef : ClassTag ](xs : Array [T ]): Array [T ] =
76+ val xs1 = new Array [T ](xs.length * 2 )
77+ xs.copyToArray(xs1)
78+ xs1
79+
80+ private def ensureCapacity (added : Int ): Unit =
81+ if size + added > refs.length then
82+ refs = double(refs)
83+ locs = double(locs)
84+
85+ /** If `ref` is in the set, its associated source position, otherwise `null` */
86+ def get (ref : CaptureRef ): SrcPos | Null =
87+ var i = 0
88+ while i < size && (refs(i) ne ref) do i += 1
89+ if i < size then locs(i) else null
90+
91+ /** If `ref` is not yet in the set, add it with given source position */
92+ def put (ref : CaptureRef , loc : SrcPos ): Unit =
93+ if get(ref) == null then
94+ ensureCapacity(1 )
95+ refs(size) = ref
96+ locs(size) = loc
97+ size += 1
98+
99+ /** Add all references with their associated positions from `that` which
100+ * are not yet in the set.
101+ */
102+ def ++= (that : ConsumedSet ): Unit =
103+ for i <- 0 until that.size do put(that.refs(i), that.locs(i))
104+
105+ /** Run `op` and return any new references it created in a separate `ConsumedSet`.
106+ * The current mutable set is reset to its state before `op` was run.
107+ */
108+ def segment (op : => Unit ): ConsumedSet =
109+ val start = size
110+ try
111+ op
112+ if size == start then EmptyConsumedSet
113+ else ConstConsumedSet (refs.slice(start, size), locs.slice(start, size))
114+ finally
115+ size = start
116+
117+ end MutConsumedSet
118+
119+ val EmptyConsumedSet = ConstConsumedSet (Array (), Array ())
120+
42121class SepChecker (checker : CheckCaptures .CheckerAPI ) extends tpd.TreeTraverser :
43122 import checker .*
44123 import SepChecker .*
45124
46125 /** The set of capabilities that are hidden by a polymorphic result type
47126 * of some previous definition.
48127 */
49- private var defsShadow : Refs = SimpleIdentitySet .empty
128+ private var defsShadow : Refs = emptySet
50129
51130 /** A map from definitions to their internal result types.
52131 * Populated during separation checking traversal.
@@ -58,6 +137,16 @@ class SepChecker(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser:
58137 */
59138 private var previousDefs : List [mutable.ListBuffer [ValOrDefDef ]] = Nil
60139
140+ private var consumed : MutConsumedSet = MutConsumedSet ()
141+
142+ private def withFreshConsumed (op : => Unit ): Unit =
143+ val saved = consumed
144+ consumed = MutConsumedSet ()
145+ op
146+ consumed = saved
147+
148+ private var openLabeled : List [(Name , mutable.ListBuffer [ConsumedSet ])] = Nil
149+
61150 extension (refs : Refs )
62151 private def footprint (using Context ): Refs =
63152 def recur (elems : Refs , newElems : List [CaptureRef ]): Refs = newElems match
@@ -198,6 +287,19 @@ class SepChecker(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser:
198287 tree.srcPos)
199288 end sepUseError
200289
290+ def consumeError (ref : CaptureRef , loc : SrcPos , pos : SrcPos )(using Context ): Unit =
291+ report.error(
292+ em """ Separation failure: Illegal access to $ref,
293+ |which was passed to a @consume parameter on line ${loc.line + 1 }
294+ |and therefore is no longer available. """ ,
295+ pos)
296+
297+ def consumeInLoopError (ref : CaptureRef , pos : SrcPos )(using Context ): Unit =
298+ report.error(
299+ em """ Separation failure: $ref appears in a loop,
300+ |therefore it cannot be passed to a @consume parameter. """ ,
301+ pos)
302+
201303 private def checkApply (fn : Tree , args : List [Tree ], deps : collection.Map [Tree , List [Tree ]])(using Context ): Unit =
202304 val fnCaptures = methPart(fn) match
203305 case Select (qual, _) => qual.nuType.captureSet
@@ -240,6 +342,9 @@ class SepChecker(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser:
240342 val overlap = defUseOverlap(defsShadow, usedFootprint, tree.symbol)
241343 if ! overlap.isEmpty then
242344 sepUseError(tree, usedFootprint, overlap)
345+ for ref <- used.elems do
346+ val pos = consumed.get(ref)
347+ if pos != null then consumeError(ref, pos, tree.srcPos)
243348
244349 def checkType (tpt : Tree , sym : Symbol )(using Context ): Unit =
245350 checkType(tpt.nuType, tpt.srcPos,
@@ -383,10 +488,11 @@ class SepChecker(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser:
383488 checkRefs(toCheck, i " $typeDescr type $tpe hides " )
384489 case TypeKind .Argument (arg) =>
385490 if tpe.hasAnnotation(defn.ConsumeAnnot ) then
386- val capts = captures(arg)
387- def descr (verb : String ) = i " argument to @consume parameter with type ${arg.nuType} $verb"
388- checkRefs(capts.footprint, descr(" refers to" ))
389- checkRefs(capts.hidden.footprint, descr(" hides" ))
491+ val capts = captures(arg).footprint
492+ checkRefs(capts, i " argument to @consume parameter with type ${arg.nuType} refers to " )
493+ for ref <- capts do
494+ if ! ref.derivesFrom(defn.Caps_SharedCapability ) then
495+ consumed.put(ref, arg.srcPos)
390496
391497 if ! tpe.hasAnnotation(defn.UntrackedCapturesAnnot ) then
392498 traverse(Captures .None , tpe)
@@ -435,35 +541,72 @@ class SepChecker(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser:
435541 case tree : Apply => tree.symbol == defn.Caps_unsafeAssumeSeparate
436542 case _ => false
437543
544+ def checkValOrDefDef (tree : ValOrDefDef )(using Context ): Unit =
545+ if ! tree.symbol.isOneOf(TermParamOrAccessor ) && ! isUnsafeAssumeSeparate(tree.rhs) then
546+ checkType(tree.tpt, tree.symbol)
547+ if previousDefs.nonEmpty then
548+ capt.println(i " sep check def ${tree.symbol}: ${tree.tpt} with ${captures(tree.tpt).hidden.footprint}" )
549+ defsShadow ++= captures(tree.tpt).hidden.footprint.deductSym(tree.symbol)
550+ resultType(tree.symbol) = tree.tpt.nuType
551+ previousDefs.head += tree
552+
438553 def traverse (tree : Tree )(using Context ): Unit =
439554 if isUnsafeAssumeSeparate(tree) then return
440555 checkUse(tree)
441556 tree match
442557 case tree : GenericApply =>
558+ traverseChildren(tree)
443559 tree.tpe match
444560 case _ : MethodOrPoly =>
445561 case _ => traverseApply(tree, Nil )
446- traverseChildren(tree)
447562 case tree : Block =>
448563 val saved = defsShadow
449564 previousDefs = mutable.ListBuffer () :: previousDefs
450565 try traverseChildren(tree)
451566 finally
452567 previousDefs = previousDefs.tail
453568 defsShadow = saved
454- case tree : ValOrDefDef =>
569+ case tree : ValDef =>
455570 traverseChildren(tree)
456- if ! tree.symbol.isOneOf(TermParamOrAccessor ) && ! isUnsafeAssumeSeparate(tree.rhs) then
457- checkType(tree.tpt, tree.symbol)
458- if previousDefs.nonEmpty then
459- capt.println(i " sep check def ${tree.symbol}: ${tree.tpt} with ${captures(tree.tpt).hidden.footprint}" )
460- defsShadow ++= captures(tree.tpt).hidden.footprint.deductSym(tree.symbol)
461- resultType(tree.symbol) = tree.tpt.nuType
462- previousDefs.head += tree
571+ checkValOrDefDef(tree)
572+ case tree : DefDef =>
573+ withFreshConsumed :
574+ traverseChildren(tree)
575+ checkValOrDefDef(tree)
576+ case If (cond, thenp, elsep) =>
577+ traverse(cond)
578+ val thenConsumed = consumed.segment(traverse(thenp))
579+ val elseConsumed = consumed.segment(traverse(elsep))
580+ consumed ++= thenConsumed
581+ consumed ++= elseConsumed
582+ case tree @ Labeled (bind, expr) =>
583+ val consumedBuf = mutable.ListBuffer [ConsumedSet ]()
584+ openLabeled = (bind.name, consumedBuf) :: openLabeled
585+ traverse(expr)
586+ for cs <- consumedBuf do consumed ++= cs
587+ openLabeled = openLabeled.tail
588+ case Return (expr, from) =>
589+ val retConsumed = consumed.segment(traverse(expr))
590+ from match
591+ case Ident (name) =>
592+ for (lbl, consumedBuf) <- openLabeled do
593+ if lbl == name then
594+ consumedBuf += retConsumed
595+ case _ =>
596+ case Match (sel, cases) =>
597+ // Matches without returns might still be kept after pattern matching to
598+ // encode table switches.
599+ traverse(sel)
600+ val caseConsumed = for cas <- cases yield consumed.segment(traverse(cas))
601+ caseConsumed.foreach(consumed ++= _)
602+ case tree : TypeDef if tree.symbol.isClass =>
603+ withFreshConsumed :
604+ traverseChildren(tree)
605+ case tree : WhileDo =>
606+ val loopConsumed = consumed.segment(traverseChildren(tree))
607+ if loopConsumed.size != 0 then
608+ val (ref, pos) = loopConsumed.toMap.head
609+ consumeInLoopError(ref, pos)
463610 case _ =>
464611 traverseChildren(tree)
465- end SepChecker
466-
467-
468-
469-
612+ end SepChecker
0 commit comments