@@ -12,36 +12,42 @@ import kotlin.coroutines.experimental.*
1212 * every time the coroutine with this element in the context is resumed on a thread.
1313 *
1414 * Implementations of this interface define a type [S] of the thread-local state that they need to store on
15- * resume of a coroutine and restore later on suspend and the infrastructure provides the corresponding storage.
15+ * resume of a coroutine and restore later on suspend. The infrastructure provides the corresponding storage.
1616 *
1717 * Example usage looks like this:
1818 *
1919 * ```
20- * // declare thread local variable holding MyData
21- * private val myThreadLocal = ThreadLocal<MyData?>()
22- *
23- * // declare context element holding MyData
24- * class MyElement(val data: MyData) : ThreadContextElement<MyData?> {
20+ * // Appends "name" of a coroutine to a current thread name when coroutine is executed
21+ * class CoroutineName(val name: String) : ThreadContextElement<String> {
2522 * // declare companion object for a key of this element in coroutine context
26- * companion object Key : CoroutineContext.Key<MyElement >
23+ * companion object Key : CoroutineContext.Key<CoroutineName >
2724 *
2825 * // provide the key of the corresponding context element
29- * override val key: CoroutineContext.Key<MyElement >
26+ * override val key: CoroutineContext.Key<CoroutineName >
3027 * get() = Key
3128 *
3229 * // this is invoked before coroutine is resumed on current thread
33- * override fun updateThreadContext(context: CoroutineContext): MyData? {
34- * val oldState = myThreadLocal.get()
35- * myThreadLocal.set(data)
36- * return oldState
30+ * override fun updateThreadContext(context: CoroutineContext): String {
31+ * val previousName = Thread.currentThread().name
32+ * Thread.currentThread().name = "$previousName # $name"
33+ * return previousName
3734 * }
3835 *
3936 * // this is invoked after coroutine has suspended on current thread
40- * override fun restoreThreadContext(context: CoroutineContext, oldState: MyData? ) {
41- * myThreadLocal.set(oldState)
37+ * override fun restoreThreadContext(context: CoroutineContext, oldState: String ) {
38+ * Thread.currentThread().name = oldState
4239 * }
4340 * }
41+ *
42+ * // Usage
43+ * launch(UI + CoroutineName("Progress bar coroutine")) { ... }
4444 * ```
45+ *
46+ * Every time this coroutine is resumed on a thread, UI thread name is updated to
47+ * "UI thread original name # Progress bar coroutine" and the thread name is restored to the original one when
48+ * this coroutine suspends.
49+ *
50+ * To use [ThreadLocal] variable within the coroutine use [ThreadLocal.asContextElement][asContextElement] function.
4551 */
4652public interface ThreadContextElement <S > : CoroutineContext .Element {
4753 /* *
@@ -67,87 +73,44 @@ public interface ThreadContextElement<S> : CoroutineContext.Element {
6773 public fun restoreThreadContext (context : CoroutineContext , oldState : S )
6874}
6975
70- private val ZERO = Symbol (" ZERO" )
71-
72- // Used when there are >= 2 active elements in the context
73- private class ThreadState (val context : CoroutineContext , n : Int ) {
74- private var a = arrayOfNulls<Any >(n)
75- private var i = 0
76-
77- fun append (value : Any? ) { a[i++ ] = value }
78- fun take () = a[i++ ]
79- fun start () { i = 0 }
80- }
81-
82- // Counts ThreadContextElements in the context
83- // Any? here is Int | ThreadContextElement (when count is one)
84- private val countAll =
85- fun (countOrElement : Any? , element : CoroutineContext .Element ): Any? {
86- if (element is ThreadContextElement <* >) {
87- val inCount = countOrElement as ? Int ? : 1
88- return if (inCount == 0 ) element else inCount + 1
89- }
90- return countOrElement
91- }
92-
93- // Find one (first) ThreadContextElement in the context, it is used when we know there is exactly one
94- private val findOne =
95- fun (found : ThreadContextElement <* >? , element : CoroutineContext .Element ): ThreadContextElement <* >? {
96- if (found != null ) return found
97- return element as ? ThreadContextElement <* >
98- }
99-
100- // Updates state for ThreadContextElements in the context using the given ThreadState
101- private val updateState =
102- fun (state : ThreadState , element : CoroutineContext .Element ): ThreadState {
103- if (element is ThreadContextElement <* >) {
104- state.append(element.updateThreadContext(state.context))
105- }
106- return state
107- }
108-
109- // Restores state for all ThreadContextElements in the context from the given ThreadState
110- private val restoreState =
111- fun (state : ThreadState , element : CoroutineContext .Element ): ThreadState {
112- @Suppress(" UNCHECKED_CAST" )
113- if (element is ThreadContextElement <* >) {
114- (element as ThreadContextElement <Any ?>).restoreThreadContext(state.context, state.take())
115- }
116- return state
117- }
118-
119- internal fun updateThreadContext (context : CoroutineContext ): Any? {
120- val count = context.fold(0 , countAll)
121- @Suppress(" IMPLICIT_BOXING_IN_IDENTITY_EQUALS" )
122- return when {
123- count == = 0 -> ZERO // very fast path when there are no active ThreadContextElements
124- // ^^^ identity comparison for speed, we know zero always has the same identity
125- count is Int -> {
126- // slow path for multiple active ThreadContextElements, allocates ThreadState for multiple old values
127- context.fold(ThreadState (context, count), updateState)
128- }
129- else -> {
130- // fast path for one ThreadContextElement (no allocations, no additional context scan)
131- @Suppress(" UNCHECKED_CAST" )
132- val element = count as ThreadContextElement <Any ?>
133- element.updateThreadContext(context)
134- }
135- }
136- }
137-
138- internal fun restoreThreadContext (context : CoroutineContext , oldState : Any? ) {
139- when {
140- oldState == = ZERO -> return // very fast path when there are no ThreadContextElements
141- oldState is ThreadState -> {
142- // slow path with multiple stored ThreadContextElements
143- oldState.start()
144- context.fold(oldState, restoreState)
145- }
146- else -> {
147- // fast path for one ThreadContextElement, but need to find it
148- @Suppress(" UNCHECKED_CAST" )
149- val element = context.fold(null , findOne) as ThreadContextElement <Any ?>
150- element.restoreThreadContext(context, oldState)
151- }
152- }
153- }
76+ /* *
77+ * Wraps [ThreadLocal] into [ThreadContextElement]. The resulting [ThreadContextElement]
78+ * maintains the given [value] of the given [ThreadLocal] for coroutine regardless of the actual thread its is resumed on.
79+ * By default [ThreadLocal.get] is used as a value for the thread-local variable, but it can be overridden with [value] parameter.
80+ *
81+ * Example usage looks like this:
82+ *
83+ * ```
84+ * val myThreadLocal = ThreadLocal<String?>()
85+ * ...
86+ * println(myThreadLocal.get()) // Prints "null"
87+ * launch(CommonPool + myThreadLocal.asContextElement(initialValue = "foo")) {
88+ * println(myThreadLocal.get()) // Prints "foo"
89+ * withContext(UI) {
90+ * println(myThreadLocal.get()) // Prints "foo", but it's on UI thread
91+ * }
92+ * }
93+ * println(myThreadLocal.get()) // Prints "null"
94+ * ```
95+ *
96+ * Note that the context element does not track modifications of the thread-local variable, for example:
97+ *
98+ * ```
99+ * myThreadLocal.set("main")
100+ * withContext(UI) {
101+ * println(myThreadLocal.get()) // Prints "main"
102+ * myThreadLocal.set("UI")
103+ * }
104+ * println(myThreadLocal.get()) // Prints "main", not "UI"
105+ * ```
106+ *
107+ * Use `withContext` to update the corresponding thread-local variable to a different value, for example:
108+ *
109+ * ```
110+ * withContext(myThreadLocal.asContextElement("foo")) {
111+ * println(myThreadLocal.get()) // Prints "foo"
112+ * }
113+ * ```
114+ */
115+ public fun <T > ThreadLocal<T>.asContextElement (value : T = get()): ThreadContextElement <T > =
116+ ThreadLocalElement (value, this )
0 commit comments