@@ -317,6 +317,40 @@ class PostgresConnectionTests: XCTestCase {
317317 }
318318 }
319319
320+ func testCloseImmediatelyWithSimpleQuery( ) async throws {
321+ let ( connection, channel) = try await self . makeTestConnectionWithAsyncTestingChannel ( )
322+
323+ try await withThrowingTaskGroup ( of: Void . self) { [ logger] taskGroup async throws -> ( ) in
324+ for _ in 1 ... 2 {
325+ taskGroup. addTask {
326+ try await connection. __simpleQuery ( " SELECT 1; " , logger: logger)
327+ }
328+ }
329+
330+ let query = try await channel. waitForSimpleQueryRequest ( )
331+ XCTAssertEqual ( query. query, " SELECT 1; " )
332+
333+ async let close : ( ) = connection. close ( )
334+
335+ try await channel. closeFuture. get ( )
336+ XCTAssertEqual ( channel. isActive, false )
337+
338+ try await close
339+
340+ while let taskResult = await taskGroup. nextResult ( ) {
341+ switch taskResult {
342+ case . success:
343+ XCTFail ( " Expected queries to fail " )
344+ case . failure( let failure) :
345+ guard let error = failure as? PSQLError else {
346+ return XCTFail ( " Unexpected error type: \( failure) " )
347+ }
348+ XCTAssertEqual ( error. code, . clientClosedConnection)
349+ }
350+ }
351+ }
352+ }
353+
320354 func testIfServerJustClosesTheErrorReflectsThat( ) async throws {
321355 let ( connection, channel) = try await self . makeTestConnectionWithAsyncTestingChannel ( )
322356 let logger = self . logger
@@ -346,6 +380,35 @@ class PostgresConnectionTests: XCTestCase {
346380 }
347381 }
348382
383+ func testIfServerJustClosesTheErrorReflectsThatInSimpleQuery( ) async throws {
384+ let ( connection, channel) = try await self . makeTestConnectionWithAsyncTestingChannel ( )
385+ let logger = self . logger
386+
387+ async let response = try await connection. __simpleQuery ( " SELECT 1; " , logger: logger)
388+
389+ let query = try await channel. waitForSimpleQueryRequest ( )
390+ XCTAssertEqual ( query. query, " SELECT 1; " )
391+
392+ try await channel. testingEventLoop. executeInContext { channel. pipeline. fireChannelInactive ( ) }
393+ try await channel. testingEventLoop. executeInContext { channel. pipeline. fireChannelUnregistered ( ) }
394+
395+ do {
396+ _ = try await response
397+ XCTFail ( " Expected to throw " )
398+ } catch {
399+ XCTAssertEqual ( ( error as? PSQLError ) ? . code, . serverClosedConnection)
400+ }
401+
402+ // retry on same connection
403+
404+ do {
405+ _ = try await connection. __simpleQuery ( " SELECT 1; " , logger: self . logger)
406+ XCTFail ( " Expected to throw " )
407+ } catch {
408+ XCTAssertEqual ( ( error as? PSQLError ) ? . code, . serverClosedConnection)
409+ }
410+ }
411+
349412 struct TestPrepareStatement : PostgresPreparedStatement {
350413 static let sql = " SELECT datname FROM pg_stat_activity WHERE state = $1 "
351414 typealias Row = String
@@ -692,6 +755,14 @@ extension NIOAsyncTestingChannel {
692755 return UnpreparedRequest ( parse: parse, describe: describe, bind: bind, execute: execute)
693756 }
694757
758+ func waitForSimpleQueryRequest( ) async throws -> PostgresFrontendMessage . Query {
759+ let query = try await self . waitForOutboundWrite ( as: PostgresFrontendMessage . self)
760+ guard case . query( let query) = query else {
761+ fatalError ( )
762+ }
763+ return query
764+ }
765+
695766 func waitForPrepareRequest( ) async throws -> PrepareRequest {
696767 let parse = try await self . waitForOutboundWrite ( as: PostgresFrontendMessage . self)
697768 let describe = try await self . waitForOutboundWrite ( as: PostgresFrontendMessage . self)
0 commit comments