Skip to content

Commit eab8f7c

Browse files
authored
Add fake server channel handler to test server close (#192)
Signed-off-by: Adam Fowler <adamfowler71@gmail.com>
1 parent 9a64b25 commit eab8f7c

File tree

3 files changed

+127
-4
lines changed

3 files changed

+127
-4
lines changed

Benchmarks/ValkeyBenchmarks/Shared.swift

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,13 +51,10 @@ func makeLocalServer(commandHandler: some BenchmarkCommandHandler = BenchmarkGet
5151
try await ServerBootstrap(group: NIOSingletons.posixEventLoopGroup)
5252
.serverChannelOption(.socketOption(.so_reuseaddr), value: 1)
5353
.childChannelInitializer { channel in
54-
do {
54+
channel.eventLoop.makeCompletedFuture {
5555
try channel.pipeline.syncOperations.addHandler(
5656
ValkeyServerChannelHandler(commandHandler: commandHandler)
5757
)
58-
return channel.eventLoop.makeSucceededVoidFuture()
59-
} catch {
60-
return channel.eventLoop.makeFailedFuture(error)
6158
}
6259
}
6360
.bind(host: "127.0.0.1", port: 0)
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
//===----------------------------------------------------------------------===//
2+
//
3+
// This source file is part of the valkey-swift open source project
4+
//
5+
// Copyright (c) 2025 the valkey-swift project authors
6+
// Licensed under Apache License v2.0
7+
//
8+
// See LICENSE.txt for license information
9+
// See CONTRIBUTORS.txt for the list of valkey-swift project authors
10+
//
11+
// SPDX-License-Identifier: Apache-2.0
12+
//
13+
//===----------------------------------------------------------------------===//
14+
15+
import NIOCore
16+
import Valkey
17+
18+
// Fake Valkey server channel handler
19+
final class TestValkeyServerChannelHandler: ChannelInboundHandler {
20+
typealias InboundIn = ByteBuffer
21+
typealias OutboundOut = ByteBuffer
22+
23+
private var decoder = NIOSingleStepByteToMessageProcessor(RESPTokenDecoder())
24+
private let commandHandler: (String, ArraySlice<String>, (ByteBuffer) -> Void) -> Void
25+
26+
static private let helloCommand = RESPToken.Value.bulkString(ByteBuffer(string: "HELLO"))
27+
static private let helloResponse = ByteBuffer(string: "%1\r\n+server\r\n+fake\r\n")
28+
static private let pingCommand = RESPToken.Value.bulkString(ByteBuffer(string: "PING"))
29+
static private let pongResponse = ByteBuffer(string: "$4\r\nPONG\r\n")
30+
static private let clientCommand = RESPToken.Value.bulkString(ByteBuffer(string: "CLIENT"))
31+
static private let setInfoSubCommand = RESPToken.Value.bulkString(ByteBuffer(string: "SETINFO"))
32+
static private let okResponse = ByteBuffer(string: "+2OK\r\n")
33+
34+
static private let response = ByteBuffer(string: "$3\r\nBar\r\n")
35+
static func defaultHandler(command: String, parameters: ArraySlice<String>, write: (ByteBuffer) -> Void) {
36+
guard command == "GET" else {
37+
fatalError("Unexpected command: \(command)")
38+
}
39+
write(Self.response)
40+
}
41+
42+
init(commandHandler: @escaping (String, ArraySlice<String>, (ByteBuffer) -> Void) -> Void = defaultHandler) {
43+
self.commandHandler = commandHandler
44+
}
45+
46+
func channelRead(context: ChannelHandlerContext, data: NIOAny) {
47+
try! self.decoder.process(buffer: self.unwrapInboundIn(data)) { token in
48+
self.handleToken(context: context, token: token)
49+
}
50+
}
51+
52+
func handleToken(context: ChannelHandlerContext, token: RESPToken) {
53+
guard let fullCommand = try? token.decode(as: [String].self) else {
54+
fatalError()
55+
}
56+
guard let command = fullCommand.first else {
57+
fatalError()
58+
}
59+
let parameters = fullCommand.dropFirst()
60+
switch command {
61+
case "HELLO":
62+
context.writeAndFlush(self.wrapOutboundOut(Self.helloResponse), promise: nil)
63+
64+
case "PING":
65+
context.writeAndFlush(self.wrapOutboundOut(Self.pongResponse), promise: nil)
66+
67+
case "CLIENT":
68+
switch parameters.first {
69+
case "SETINFO":
70+
context.writeAndFlush(self.wrapOutboundOut(Self.okResponse), promise: nil)
71+
default:
72+
commandHandler(command, parameters) {
73+
context.writeAndFlush(self.wrapOutboundOut($0), promise: nil)
74+
}
75+
}
76+
77+
default:
78+
commandHandler(command, parameters) {
79+
context.writeAndFlush(self.wrapOutboundOut($0), promise: nil)
80+
}
81+
}
82+
}
83+
}

Tests/ValkeyTests/ValkeyConnectionTests.swift

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import Logging
1616
import NIOCore
1717
import NIOEmbedded
18+
import NIOPosix
1819
import Testing
1920

2021
@testable import Valkey
@@ -449,4 +450,46 @@ struct ConnectionTests {
449450
// verify connection hasnt been closed
450451
#expect(channel.isActive == true)
451452
}
453+
454+
@Test
455+
@available(valkeySwift 1.0, *)
456+
func testCloseOnServeClose() async throws {
457+
let channel = try await ServerBootstrap(group: NIOSingletons.posixEventLoopGroup)
458+
.serverChannelOption(.socketOption(.so_reuseaddr), value: 1)
459+
.childChannelOption(ChannelOptions.allowRemoteHalfClosure, value: true)
460+
.childChannelInitializer { channel in
461+
channel.eventLoop.makeCompletedFuture {
462+
try channel.pipeline.syncOperations.addHandler(
463+
TestValkeyServerChannelHandler { command, _, write in
464+
switch command {
465+
case "QUIT":
466+
write(ByteBuffer(string: "+2OK\r\n"))
467+
channel.close(mode: .output, promise: nil)
468+
default:
469+
fatalError("Unexpected command: \(command)")
470+
}
471+
472+
}
473+
)
474+
}
475+
}
476+
.bind(host: "127.0.0.1", port: 0)
477+
.get()
478+
let port = channel.localAddress!.port!
479+
try await ValkeyConnection.withConnection(
480+
address: .hostname("127.0.0.1", port: port),
481+
configuration: .init(),
482+
eventLoop: MultiThreadedEventLoopGroup.singleton.any(),
483+
logger: Logger(label: "test")
484+
) { connection in
485+
let clientChannel = await connection.channel
486+
try await connection.quit()
487+
await withCheckedContinuation { cont in
488+
clientChannel.closeFuture.whenComplete { _ in
489+
cont.resume()
490+
}
491+
}
492+
}
493+
try await channel.close()
494+
}
452495
}

0 commit comments

Comments
 (0)