diff --git a/Sources/PostgresNIO/Connection/PostgresConnection.swift b/Sources/PostgresNIO/Connection/PostgresConnection.swift index fc48fa31..7e2d6f63 100644 --- a/Sources/PostgresNIO/Connection/PostgresConnection.swift +++ b/Sources/PostgresNIO/Connection/PostgresConnection.swift @@ -455,7 +455,14 @@ extension PostgresConnection { let task = HandlerTask.startListening(listener) - self.channel.write(task, promise: nil) + let promise = self.channel.eventLoop.makePromise(of: Void.self) + promise.futureResult.whenFailure { error in + self.logger.debug("Channel error in listen()", + metadata: [.error: "\(error)"]) + listener.failed(PSQLError(code: .listenFailed)) + } + + self.channel.write(task, promise: promise) } } onCancel: { let task = HandlerTask.cancelListening(channel, id) diff --git a/Sources/PostgresNIO/New/Connection State Machine/ListenStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/ListenStateMachine.swift index 89f40469..2cc446c5 100644 --- a/Sources/PostgresNIO/New/Connection State Machine/ListenStateMachine.swift +++ b/Sources/PostgresNIO/New/Connection State Machine/ListenStateMachine.swift @@ -224,7 +224,7 @@ extension ListenStateMachine { mutating func fail(_ error: Error) -> FailAction { switch self.state { case .initialized: - fatalError("Invalid state: \(self.state)") + return .none case .starting(let listeners), .listening(let listeners), .stopping(let listeners): self.state = .failed(error) diff --git a/Sources/PostgresNIO/New/NotificationListener.swift b/Sources/PostgresNIO/New/NotificationListener.swift index 2f784e33..668bb670 100644 --- a/Sources/PostgresNIO/New/NotificationListener.swift +++ b/Sources/PostgresNIO/New/NotificationListener.swift @@ -17,6 +17,17 @@ final class NotificationListener: @unchecked Sendable { case done } + deinit { + switch self.state { + case .streamInitialized: + preconditionFailure("Notification continuation had not been used") + case .closure: + preconditionFailure("Notification closure had not been used") + case .streamListening, .done: + break + } + } + init( channel: String, id: Int, diff --git a/Sources/PostgresNIO/New/PostgresChannelHandler.swift b/Sources/PostgresNIO/New/PostgresChannelHandler.swift index bc256203..05a2a840 100644 --- a/Sources/PostgresNIO/New/PostgresChannelHandler.swift +++ b/Sources/PostgresNIO/New/PostgresChannelHandler.swift @@ -209,6 +209,7 @@ final class PostgresChannelHandler: ChannelDuplexHandler { psqlTask = .extendedQuery(query) case .startListening(let listener): + defer { promise?.succeed(()) } switch self.listenState.startListening(listener) { case .startListening(let channel): psqlTask = self.makeStartListeningQuery(channel: channel, context: context) diff --git a/Tests/IntegrationTests/AsyncTests.swift b/Tests/IntegrationTests/AsyncTests.swift index b4c8e93f..92c62fd4 100644 --- a/Tests/IntegrationTests/AsyncTests.swift +++ b/Tests/IntegrationTests/AsyncTests.swift @@ -284,6 +284,101 @@ final class AsyncPostgresConnectionTests: XCTestCase { } } + func testListenTwiceChannel() async throws { + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } + let eventLoop = eventLoopGroup.next() + + try await self.withTestConnection(on: eventLoop) { connection in + // Concurrently listen on a channel that is initially closed + async let stream1later = connection.listen("same-channel") + async let stream2later = connection.listen("same-channel") + let (stream1, stream2) = try await (stream1later, stream2later) + + try await self.withTestConnection(on: eventLoop) { other in + try await other.query(#"NOTIFY "\#(unescaped: "same-channel")";"#, logger: .psqlTest) + } + + var stream1EventReceived = false + var stream2EventReceived = false + + for try await _ in stream1 { + stream1EventReceived = true + break + } + + for try await _ in stream2 { + stream2EventReceived = true + break + } + + XCTAssertTrue(stream1EventReceived) + XCTAssertTrue(stream2EventReceived) + } + } + + func testListenOnClosedChannel() async throws { + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } + let eventLoop = eventLoopGroup.next() + + try await self.withTestConnection(on: eventLoop) { connection in + try await connection.close() + do { + _ = try await connection.listen("futile") + XCTFail("Expected not to get any events") + } catch let error as PSQLError where error.code == .listenFailed { + // Expected + } + } + } + + func testListenThenCloseChannel() async throws { + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } + let eventLoop = eventLoopGroup.next() + + try await self.withTestConnection(on: eventLoop) { connection in + let stream = try await connection.listen("hopeful") + try await connection.close() + do { + for try await _ in stream { + XCTFail("Expected not to get any events") + } + XCTFail("Expected not to have reached the end of stream") + } catch is PSQLError { + // Expected + } + } + } + + func testListenThenClosingChannel() async throws { + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } + let eventLoop = eventLoopGroup.next() + + try await self.withTestConnection(on: eventLoop) { connection in + _ = try await connection.listen("initial") + async let asyncClose: () = connection.close() + let stream: PostgresNotificationSequence + do { + stream = try await connection.listen("hopeful") + } catch let error as PSQLError where error.code == .listenFailed { + // Expected + return + } + try await asyncClose + do { + for try await _ in stream { + XCTFail("Expected not to get any events") + } + XCTFail("Expected not to have reached the end of stream") + } catch is PSQLError { + // Expected + } + } + } + #if canImport(Network) func testSelect10kRowsNetworkFramework() async throws { let eventLoopGroup = NIOTSEventLoopGroup()