Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion Sources/PostgresNIO/Connection/PostgresConnection.swift
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,14 @@ extension PostgresConnection {

let task = HandlerTask.startListening(listener)

self.channel.write(task, promise: nil)
let promise = self.channel.eventLoop.makePromise(of: Void.self)
promise.futureResult.whenFailure { error in
self.logger.debug("Channel error in listen()",
metadata: [.error: "\(error)"])
listener.failed(PSQLError(code: .listenFailed))
}

self.channel.write(task, promise: promise)
}
} onCancel: {
let task = HandlerTask.cancelListening(channel, id)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ extension ListenStateMachine {
mutating func fail(_ error: Error) -> FailAction {
switch self.state {
case .initialized:
fatalError("Invalid state: \(self.state)")
return .none

case .starting(let listeners), .listening(let listeners), .stopping(let listeners):
self.state = .failed(error)
Expand Down
11 changes: 11 additions & 0 deletions Sources/PostgresNIO/New/NotificationListener.swift
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,17 @@ final class NotificationListener: @unchecked Sendable {
case done
}

deinit {
switch self.state {
case .streamInitialized:
preconditionFailure("Notification continuation had not been used")
case .closure:
preconditionFailure("Notification closure had not been used")
case .streamListening, .done:
break
}
}

init(
channel: String,
id: Int,
Expand Down
1 change: 1 addition & 0 deletions Sources/PostgresNIO/New/PostgresChannelHandler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ final class PostgresChannelHandler: ChannelDuplexHandler {
psqlTask = .extendedQuery(query)

case .startListening(let listener):
defer { promise?.succeed(()) }
switch self.listenState.startListening(listener) {
case .startListening(let channel):
psqlTask = self.makeStartListeningQuery(channel: channel, context: context)
Expand Down
95 changes: 95 additions & 0 deletions Tests/IntegrationTests/AsyncTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,101 @@ final class AsyncPostgresConnectionTests: XCTestCase {
}
}

func testListenTwiceChannel() async throws {
let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1)
defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) }
let eventLoop = eventLoopGroup.next()

try await self.withTestConnection(on: eventLoop) { connection in
// Concurrently listen on a channel that is initially closed
async let stream1later = connection.listen("same-channel")
async let stream2later = connection.listen("same-channel")
let (stream1, stream2) = try await (stream1later, stream2later)

try await self.withTestConnection(on: eventLoop) { other in
try await other.query(#"NOTIFY "\#(unescaped: "same-channel")";"#, logger: .psqlTest)
}

var stream1EventReceived = false
var stream2EventReceived = false

for try await _ in stream1 {
stream1EventReceived = true
break
}

for try await _ in stream2 {
stream2EventReceived = true
break
}

XCTAssertTrue(stream1EventReceived)
XCTAssertTrue(stream2EventReceived)
}
}

func testListenOnClosedChannel() async throws {
let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1)
defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) }
let eventLoop = eventLoopGroup.next()

try await self.withTestConnection(on: eventLoop) { connection in
try await connection.close()
do {
_ = try await connection.listen("futile")
XCTFail("Expected not to get any events")
} catch let error as PSQLError where error.code == .listenFailed {
// Expected
}
}
}

func testListenThenCloseChannel() async throws {
let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1)
defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) }
let eventLoop = eventLoopGroup.next()

try await self.withTestConnection(on: eventLoop) { connection in
let stream = try await connection.listen("hopeful")
try await connection.close()
do {
for try await _ in stream {
XCTFail("Expected not to get any events")
}
XCTFail("Expected not to have reached the end of stream")
} catch is PSQLError {
// Expected
}
}
}

func testListenThenClosingChannel() async throws {
let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1)
defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) }
let eventLoop = eventLoopGroup.next()

try await self.withTestConnection(on: eventLoop) { connection in
_ = try await connection.listen("initial")
async let asyncClose: () = connection.close()
let stream: PostgresNotificationSequence
do {
stream = try await connection.listen("hopeful")
} catch let error as PSQLError where error.code == .listenFailed {
// Expected
return
}
try await asyncClose
do {
for try await _ in stream {
XCTFail("Expected not to get any events")
}
XCTFail("Expected not to have reached the end of stream")
} catch is PSQLError {
// Expected
}
}
}

#if canImport(Network)
func testSelect10kRowsNetworkFramework() async throws {
let eventLoopGroup = NIOTSEventLoopGroup()
Expand Down