@@ -224,6 +224,63 @@ class PostgresConnectionTests: XCTestCase {
224224 }
225225 }
226226
227+ func testSimpleListenFailsIfConnectionIsClosed( ) async throws {
228+ let ( connection, channel) = try await self . makeTestConnectionWithAsyncTestingChannel ( )
229+
230+ try await connection. closeGracefully ( )
231+
232+ XCTAssertEqual ( channel. isActive, false )
233+
234+ do {
235+ _ = try await connection. listen ( " test_channel " )
236+ XCTFail ( " Expected to fail " )
237+ } catch let error as ChannelError {
238+ XCTAssertEqual ( error, . ioOnClosedChannel)
239+ }
240+ }
241+
242+ func testSimpleListenFailsIfConnectionIsClosedWhileListening( ) async throws {
243+ let ( connection, channel) = try await self . makeTestConnectionWithAsyncTestingChannel ( )
244+
245+ try await withThrowingTaskGroup ( of: Void . self) { taskGroup in
246+ taskGroup. addTask {
247+ let events = try await connection. listen ( " foo " )
248+ var iterator = events. makeAsyncIterator ( )
249+ let first = try await iterator. next ( )
250+ XCTAssertEqual ( first? . payload, " wooohooo " )
251+ do {
252+ _ = try await iterator. next ( )
253+ XCTFail ( " Did not expect to not throw " )
254+ } catch let error as PSQLError {
255+ XCTAssertEqual ( error. code, . clientClosedConnection)
256+ }
257+ }
258+
259+ let listenMessage = try await channel. waitForUnpreparedRequest ( )
260+ XCTAssertEqual ( listenMessage. parse. query, #"LISTEN "foo";"# )
261+
262+ try await channel. writeInbound ( PostgresBackendMessage . parseComplete)
263+ try await channel. writeInbound ( PostgresBackendMessage . parameterDescription ( . init( dataTypes: [ ] ) ) )
264+ try await channel. writeInbound ( PostgresBackendMessage . noData)
265+ try await channel. writeInbound ( PostgresBackendMessage . bindComplete)
266+ try await channel. writeInbound ( PostgresBackendMessage . commandComplete ( " LISTEN " ) )
267+ try await channel. writeInbound ( PostgresBackendMessage . readyForQuery ( . idle) )
268+
269+ try await channel. writeInbound ( PostgresBackendMessage . notification ( . init( backendPID: 12 , channel: " foo " , payload: " wooohooo " ) ) )
270+
271+ try await connection. close ( )
272+
273+ XCTAssertEqual ( channel. isActive, false )
274+
275+ switch await taskGroup. nextResult ( ) ! {
276+ case . success:
277+ break
278+ case . failure( let failure) :
279+ XCTFail ( " Unexpected error: \( failure) " )
280+ }
281+ }
282+ }
283+
227284 func testCloseGracefullyClosesWhenInternalQueueIsEmpty( ) async throws {
228285 let ( connection, channel) = try await self . makeTestConnectionWithAsyncTestingChannel ( )
229286 try await withThrowingTaskGroup ( of: Void . self) { [ logger] taskGroup async throws -> ( ) in
@@ -638,6 +695,118 @@ class PostgresConnectionTests: XCTestCase {
638695 }
639696 }
640697
698+ func testQueryFailsIfConnectionIsClosed( ) async throws {
699+ let ( connection, channel) = try await self . makeTestConnectionWithAsyncTestingChannel ( )
700+
701+ try await connection. closeGracefully ( )
702+
703+ XCTAssertEqual ( channel. isActive, false )
704+
705+ do {
706+ _ = try await connection. query ( " SELECT version; " , logger: self . logger)
707+ XCTFail ( " Expected to fail " )
708+ } catch let error as ChannelError {
709+ XCTAssertEqual ( error, . ioOnClosedChannel)
710+ }
711+ }
712+
713+ func testPrepareStatementFailsIfConnectionIsClosed( ) async throws {
714+ let ( connection, channel) = try await self . makeTestConnectionWithAsyncTestingChannel ( )
715+
716+ try await connection. closeGracefully ( )
717+
718+ XCTAssertEqual ( channel. isActive, false )
719+
720+ do {
721+ _ = try await connection. prepareStatement ( " SELECT version; " , with: " test_query " , logger: . psqlTest) . get ( )
722+ XCTFail ( " Expected to fail " )
723+ } catch let error as ChannelError {
724+ XCTAssertEqual ( error, . ioOnClosedChannel)
725+ }
726+ }
727+
728+ func testExecuteFailsIfConnectionIsClosed( ) async throws {
729+ let ( connection, channel) = try await self . makeTestConnectionWithAsyncTestingChannel ( )
730+
731+ try await connection. closeGracefully ( )
732+
733+ XCTAssertEqual ( channel. isActive, false )
734+
735+ do {
736+ let statement = PSQLExecuteStatement ( name: " SELECT version; " , binds: . init( ) , rowDescription: nil )
737+ _ = try await connection. execute ( statement, logger: . psqlTest) . get ( )
738+ XCTFail ( " Expected to fail " )
739+ } catch let error as ChannelError {
740+ XCTAssertEqual ( error, . ioOnClosedChannel)
741+ }
742+ }
743+
744+ func testExecutePreparedStatementFailsIfConnectionIsClosed( ) async throws {
745+ let ( connection, channel) = try await self . makeTestConnectionWithAsyncTestingChannel ( )
746+
747+ try await connection. closeGracefully ( )
748+
749+ XCTAssertEqual ( channel. isActive, false )
750+
751+ struct TestPreparedStatement : PostgresPreparedStatement {
752+ static let sql = " SELECT pid, datname FROM pg_stat_activity WHERE state = $1 "
753+ typealias Row = ( Int , String )
754+
755+ var state : String
756+
757+ func makeBindings( ) -> PostgresBindings {
758+ var bindings = PostgresBindings ( )
759+ bindings. append ( self . state)
760+ return bindings
761+ }
762+
763+ func decodeRow( _ row: PostgresNIO . PostgresRow ) throws -> Row {
764+ try row. decode ( Row . self)
765+ }
766+ }
767+
768+ do {
769+ let preparedStatement = TestPreparedStatement ( state: " active " )
770+ _ = try await connection. execute ( preparedStatement, logger: . psqlTest)
771+ XCTFail ( " Expected to fail " )
772+ } catch let error as ChannelError {
773+ XCTAssertEqual ( error, . ioOnClosedChannel)
774+ }
775+ }
776+
777+ func testExecutePreparedStatementWithVoidRowFailsIfConnectionIsClosed( ) async throws {
778+ let ( connection, channel) = try await self . makeTestConnectionWithAsyncTestingChannel ( )
779+
780+ try await connection. closeGracefully ( )
781+
782+ XCTAssertEqual ( channel. isActive, false )
783+
784+ struct TestPreparedStatement : PostgresPreparedStatement {
785+ static let sql = " SELECT * FROM pg_stat_activity WHERE state = $1 "
786+ typealias Row = ( )
787+
788+ var state : String
789+
790+ func makeBindings( ) -> PostgresBindings {
791+ var bindings = PostgresBindings ( )
792+ bindings. append ( self . state)
793+ return bindings
794+ }
795+
796+ func decodeRow( _ row: PostgresNIO . PostgresRow ) throws -> Row {
797+ ( )
798+ }
799+ }
800+
801+ do {
802+ let preparedStatement = TestPreparedStatement ( state: " active " )
803+ _ = try await connection. execute ( preparedStatement, logger: . psqlTest)
804+ XCTFail ( " Expected to fail " )
805+ } catch let error as ChannelError {
806+ XCTAssertEqual ( error, . ioOnClosedChannel)
807+ }
808+ }
809+
641810 func makeTestConnectionWithAsyncTestingChannel( ) async throws -> ( PostgresConnection , NIOAsyncTestingChannel ) {
642811 let eventLoop = NIOAsyncTestingEventLoop ( )
643812 let channel = await NIOAsyncTestingChannel ( handlers: [
0 commit comments