Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ import kotlinx.coroutines.flow.Flow
import kotlinx.rpc.RpcCall
import kotlinx.rpc.RpcClient
import kotlinx.rpc.grpc.GrpcMetadata
import kotlinx.rpc.grpc.client.GrpcCallOptions
import kotlinx.rpc.grpc.client.internal.ManagedChannel
import kotlinx.rpc.grpc.client.internal.ManagedChannelBuilder
import kotlinx.rpc.grpc.client.internal.applyConfig
import kotlinx.rpc.grpc.client.internal.bidirectionalStreamingRpc
import kotlinx.rpc.grpc.client.internal.buildChannel
import kotlinx.rpc.grpc.client.internal.clientStreamingRpc
Expand All @@ -27,6 +27,7 @@ import kotlinx.rpc.grpc.descriptor.MethodType
import kotlinx.rpc.grpc.descriptor.methodType
import kotlinx.rpc.internal.utils.map.RpcInternalConcurrentHashMap
import kotlin.time.Duration
import kotlin.time.Duration.Companion.seconds

private typealias RequestClient = Any

Expand Down Expand Up @@ -180,9 +181,7 @@ private fun GrpcClient(
builder: ManagedChannelBuilder<*>,
config: GrpcClientConfiguration,
): GrpcClient {
val channel = builder.apply {
config.overrideAuthority?.let { overrideAuthority(it) }
}.buildChannel()
val channel = builder.applyConfig(config).buildChannel()
return GrpcClient(channel, config.messageCodecResolver, config.interceptors)
}

Expand All @@ -198,6 +197,7 @@ private fun GrpcClient(
*/
public class GrpcClientConfiguration internal constructor() {
internal val interceptors: MutableList<ClientInterceptor> = mutableListOf()
internal var keepAlive: KeepAlive? = null

/**
* Configurable resolver used to determine the appropriate codec for a given Kotlin type
Expand Down Expand Up @@ -294,4 +294,55 @@ public class GrpcClientConfiguration internal constructor() {
public fun tls(configure: TlsClientCredentialsBuilder.() -> Unit): ClientCredentials =
TlsClientCredentials(configure)

}
/**
* Configures keep-alive settings for the gRPC client.
*
* Keep-alive allows you to fine-tune the behavior of the client to ensure the connection
* between the client and server remains active according to specific parameters.
*
* By default, keep-alive is disabled.
*
* @param configure A lambda to apply custom configurations to the [KeepAlive] instance.
* The [KeepAlive] settings include:
* - `time`: The maximum amount of time that the channel can be idle before a keep-alive
* ping is sent.
* - `timeout`: The time allowed for a keep-alive ping to complete.
* - `withoutCalls`: Whether to send keep-alive pings even when there are no outstanding
* RPCs on the connection.
*
* @see KeepAlive
*/
public fun keepAlive(configure: KeepAlive.() -> Unit) {
keepAlive = KeepAlive().apply(configure)
}

/**
* Represents keep-alive settings for a gRPC client connection.
*
* Keep-alive ensures that the connection between the client and the server remains active.
* It helps detect connection issues proactively before a request is made and facilitates
* maintaining long-lived idle connections.
*
* Client authors must coordinate with service owners for whether a particular client-side
* setting is acceptable.
*
* @property time Specifies the maximum amount of time the channel can remain idle before a
* keep-alive ping is sent to the server to check the connection state.
* The default value is `Duration.INFINITE`, which disables keep-alive pings when idle.
*
* @property timeout Sets the amount of time to wait for a keep-alive ping response.
* If the server does not respond within this timeout, the connection will be considered broken.
* The default value is 20 seconds.
*
* @property withoutCalls Defines whether keep-alive pings will be sent even when there
* are no active RPCs on the connection. If set to `true`, pings will be sent regardless
* of ongoing calls; otherwise, pings are only sent during active RPCs.
* The default value is `false`.
*/
public class KeepAlive internal constructor() {
public var time: Duration = Duration.INFINITE
public var timeout: Duration = 20.seconds
public var withoutCalls: Boolean = false
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
package kotlinx.rpc.grpc.client.internal

import kotlinx.rpc.grpc.client.ClientCredentials
import kotlinx.rpc.grpc.client.GrpcClientConfiguration
import kotlinx.rpc.internal.utils.InternalRpcApi
import kotlin.time.Duration

Expand Down Expand Up @@ -71,9 +72,7 @@ public interface ManagedChannel {
* Builder class for [ManagedChannel].
*/
@InternalRpcApi
public expect abstract class ManagedChannelBuilder<T : ManagedChannelBuilder<T>> {
public abstract fun overrideAuthority(authority: String): T
}
public expect abstract class ManagedChannelBuilder<T : ManagedChannelBuilder<T>>

@InternalRpcApi
public expect fun ManagedChannelBuilder(
Expand All @@ -88,5 +87,7 @@ public expect fun ManagedChannelBuilder(
credentials: ClientCredentials? = null,
): ManagedChannelBuilder<*>

internal expect fun ManagedChannelBuilder<*>.applyConfig(config: GrpcClientConfiguration): ManagedChannelBuilder<*>

@InternalRpcApi
public expect fun ManagedChannelBuilder<*>.buildChannel(): ManagedChannel
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import io.grpc.Grpc
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.withContext
import kotlinx.rpc.grpc.client.ClientCredentials
import kotlinx.rpc.grpc.client.GrpcClientConfiguration
import kotlinx.rpc.internal.utils.InternalRpcApi
import java.util.concurrent.TimeUnit
import kotlin.time.Duration
Expand Down Expand Up @@ -80,3 +81,14 @@ private class JvmManagedChannel(private val channel: io.grpc.ManagedChannel) : M
override val platformApi: ManagedChannelPlatform
get() = channel
}

internal actual fun ManagedChannelBuilder<*>.applyConfig(config: GrpcClientConfiguration): ManagedChannelBuilder<*> {
config.keepAlive?.let {
keepAliveTime(it.time.inWholeMilliseconds, TimeUnit.MILLISECONDS)
keepAliveTimeout(it.timeout.inWholeMilliseconds, TimeUnit.MILLISECONDS)
keepAliveWithoutCalls(it.withoutCalls)
}

config.overrideAuthority?.let { overrideAuthority(it) }
return this
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
package kotlinx.rpc.grpc.client.internal

import kotlinx.rpc.grpc.client.ClientCredentials
import kotlinx.rpc.grpc.client.GrpcClientConfiguration
import kotlinx.rpc.grpc.client.TlsClientCredentials
import kotlinx.rpc.grpc.internal.internalError
import kotlinx.rpc.internal.utils.InternalRpcApi
Expand All @@ -22,25 +23,23 @@ public actual abstract class ManagedChannelPlatform : GrpcChannel()
*/
@InternalRpcApi
public actual abstract class ManagedChannelBuilder<T : ManagedChannelBuilder<T>> {
public actual abstract fun overrideAuthority(authority: String): T
internal var config: GrpcClientConfiguration? = null
}

internal class NativeManagedChannelBuilder(
private val target: String,
private var credentials: Lazy<ClientCredentials>,
) : ManagedChannelBuilder<NativeManagedChannelBuilder>() {

private var authority: String? = null

override fun overrideAuthority(authority: String): NativeManagedChannelBuilder {
this.authority = authority
return this
}

fun buildChannel(): NativeManagedChannel {
val keepAlive = config?.keepAlive
keepAlive?.run {
require(time.isPositive()) { "keepalive time must be positive" }
require(timeout.isPositive()) { "keepalive timeout must be positive" }
}
return NativeManagedChannel(
target,
authority = authority,
authority = config?.overrideAuthority,
keepAlive = config?.keepAlive,
credentials = credentials.value,
)
}
Expand Down Expand Up @@ -70,3 +69,7 @@ public actual fun ManagedChannelBuilder(target: String, credentials: ClientCrede
}


internal actual fun ManagedChannelBuilder<*>.applyConfig(config: GrpcClientConfiguration): ManagedChannelBuilder<*> {
this.config = config
return this
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@ import cnames.structs.grpc_channel
import kotlinx.atomicfu.atomic
import kotlinx.cinterop.CPointer
import kotlinx.cinterop.ExperimentalForeignApi
import kotlinx.cinterop.MemScope
import kotlinx.cinterop.alloc
import kotlinx.cinterop.allocArray
import kotlinx.cinterop.convert
import kotlinx.cinterop.cstr
import kotlinx.cinterop.memScoped
import kotlinx.cinterop.ptr
Expand All @@ -21,6 +24,7 @@ import kotlinx.coroutines.cancelChildren
import kotlinx.coroutines.withTimeoutOrNull
import kotlinx.rpc.grpc.client.ClientCredentials
import kotlinx.rpc.grpc.client.GrpcCallOptions
import kotlinx.rpc.grpc.client.GrpcClientConfiguration
import kotlinx.rpc.grpc.client.rawDeadline
import kotlinx.rpc.grpc.descriptor.MethodDescriptor
import kotlinx.rpc.grpc.internal.CompletionQueue
Expand Down Expand Up @@ -50,6 +54,7 @@ import kotlin.time.Duration
internal class NativeManagedChannel(
target: String,
val authority: String?,
val keepAlive: GrpcClientConfiguration.KeepAlive?,
// we must store them, otherwise the credentials are getting released
credentials: ClientCredentials,
) : ManagedChannel, ManagedChannelPlatform() {
Expand All @@ -66,22 +71,36 @@ internal class NativeManagedChannel(
private val cq = CompletionQueue()

internal val raw: CPointer<grpc_channel> = memScoped {
val args = authority?.let {
val args = mutableListOf<GrpcArg>()

authority?.let {
// the C Core API doesn't have a way to override the authority (used for TLS SNI) as it
// is available in the Java gRPC implementation.
// instead, it can be done by setting the "grpc.ssl_target_name_override" argument.
val authorityOverride = alloc<grpc_arg> {
type = grpc_arg_type.GRPC_ARG_STRING
key = "grpc.ssl_target_name_override".cstr.ptr
value.string = authority.cstr.ptr
}
args.add(GrpcArg.Str(
key = "grpc.ssl_target_name_override",
value = it
))
}

alloc<grpc_channel_args> {
num_args = 1u
args = authorityOverride.ptr
}
keepAlive?.let {
args.add(GrpcArg.Integer(
key = "grpc.keepalive_time_ms",
value = it.time.inWholeMilliseconds.convert()
))
args.add(GrpcArg.Integer(
key = "grpc.keepalive_timeout_ms",
value = it.timeout.inWholeMilliseconds.convert()
))
args.add(GrpcArg.Integer(
key = "grpc.keepalive_permit_without_calls",
value = if (it.withoutCalls) 1 else 0
))
}
grpc_channel_create(target, credentials.raw, args?.ptr)

var rawArgs = if (args.isNotEmpty()) args.toRaw(this) else null

grpc_channel_create(target, credentials.raw, rawArgs?.ptr)
?: error("Failed to create channel")
}

Expand Down Expand Up @@ -170,3 +189,33 @@ internal class NativeManagedChannel(
}

}

internal sealed class GrpcArg(val key: String) {
internal class Str(key: String, val value: String) : GrpcArg(key)
internal class Integer(key: String, val value: Int) : GrpcArg(key)

internal val rawType: grpc_arg_type
get() = when (this) {
is Str -> grpc_arg_type.GRPC_ARG_STRING
is Integer -> grpc_arg_type.GRPC_ARG_INTEGER
}
}

private fun List<GrpcArg>.toRaw(memScope: MemScope): grpc_channel_args {
with(memScope) {
val arr = allocArray<grpc_arg>(size) {
val arg = get(it)
type = arg.rawType
key = arg.key.cstr.ptr
when (arg) {
is GrpcArg.Str -> value.string = arg.value.cstr.ptr
is GrpcArg.Integer -> value.integer = arg.value.convert()
}
}

return alloc<grpc_channel_args> {
num_args = size.convert()
args = arr
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ class GrpcCompressionTest : GrpcProtoTest() {
block()
}
} finally {
clearNativeEnv("GRPC_GRACE")
clearNativeEnv("GRPC_TRACE")
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
/*
* Copyright 2023-2025 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license.
*/

package kotlinx.rpc.grpc.test.proto

import kotlinx.rpc.RpcServer
import kotlinx.rpc.grpc.test.EchoService
import kotlinx.rpc.grpc.test.EchoServiceImpl
import kotlinx.rpc.registerService
import kotlin.test.Test
import kotlin.test.assertContains
import kotlin.test.assertFailsWith
import kotlin.time.Duration
import kotlin.time.Duration.Companion.seconds

/**
* Tests that the client can configure the compression of requests.
*
* This test is hard to realize on native, as the gRPC-Core doesn't expose internal headers like
* `grpc-encoding` to the user application. This means we cannot verify that the client or sever
* actually sent those headers on native. Instead, we capture the grpc trace output (written to stderr)
* and verify that the client and server actually used the compression algorithm.
*/
class GrpcKeepAliveTest : GrpcProtoTest() {
override fun RpcServer.registerServices() {
return registerService<EchoService> { EchoServiceImpl() }
}

@Test
fun `test keepalive set - should propagate settings to core libraries`() = testKeepAlive(
time = 15.seconds,
timeout = 5.seconds,
withoutCalls = true,
)

@Test
fun `test keepalive negative time - should fail`() {
val error = assertFailsWith<IllegalArgumentException> {
runGrpcTest(
configure = {
keepAlive {
this.time = (-1).seconds
}
}
) {
// not reached
}
}
assertContains(error.message!!, "keepalive time must be positive")
}

@Test
fun `test keepalive negative timeout - should fail`() {
val error = assertFailsWith<IllegalArgumentException> {
runGrpcTest(
configure = {
keepAlive {
this.timeout = (-1).seconds
}
}
) {
// not reached
}
}
assertContains(error.message!!, "keepalive timeout must be positive")
}
}

expect fun GrpcProtoTest.testKeepAlive(
time: Duration,
timeout: Duration,
withoutCalls: Boolean,
)
Loading
Loading