@@ -16,6 +16,14 @@ internal const val MASK = BUFFER_CAPACITY - 1 // 128 by default
1616internal const val TASK_STOLEN = - 1L
1717internal const val NOTHING_TO_STEAL = - 2L
1818
19+ internal typealias StealingMode = Int
20+ internal const val STEAL_ANY : StealingMode = 3
21+ internal const val STEAL_CPU_ONLY : StealingMode = 2
22+ internal const val STEAL_BLOCKING_ONLY : StealingMode = 1
23+
24+ internal inline val Task .maskForStealingMode: Int
25+ get() = if (isBlocking) STEAL_BLOCKING_ONLY else STEAL_CPU_ONLY
26+
1927/* *
2028 * Tightly coupled with [CoroutineScheduler] queue of pending tasks, but extracted to separate file for simplicity.
2129 * At any moment queue is used only by [CoroutineScheduler.Worker] threads, has only one producer (worker owning this queue)
@@ -107,54 +115,75 @@ internal class WorkQueue {
107115 *
108116 * Returns [NOTHING_TO_STEAL] if queue has nothing to steal, [TASK_STOLEN] if at least task was stolen
109117 * or positive value of how many nanoseconds should pass until the head of this queue will be available to steal.
118+ *
119+ * [StealingMode] controls what tasks to steal:
120+ * * [STEAL_ANY] is default mode for scheduler, task from the head (in FIFO order) is stolen
121+ * * [STEAL_BLOCKING_ONLY] is mode for stealing *an arbitrary* blocking task, which is used by the scheduler when helping in Dispatchers.IO mode
122+ * * [STEAL_CPU_ONLY] is a kludge for `runSingleTaskFromCurrentSystemDispatcher`
110123 */
111- fun trySteal (stolenTaskRef : ObjectRef <Task ?>): Long {
112- val task = pollBuffer()
124+ fun trySteal (stealingMode : StealingMode , stolenTaskRef : ObjectRef <Task ?>): Long {
125+ val task = when (stealingMode) {
126+ STEAL_ANY -> pollBuffer()
127+ else -> stealWithExclusiveMode(stealingMode)
128+ }
129+
113130 if (task != null ) {
114131 stolenTaskRef.element = task
115132 return TASK_STOLEN
116133 }
117- return tryStealLastScheduled(stolenTaskRef, blockingOnly = false )
134+ return tryStealLastScheduled(stealingMode, stolenTaskRef )
118135 }
119136
120- fun tryStealBlocking (stolenTaskRef : ObjectRef <Task ?>): Long {
137+ // Steal only tasks of a particular kind, potentially invoking full queue scan
138+ private fun stealWithExclusiveMode (stealingMode : StealingMode ): Task ? {
121139 var start = consumerIndex.value
122140 val end = producerIndex.value
123-
124- while (start != end && blockingTasksInBuffer.value > 0 ) {
125- stolenTaskRef.element = tryExtractBlockingTask(start++ ) ? : continue
126- return TASK_STOLEN
141+ val onlyBlocking = stealingMode == STEAL_BLOCKING_ONLY
142+ // Bail out if there is no blocking work for us
143+ while (start != end) {
144+ if (onlyBlocking && blockingTasksInBuffer.value == 0 ) return null
145+ return tryExtractFromTheMiddle(start++ , onlyBlocking) ? : continue
127146 }
128- return tryStealLastScheduled(stolenTaskRef, blockingOnly = true )
147+
148+ return null
129149 }
130150
131151 // Polls for blocking task, invoked only by the owner
132- fun pollBlocking (): Task ? {
152+ // NB: ONLY for runSingleTask method
153+ fun pollBlocking (): Task ? = pollWithExclusiveMode(onlyBlocking = true /* only blocking */ )
154+
155+ // Polls for CPU task, invoked only by the owner
156+ // NB: ONLY for runSingleTask method
157+ fun pollCpu (): Task ? = pollWithExclusiveMode(onlyBlocking = false /* only cpu */ )
158+
159+ private fun pollWithExclusiveMode (/* Only blocking OR only CPU */ onlyBlocking : Boolean ): Task ? {
133160 while (true ) { // Poll the slot
134161 val lastScheduled = lastScheduledTask.value ? : break
135- if (! lastScheduled.isBlocking) break
162+ if (lastScheduled.isBlocking != onlyBlocking ) break
136163 if (lastScheduledTask.compareAndSet(lastScheduled, null )) {
137164 return lastScheduled
138165 } // Failed -> someone else stole it
139166 }
140167
168+ // Failed to poll the slot, scan the queue
141169 val start = consumerIndex.value
142170 var end = producerIndex.value
143-
144- while (start != end && blockingTasksInBuffer.value > 0 ) {
145- val task = tryExtractBlockingTask(-- end)
171+ // Bail out if there is no blocking work for us
172+ while (start != end) {
173+ if (onlyBlocking && blockingTasksInBuffer.value == 0 ) return null
174+ val task = tryExtractFromTheMiddle(-- end, onlyBlocking)
146175 if (task != null ) {
147176 return task
148177 }
149178 }
150179 return null
151180 }
152181
153- private fun tryExtractBlockingTask (index : Int ): Task ? {
182+ private fun tryExtractFromTheMiddle (index : Int , onlyBlocking : Boolean ): Task ? {
154183 val arrayIndex = index and MASK
155184 val value = buffer[arrayIndex]
156- if (value != null && value.isBlocking && buffer.compareAndSet(arrayIndex, value, null )) {
157- blockingTasksInBuffer.decrementAndGet()
185+ if (value != null && value.isBlocking == onlyBlocking && buffer.compareAndSet(arrayIndex, value, null )) {
186+ if (onlyBlocking) blockingTasksInBuffer.decrementAndGet()
158187 return value
159188 }
160189 return null
@@ -170,10 +199,12 @@ internal class WorkQueue {
170199 /* *
171200 * Contract on return value is the same as for [trySteal]
172201 */
173- private fun tryStealLastScheduled (stolenTaskRef : ObjectRef <Task ?>, blockingOnly : Boolean ): Long {
202+ private fun tryStealLastScheduled (stealingMode : StealingMode , stolenTaskRef : ObjectRef <Task ?>): Long {
174203 while (true ) {
175204 val lastScheduled = lastScheduledTask.value ? : return NOTHING_TO_STEAL
176- if (blockingOnly && ! lastScheduled.isBlocking) return NOTHING_TO_STEAL
205+ if ((lastScheduled.maskForStealingMode and stealingMode) == 0 ) {
206+ return NOTHING_TO_STEAL
207+ }
177208
178209 // TODO time wraparound ?
179210 val time = schedulerTimeSource.nanoTime()
0 commit comments