@@ -317,6 +317,104 @@ final class PSQLRowStreamTests: XCTestCase {
317317 XCTAssertNoThrow ( try future. wait ( ) )
318318 }
319319
320+ func testDrainBeforeStreamHasFinishedWhenThereIsAlreadyAConsumer( ) {
321+ let dataSource = CountingDataSource ( )
322+ let stream = PSQLRowStream (
323+ source: . stream(
324+ [ self . makeColumnDescription ( name: " foo " , dataType: . text, format: . binary) ] ,
325+ dataSource
326+ ) ,
327+ eventLoop: self . eventLoop,
328+ logger: self . logger
329+ )
330+ XCTAssertEqual ( dataSource. hitDemand, 0 )
331+ XCTAssertEqual ( dataSource. hitCancel, 0 )
332+
333+ stream. receive ( [
334+ [ ByteBuffer ( string: " 0 " ) ] ,
335+ [ ByteBuffer ( string: " 1 " ) ]
336+ ] )
337+
338+ XCTAssertEqual ( dataSource. hitDemand, 0 , " Before we have a consumer demand is not signaled " )
339+
340+ // attach consumers
341+ let allFuture = stream. all ( )
342+ XCTAssertEqual ( dataSource. hitDemand, 1 )
343+ let drainFuture = stream. drain ( )
344+ XCTAssertEqual ( dataSource. hitDemand, 2 )
345+
346+ stream. receive ( [
347+ [ ByteBuffer ( string: " 2 " ) ] ,
348+ [ ByteBuffer ( string: " 3 " ) ]
349+ ] )
350+ XCTAssertEqual ( dataSource. hitDemand, 3 )
351+
352+ stream. receive ( [
353+ [ ByteBuffer ( string: " 4 " ) ] ,
354+ [ ByteBuffer ( string: " 5 " ) ]
355+ ] )
356+ XCTAssertEqual ( dataSource. hitDemand, 4 )
357+
358+ stream. receive ( completion: . success( " SELECT 2 " ) )
359+
360+ XCTAssertNoThrow ( try drainFuture. wait ( ) )
361+
362+ var rows : [ PostgresRow ] ?
363+ XCTAssertNoThrow ( rows = try allFuture. wait ( ) )
364+ XCTAssertEqual ( rows? . count, 6 )
365+ }
366+
367+ func testDrainBeforeStreamHasFinishedWhenThereIsAlreadyAnAsyncConsumer( ) {
368+ let dataSource = CountingDataSource ( )
369+ let stream = PSQLRowStream (
370+ source: . stream(
371+ [ self . makeColumnDescription ( name: " foo " , dataType: . text, format: . binary) ] ,
372+ dataSource
373+ ) ,
374+ eventLoop: self . eventLoop,
375+ logger: self . logger
376+ )
377+ XCTAssertEqual ( dataSource. hitDemand, 0 )
378+ XCTAssertEqual ( dataSource. hitCancel, 0 )
379+
380+ stream. receive ( [
381+ [ ByteBuffer ( string: " 0 " ) ] ,
382+ [ ByteBuffer ( string: " 1 " ) ]
383+ ] )
384+
385+ XCTAssertEqual ( dataSource. hitDemand, 0 , " Before we have a consumer demand is not signaled " )
386+
387+ // attach consumers
388+ let rowSequence = stream. asyncSequence ( )
389+ XCTAssertEqual ( dataSource. hitDemand, 0 )
390+ let drainFuture = stream. drain ( )
391+ XCTAssertEqual ( dataSource. hitDemand, 1 )
392+
393+ stream. receive ( [
394+ [ ByteBuffer ( string: " 2 " ) ] ,
395+ [ ByteBuffer ( string: " 3 " ) ]
396+ ] )
397+ XCTAssertEqual ( dataSource. hitDemand, 2 )
398+
399+ stream. receive ( [
400+ [ ByteBuffer ( string: " 4 " ) ] ,
401+ [ ByteBuffer ( string: " 5 " ) ]
402+ ] )
403+ XCTAssertEqual ( dataSource. hitDemand, 3 )
404+
405+ stream. receive ( completion: . success( " SELECT 2 " ) )
406+
407+ XCTAssertNoThrow ( try drainFuture. wait ( ) )
408+
409+ XCTAssertNoThrow {
410+ let rows = try stream. eventLoop. makeFutureWithTask {
411+ try ? await rowSequence. collect ( )
412+ } . wait ( )
413+ XCTAssertEqual ( dataSource. hitDemand, 4 )
414+ XCTAssertEqual ( rows? . count, 6 )
415+ }
416+ }
417+
320418 func makeColumnDescription( name: String , dataType: PostgresDataType , format: PostgresFormat ) -> RowDescription . Column {
321419 RowDescription . Column (
322420 name: " test " ,
0 commit comments