@@ -26,15 +26,13 @@ internal const val NOTHING_TO_STEAL = -2L
2626 * that these two (current one and submitted) are communicating and sharing state thus making such communication extremely fast.
2727 * E.g. submitted jobs [1, 2, 3, 4] will be executed in [4, 1, 2, 3] order.
2828 *
29- * ### Work offloading
30- *
31- * When the queue is full, half of existing tasks are offloaded to global queue which is regularly polled by other pool workers.
32- * Offloading occurs in LIFO order for the sake of implementation simplicity: offloads should be extremely rare and occurs only in specific use-cases
33- * (e.g. when coroutine starts heavy fork-join-like computation), so fairness is not important.
34- * As an alternative, offloading directly to some [CoroutineScheduler.Worker] may be used, but then the strategy of selecting any idle worker
35- * should be implemented and implementation should be aware multiple producers.
36- *
37- * @suppress **This is unstable API and it is subject to change.**
29+ * ### Algorithm and implementation details
30+ * This is a regular SPMC bounded queue with the additional property that tasks can be removed from the middle of the queue
31+ * (scheduler workers without a CPU permit steal blocking tasks via this mechanism). Such property enforces us to use CAS in
32+ * order to properly claim value from the buffer.
33+ * Moreover, [Task] objects are reusable, so it may seem that this queue is prone to ABA problem.
34+ * Indeed it formally has ABA-problem, but the whole processing logic is written in the way that such ABA is harmless.
35+ * "I have discovered a truly marvelous proof of this, which this margin is too narrow to contain"
3836 */
3937internal class WorkQueue {
4038
@@ -58,18 +56,21 @@ internal class WorkQueue {
5856
5957 private val producerIndex = atomic(0 )
6058 private val consumerIndex = atomic(0 )
59+ // Shortcut to avoid scanning queue without blocking tasks
60+ private val blockingTasksInBuffer = atomic(0 )
6161
6262 /* *
6363 * Retrieves and removes task from the head of the queue
64- * Invariant: this method is called only by the owner of the queue ([stealBatch] is not)
64+ * Invariant: this method is called only by the owner of the queue.
6565 */
6666 fun poll (): Task ? = lastScheduledTask.getAndSet(null ) ? : pollBuffer()
6767
6868 /* *
6969 * Invariant: Called only by the owner of the queue, returns
7070 * `null` if task was added, task that wasn't added otherwise.
7171 */
72- fun add (task : Task ): Task ? {
72+ fun add (task : Task , fair : Boolean = false): Task ? {
73+ if (fair) return addLast(task)
7374 val previous = lastScheduledTask.getAndSet(task) ? : return null
7475 return addLast(previous)
7576 }
@@ -78,18 +79,20 @@ internal class WorkQueue {
7879 * Invariant: Called only by the owner of the queue, returns
7980 * `null` if task was added, task that wasn't added otherwise.
8081 */
81- fun addLast (task : Task ): Task ? {
82+ private fun addLast (task : Task ): Task ? {
83+ if (task.isBlocking) blockingTasksInBuffer.incrementAndGet()
8284 if (bufferSize == BUFFER_CAPACITY - 1 ) return task
83- val headLocal = producerIndex.value
84- val nextIndex = headLocal and MASK
85-
85+ val nextIndex = producerIndex.value and MASK
8686 /*
87- * If current element is not null then we're racing with consumers for the tail. If we skip this check then
88- * the consumer can null out current element and it will be lost. If we're racing for tail then
89- * the queue is close to overflowing => return task
87+ * If current element is not null then we're racing with a really slow consumer that committed the consumer index,
88+ * but hasn't yet nulled out the slot, effectively preventing us from using it.
89+ * Such situations are very rare in practise (although possible) and we decided to give up a progress guarantee
90+ * to have a stronger invariant "add to queue with bufferSize == 0 is always successful".
91+ * This algorithm can still be wait-free for add, but if and only if tasks are not reusable, otherwise
92+ * nulling out the buffer wouldn't be possible.
9093 */
91- if (buffer[nextIndex] != null ) {
92- return task
94+ while (buffer[nextIndex] != null ) {
95+ Thread . yield ()
9396 }
9497 buffer.lazySet(nextIndex, task)
9598 producerIndex.incrementAndGet()
@@ -103,18 +106,52 @@ internal class WorkQueue {
103106 * or positive value of how many nanoseconds should pass until the head of this queue will be available to steal.
104107 */
105108 fun tryStealFrom (victim : WorkQueue ): Long {
106- if (victim.stealBatch { task -> add(task) }) {
109+ assert { bufferSize == 0 }
110+ val task = victim.pollBuffer()
111+ if (task != null ) {
112+ val notAdded = add(task)
113+ assert { notAdded == null }
107114 return TASK_STOLEN
108115 }
109- return tryStealLastScheduled(victim)
116+ return tryStealLastScheduled(victim, blockingOnly = false )
117+ }
118+
119+ fun tryStealBlockingFrom (victim : WorkQueue ): Long {
120+ assert { bufferSize == 0 }
121+ var start = victim.consumerIndex.value
122+ val end = victim.producerIndex.value
123+ val buffer = victim.buffer
124+
125+ while (start != end) {
126+ val index = start and MASK
127+ if (victim.blockingTasksInBuffer.value == 0 ) break
128+ val value = buffer[index]
129+ if (value != null && value.isBlocking && buffer.compareAndSet(index, value, null )) {
130+ victim.blockingTasksInBuffer.decrementAndGet()
131+ add(value)
132+ return TASK_STOLEN
133+ } else {
134+ ++ start
135+ }
136+ }
137+ return tryStealLastScheduled(victim, blockingOnly = true )
138+ }
139+
140+ fun offloadAllWorkTo (globalQueue : GlobalQueue ) {
141+ lastScheduledTask.getAndSet(null )?.let { globalQueue.add(it) }
142+ while (pollTo(globalQueue)) {
143+ // Steal everything
144+ }
110145 }
111146
112147 /* *
113148 * Contract on return value is the same as for [tryStealFrom]
114149 */
115- private fun tryStealLastScheduled (victim : WorkQueue ): Long {
150+ private fun tryStealLastScheduled (victim : WorkQueue , blockingOnly : Boolean ): Long {
116151 while (true ) {
117152 val lastScheduled = victim.lastScheduledTask.value ? : return NOTHING_TO_STEAL
153+ if (blockingOnly && ! lastScheduled.isBlocking) return NOTHING_TO_STEAL
154+
118155 // TODO time wraparound ?
119156 val time = schedulerTimeSource.nanoTime()
120157 val staleness = time - lastScheduled.submissionTime
@@ -134,49 +171,10 @@ internal class WorkQueue {
134171 }
135172 }
136173
137- private fun GlobalQueue.add (task : Task ) {
138- /*
139- * globalQueue is closed as the very last step in the shutdown sequence when all worker threads had
140- * been already shutdown (with the only exception of the last worker thread that might be performing
141- * shutdown procedure itself). As a consistency check we do a [cheap!] check that it is not closed here yet.
142- */
143- val added = addLast(task)
144- assert { added }
145- }
146-
147- internal fun offloadAllWork (globalQueue : GlobalQueue ) {
148- lastScheduledTask.getAndSet(null )?.let { globalQueue.add(it) }
149- while (stealBatchTo(globalQueue)) {
150- // Steal everything
151- }
152- }
153-
154- /* *
155- * Method that is invoked by external workers to steal work.
156- * Half of the buffer (at least 1) is stolen, returns `true` if at least one task was stolen.
157- */
158- private inline fun stealBatch (consumer : (Task ) -> Unit ): Boolean {
159- val size = bufferSize
160- if (size == 0 ) return false
161- var toSteal = (size / 2 ).coerceAtLeast(1 )
162- var wasStolen = false
163- while (toSteal-- > 0 ) {
164- val tailLocal = consumerIndex.value
165- if (tailLocal - producerIndex.value == 0 ) return wasStolen
166- val index = tailLocal and MASK
167- val element = buffer[index] ? : continue
168- if (consumerIndex.compareAndSet(tailLocal, tailLocal + 1 )) {
169- // 1) Help GC 2) Signal producer that this slot is consumed and may be used
170- consumer(element)
171- buffer[index] = null
172- wasStolen = true
173- }
174- }
175- return wasStolen
176- }
177-
178- private fun stealBatchTo (queue : GlobalQueue ): Boolean {
179- return stealBatch { queue.add(it) }
174+ private fun pollTo (queue : GlobalQueue ): Boolean {
175+ val task = pollBuffer() ? : return false
176+ queue.add(task)
177+ return true
180178 }
181179
182180 private fun pollBuffer (): Task ? {
@@ -185,8 +183,28 @@ internal class WorkQueue {
185183 if (tailLocal - producerIndex.value == 0 ) return null
186184 val index = tailLocal and MASK
187185 if (consumerIndex.compareAndSet(tailLocal, tailLocal + 1 )) {
188- return buffer.getAndSet(index, null )
186+ // Nulls are allowed when blocking tasks are stolen from the middle of the queue.
187+ val value = buffer.getAndSet(index, null ) ? : continue
188+ value.decrementIfBlocking()
189+ return value
189190 }
190191 }
191192 }
193+
194+ private fun Task?.decrementIfBlocking () {
195+ if (this != null && isBlocking) {
196+ val value = blockingTasksInBuffer.decrementAndGet()
197+ assert { value >= 0 }
198+ }
199+ }
192200}
201+
202+ private fun GlobalQueue.add (task : Task ) {
203+ /*
204+ * globalQueue is closed as the very last step in the shutdown sequence when all worker threads had
205+ * been already shutdown (with the only exception of the last worker thread that might be performing
206+ * shutdown procedure itself). As a consistency check we do a [cheap!] check that it is not closed here yet.
207+ */
208+ val added = addLast(task)
209+ assert { added }
210+ }
0 commit comments