@@ -7,12 +7,37 @@ package kotlinx.coroutines.flow
77import kotlinx.coroutines.*
88import kotlinx.coroutines.channels.*
99import kotlin.coroutines.*
10+ import kotlin.reflect.*
1011import kotlin.test.*
1112
1213class FlowInvariantsTest : TestBase () {
1314
15+ private fun <T > runParametrizedTest (
16+ expectedException : KClass <out Throwable >? = null,
17+ testBody : suspend (flowFactory: (suspend FlowCollector <T >.() -> Unit ) -> Flow <T >) -> Unit
18+ ) = runTest {
19+ val r1 = runCatching { testBody { flow(it) } }.exceptionOrNull()
20+ check(r1, expectedException)
21+ reset()
22+
23+ val r2 = runCatching { testBody { abstractFlow(it) } }.exceptionOrNull()
24+ check(r2, expectedException)
25+ }
26+
27+ private fun <T > abstractFlow (block : suspend FlowCollector <T >.() -> Unit ): Flow <T > = object : AbstractFlow <T >() {
28+ override suspend fun collectSafely (collector : FlowCollector <T >) {
29+ collector.block()
30+ }
31+ }
32+
33+ private fun check (exception : Throwable ? , expectedException : KClass <out Throwable >? ) {
34+ if (expectedException != null && exception == null ) fail(" Expected $expectedException , but test completed successfully" )
35+ if (expectedException != null && exception != null ) assertTrue(expectedException.isInstance(exception))
36+ if (expectedException == null && exception != null ) throw exception
37+ }
38+
1439 @Test
15- fun testWithContextContract () = runTest({ it is IllegalStateException } ) {
40+ fun testWithContextContract () = runParametrizedTest< Int >( IllegalStateException :: class ) { flow ->
1641 flow {
1742 kotlinx.coroutines.withContext(NonCancellable ) {
1843 emit(1 )
@@ -23,7 +48,7 @@ class FlowInvariantsTest : TestBase() {
2348 }
2449
2550 @Test
26- fun testWithDispatcherContractViolated () = runTest({ it is IllegalStateException } ) {
51+ fun testWithDispatcherContractViolated () = runParametrizedTest< Int >( IllegalStateException :: class ) { flow ->
2752 flow {
2853 kotlinx.coroutines.withContext(NamedDispatchers (" foo" )) {
2954 emit(1 )
@@ -34,7 +59,7 @@ class FlowInvariantsTest : TestBase() {
3459 }
3560
3661 @Test
37- fun testCachedInvariantCheckResult () = runTest {
62+ fun testCachedInvariantCheckResult () = runParametrizedTest< Int > { flow ->
3863 flow {
3964 emit(1 )
4065
@@ -55,7 +80,7 @@ class FlowInvariantsTest : TestBase() {
5580 }
5681
5782 @Test
58- fun testWithNameContractViolated () = runTest({ it is IllegalStateException } ) {
83+ fun testWithNameContractViolated () = runParametrizedTest< Int >( IllegalStateException :: class ) { flow ->
5984 flow {
6085 kotlinx.coroutines.withContext(CoroutineName (" foo" )) {
6186 emit(1 )
@@ -86,25 +111,25 @@ class FlowInvariantsTest : TestBase() {
86111 }
87112
88113 @Test
89- fun testScopedJob () = runTest({ it is IllegalStateException } ) {
90- flow { emit(1 ) }.buffer(EmptyCoroutineContext ).collect {
114+ fun testScopedJob () = runParametrizedTest< Int >( IllegalStateException :: class ) { flow ->
115+ flow { emit(1 ) }.buffer(EmptyCoroutineContext , flow ).collect {
91116 expect(1 )
92117 }
93118
94119 finish(2 )
95120 }
96121
97122 @Test
98- fun testScopedJobWithViolation () = runTest({ it is IllegalStateException } ) {
99- flow { emit(1 ) }.buffer(Dispatchers .Unconfined ).collect {
123+ fun testScopedJobWithViolation () = runParametrizedTest< Int >( IllegalStateException :: class ) { flow ->
124+ flow { emit(1 ) }.buffer(Dispatchers .Unconfined , flow ).collect {
100125 expect(1 )
101126 }
102127
103128 finish(2 )
104129 }
105130
106131 @Test
107- fun testMergeViolation () = runTest {
132+ fun testMergeViolation () = runParametrizedTest< Int > { flow ->
108133 fun Flow<Int>.merge (other : Flow <Int >): Flow <Int > = flow {
109134 coroutineScope {
110135 launch {
@@ -130,17 +155,6 @@ class FlowInvariantsTest : TestBase() {
130155 assertFailsWith<IllegalStateException > { flow.trickyMerge(flow).toList() }
131156 }
132157
133- // TODO merge artifact
134- private fun <T > channelFlow (bufferSize : Int = 16, @BuilderInference block : suspend ProducerScope <T >.() -> Unit ): Flow <T > =
135- flow {
136- coroutineScope {
137- val channel = produce(capacity = bufferSize, block = block)
138- channel.consumeEach { value ->
139- emit(value)
140- }
141- }
142- }
143-
144158 @Test
145159 fun testNoMergeViolation () = runTest {
146160 fun Flow<Int>.merge (other : Flow <Int >): Flow <Int > = channelFlow {
@@ -167,7 +181,7 @@ class FlowInvariantsTest : TestBase() {
167181 }
168182
169183 @Test
170- fun testScopedCoroutineNoViolation () = runTest {
184+ fun testScopedCoroutineNoViolation () = runParametrizedTest< Int > { flow ->
171185 fun Flow<Int>.buffer (): Flow <Int > = flow {
172186 coroutineScope {
173187 val channel = produce {
@@ -180,11 +194,10 @@ class FlowInvariantsTest : TestBase() {
180194 }
181195 }
182196 }
183-
184197 assertEquals(listOf (1 , 1 ), flowOf(1 , 1 ).buffer().toList())
185198 }
186199
187- private fun Flow<Int>.buffer (coroutineContext : CoroutineContext ): Flow <Int > = flow {
200+ private fun Flow<Int>.buffer (coroutineContext : CoroutineContext , flow : (suspend FlowCollector < Int >.() -> Unit ) -> Flow < Int > ): Flow <Int > = flow {
188201 coroutineScope {
189202 val channel = Channel <Int >()
190203 launch {
0 commit comments