From 05c2e773c57e59b3b48995818c59d034a036e1dd Mon Sep 17 00:00:00 2001 From: Ruediger Klaehn Date: Thu, 14 Nov 2024 16:28:13 +0200 Subject: [PATCH 1/2] Add a function to RpcServer that runs a complete accept loop --- src/server.rs | 54 ++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 53 insertions(+), 1 deletion(-) diff --git a/src/server.rs b/src/server.rs index 76644b25..565157bd 100644 --- a/src/server.rs +++ b/src/server.rs @@ -7,13 +7,15 @@ use std::{ marker::PhantomData, pin::Pin, result, + sync::Arc, task::{self, Poll}, }; use futures_lite::{Future, Stream, StreamExt}; use futures_util::{SinkExt, TryStreamExt}; use pin_project::pin_project; -use tokio::sync::oneshot; +use tokio::{sync::oneshot, task::JoinSet}; +use tracing::{error, warn}; use crate::{ transport::{ @@ -211,6 +213,56 @@ impl> RpcServer { pub fn into_inner(self) -> C { self.source } + + /// Run an accept loop for this server. + /// + /// Each request will be handled in a separate task. + /// + /// It is the caller's responsibility to poll the returned future to drive the server. + pub async fn accept_loop(self, handler: Fun) + where + S: Service, + C: Listener, + Fun: Fn(S::Req, RpcChannel) -> Fut + Send + Sync + 'static, + Fut: Future> + Send + 'static, + E: Into + 'static, + { + let handler = Arc::new(handler); + let mut tasks = JoinSet::new(); + loop { + tokio::select! { + Some(res) = tasks.join_next(), if !tasks.is_empty() => { + if let Err(e) = res { + if e.is_panic() { + error!("Panic handling RPC request: {e}"); + } + } + } + req = self.accept() => { + let req = match req { + Ok(req) => req, + Err(e) => { + warn!("Error accepting RPC request: {e}"); + continue; + } + }; + let handler = handler.clone(); + tasks.spawn(async move { + let (req, chan) = match req.read_first().await { + Ok((req, chan)) => (req, chan), + Err(e) => { + warn!("Error reading first message: {e}"); + return; + } + }; + if let Err(cause) = handler(req, chan).await { + warn!("Error handling RPC request: {}", cause.into()); + } + }); + } + } + } + } } impl> AsRef for RpcServer { From 6e17b6a97d94c0dd31a31e3470ca9f27b4c12f5d Mon Sep 17 00:00:00 2001 From: Ruediger Klaehn Date: Thu, 14 Nov 2024 17:50:20 +0200 Subject: [PATCH 2/2] Add another function that spawns an accept loop. It doesn't do much, but saves quite a bit of boilerplate at the call site, and also avoids having the task leaking by using AbortOnDropHandle. --- Cargo.toml | 9 +++++---- examples/modularize.rs | 21 ++++----------------- src/server.rs | 13 +++++++++++++ tests/flume.rs | 16 +++------------- tests/hyper.rs | 21 +++++---------------- tests/iroh-net.rs | 38 +++++++++++++++----------------------- tests/math.rs | 23 +++++++++-------------- tests/quinn.rs | 39 ++++++++++++++++++++------------------- 8 files changed, 74 insertions(+), 106 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 99da459e..08e5691f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,7 +27,7 @@ quinn = { package = "iroh-quinn", version = "0.12", optional = true } serde = { version = "1.0.183", features = ["derive"] } tokio = { version = "1", default-features = false, features = ["macros", "sync"] } tokio-serde = { version = "0.8", features = ["bincode"], optional = true } -tokio-util = { version = "0.7", features = ["codec"], optional = true } +tokio-util = { version = "0.7", features = ["rt"] } tracing = "0.1" hex = "0.4.3" futures = { version = "0.3.30", optional = true } @@ -52,12 +52,13 @@ proc-macro2 = "1.0.66" futures-buffered = "0.2.4" testresult = "0.4.1" nested_enum_utils = "0.1.0" +tokio-util = { version = "0.7", features = ["rt"] } [features] -hyper-transport = ["dep:flume", "dep:hyper", "dep:bincode", "dep:bytes", "dep:tokio-serde", "dep:tokio-util"] -quinn-transport = ["dep:flume", "dep:quinn", "dep:bincode", "dep:tokio-serde", "dep:tokio-util"] +hyper-transport = ["dep:flume", "dep:hyper", "dep:bincode", "dep:bytes", "dep:tokio-serde", "tokio-util/codec"] +quinn-transport = ["dep:flume", "dep:quinn", "dep:bincode", "dep:tokio-serde", "tokio-util/codec"] flume-transport = ["dep:flume"] -iroh-net-transport = ["dep:iroh-net", "dep:flume", "dep:quinn", "dep:bincode", "dep:tokio-serde", "dep:tokio-util"] +iroh-net-transport = ["dep:iroh-net", "dep:flume", "dep:quinn", "dep:bincode", "dep:tokio-serde", "tokio-util/codec"] macros = [] default = ["flume-transport"] diff --git a/examples/modularize.rs b/examples/modularize.rs index 700fc93a..4143b3e5 100644 --- a/examples/modularize.rs +++ b/examples/modularize.rs @@ -12,7 +12,6 @@ use app::AppService; use futures_lite::StreamExt; use futures_util::SinkExt; use quic_rpc::{client::BoxedConnector, transport::flume, Listener, RpcClient, RpcServer}; -use tracing::warn; #[tokio::main] async fn main() -> Result<()> { @@ -32,23 +31,11 @@ async fn main() -> Result<()> { async fn run_server>(server_conn: C, handler: app::Handler) { let server = RpcServer::::new(server_conn); - loop { - let Ok(accepting) = server.accept().await else { - continue; - }; - match accepting.read_first().await { - Err(err) => warn!(?err, "server accept failed"), - Ok((req, chan)) => { - let handler = handler.clone(); - tokio::task::spawn(async move { - if let Err(err) = handler.handle_rpc_request(req, chan).await { - warn!(?err, "internal rpc error"); - } - }); - } - } - } + server + .accept_loop(move |req, chan| handler.clone().handle_rpc_request(req, chan)) + .await } + pub async fn client_demo(conn: BoxedConnector) -> Result<()> { let rpc_client = RpcClient::::new(conn); let client = app::Client::new(rpc_client.clone()); diff --git a/src/server.rs b/src/server.rs index 565157bd..47d1f700 100644 --- a/src/server.rs +++ b/src/server.rs @@ -15,6 +15,7 @@ use futures_lite::{Future, Stream, StreamExt}; use futures_util::{SinkExt, TryStreamExt}; use pin_project::pin_project; use tokio::{sync::oneshot, task::JoinSet}; +use tokio_util::task::AbortOnDropHandle; use tracing::{error, warn}; use crate::{ @@ -263,6 +264,18 @@ impl> RpcServer { } } } + + /// Spawn an accept loop and return a handle to the task. + pub fn spawn_accept_loop(self, handler: Fun) -> AbortOnDropHandle<()> + where + S: Service, + C: Listener, + Fun: Fn(S::Req, RpcChannel) -> Fut + Send + Sync + 'static, + Fut: Future> + Send + 'static, + E: Into + 'static, + { + AbortOnDropHandle::new(tokio::spawn(self.accept_loop(handler))) + } } impl> AsRef for RpcServer { diff --git a/tests/flume.rs b/tests/flume.rs index d0b14c8f..34fc9009 100644 --- a/tests/flume.rs +++ b/tests/flume.rs @@ -7,6 +7,7 @@ use quic_rpc::{ transport::flume, RpcClient, RpcServer, Service, }; +use tokio_util::task::AbortOnDropHandle; #[tokio::test] async fn flume_channel_bench() -> anyhow::Result<()> { @@ -14,14 +15,9 @@ async fn flume_channel_bench() -> anyhow::Result<()> { let (server, client) = flume::channel(1); let server = RpcServer::::new(server); - let server_handle = tokio::task::spawn(ComputeService::server(server)); + let _server_handle = AbortOnDropHandle::new(tokio::spawn(ComputeService::server(server))); let client = RpcClient::::new(client); bench(client, 1000000).await?; - // dropping the client will cause the server to terminate - match server_handle.await? { - Err(RpcServerError::Accept(_)) => {} - e => panic!("unexpected termination result {e:?}"), - } Ok(()) } @@ -101,13 +97,7 @@ async fn flume_channel_smoke() -> anyhow::Result<()> { let (server, client) = flume::channel(1); let server = RpcServer::::new(server); - let server_handle = tokio::task::spawn(ComputeService::server(server)); + let _server_handle = AbortOnDropHandle::new(tokio::spawn(ComputeService::server(server))); smoke_test(client).await?; - - // dropping the client will cause the server to terminate - match server_handle.await? { - Err(RpcServerError::Accept(_)) => {} - e => panic!("unexpected termination result {e:?}"), - } Ok(()) } diff --git a/tests/hyper.rs b/tests/hyper.rs index 0e5766dd..ab8fbc2c 100644 --- a/tests/hyper.rs +++ b/tests/hyper.rs @@ -15,19 +15,13 @@ use tokio::task::JoinHandle; mod math; use math::*; +use tokio_util::task::AbortOnDropHandle; mod util; -fn run_server(addr: &SocketAddr) -> JoinHandle> { +fn run_server(addr: &SocketAddr) -> AbortOnDropHandle<()> { let channel = HyperListener::serve(addr).unwrap(); let server = RpcServer::new(channel); - tokio::spawn(async move { - loop { - let server = server.clone(); - ComputeService::server(server).await?; - } - #[allow(unreachable_code)] - anyhow::Ok(()) - }) + ComputeService::server(server) } #[derive(Debug, Serialize, Deserialize, From, TryInto)] @@ -133,13 +127,11 @@ impl TestService { async fn hyper_channel_bench() -> anyhow::Result<()> { let addr: SocketAddr = "127.0.0.1:3000".parse()?; let uri: Uri = "http://127.0.0.1:3000".parse()?; - let server_handle = run_server(&addr); + let _server_handle = run_server(&addr); let client = HyperConnector::new(uri); let client = RpcClient::new(client); bench(client, 50000).await?; println!("terminating server"); - server_handle.abort(); - let _ = server_handle.await; Ok(()) } @@ -147,11 +139,9 @@ async fn hyper_channel_bench() -> anyhow::Result<()> { async fn hyper_channel_smoke() -> anyhow::Result<()> { let addr: SocketAddr = "127.0.0.1:3001".parse()?; let uri: Uri = "http://127.0.0.1:3001".parse()?; - let server_handle = run_server(&addr); + let _server_handle = run_server(&addr); let client = HyperConnector::new(uri); smoke_test(client).await?; - server_handle.abort(); - let _ = server_handle.await; Ok(()) } @@ -302,6 +292,5 @@ async fn hyper_channel_errors() -> anyhow::Result<()> { println!("terminating server"); server_handle.abort(); - let _ = server_handle.await; Ok(()) } diff --git a/tests/iroh-net.rs b/tests/iroh-net.rs index c416c597..4d31ad65 100644 --- a/tests/iroh-net.rs +++ b/tests/iroh-net.rs @@ -2,10 +2,13 @@ use iroh_net::{key::SecretKey, NodeAddr}; use quic_rpc::{transport, RpcClient, RpcServer}; -use tokio::task::JoinHandle; +use testresult::TestResult; + +use crate::transport::iroh_net::{IrohNetConnector, IrohNetListener}; mod math; use math::*; +use tokio_util::task::AbortOnDropHandle; mod util; const ALPN: &[u8] = b"quic-rpc/iroh-net/test"; @@ -44,13 +47,10 @@ impl Endpoints { } } -fn run_server(server: iroh_net::Endpoint) -> JoinHandle> { - tokio::task::spawn(async move { - let connection = transport::iroh_net::IrohNetListener::new(server)?; - let server = RpcServer::new(connection); - ComputeService::server(server).await?; - anyhow::Ok(()) - }) +fn run_server(server: iroh_net::Endpoint) -> AbortOnDropHandle<()> { + let connection = IrohNetListener::new(server).unwrap(); + let server = RpcServer::new(connection); + ComputeService::server(server) } // #[tokio::test(flavor = "multi_thread", worker_threads = 2)] @@ -64,17 +64,12 @@ async fn iroh_net_channel_bench() -> anyhow::Result<()> { server_node_addr, } = Endpoints::new().await?; tracing::debug!("Starting server"); - let server_handle = run_server(server); + let _server_handle = run_server(server); tracing::debug!("Starting client"); - let client = RpcClient::new(transport::iroh_net::IrohNetConnector::new( - client, - server_node_addr, - ALPN.into(), - )); + let client = RpcClient::new(IrohNetConnector::new(client, server_node_addr, ALPN.into())); tracing::debug!("Starting benchmark"); bench(client, 50000).await?; - server_handle.abort(); Ok(()) } @@ -86,11 +81,9 @@ async fn iroh_net_channel_smoke() -> anyhow::Result<()> { server, server_node_addr, } = Endpoints::new().await?; - let server_handle = run_server(server); - let client_connection = - transport::iroh_net::IrohNetConnector::new(client, server_node_addr, ALPN.into()); + let _server_handle = run_server(server); + let client_connection = IrohNetConnector::new(client, server_node_addr, ALPN.into()); smoke_test(client_connection).await?; - server_handle.abort(); Ok(()) } @@ -99,7 +92,7 @@ async fn iroh_net_channel_smoke() -> anyhow::Result<()> { /// /// This is a regression test. #[tokio::test] -async fn server_away_and_back() -> anyhow::Result<()> { +async fn server_away_and_back() -> TestResult<()> { tracing_subscriber::fmt::try_init().ok(); tracing::info!("Creating endpoints"); @@ -128,7 +121,7 @@ async fn server_away_and_back() -> anyhow::Result<()> { // create the RPC Server let connection = transport::iroh_net::IrohNetListener::new(server_endpoint.clone())?; let server = RpcServer::new(connection); - let server_handle = tokio::task::spawn(ComputeService::server_bounded(server, 1)); + let server_handle = tokio::spawn(ComputeService::server_bounded(server, 1)); // wait a bit for connection due to Windows test failing on CI tokio::time::sleep(tokio::time::Duration::from_millis(300)).await; @@ -151,7 +144,7 @@ async fn server_away_and_back() -> anyhow::Result<()> { // make the server run again let connection = transport::iroh_net::IrohNetListener::new(server_endpoint.clone())?; let server = RpcServer::new(connection); - let server_handle = tokio::task::spawn(ComputeService::server_bounded(server, 5)); + let server_handle = tokio::spawn(ComputeService::server_bounded(server, 5)); // wait a bit for connection due to Windows test failing on CI tokio::time::sleep(tokio::time::Duration::from_millis(300)).await; @@ -163,7 +156,6 @@ async fn server_away_and_back() -> anyhow::Result<()> { // server is running, this should work let SqrResponse(response) = client.rpc(Sqr(3)).await?; assert_eq!(response, 9); - server_handle.abort(); Ok(()) } diff --git a/tests/math.rs b/tests/math.rs index fd851669..e79d95aa 100644 --- a/tests/math.rs +++ b/tests/math.rs @@ -26,6 +26,7 @@ use quic_rpc::{ }; use serde::{Deserialize, Serialize}; use thousands::Separable; +use tokio_util::task::AbortOnDropHandle; /// compute the square of a number #[derive(Debug, Serialize, Deserialize)] @@ -163,20 +164,14 @@ impl ComputeService { } } - pub async fn server>( + pub fn server>( server: RpcServer, - ) -> result::Result<(), RpcServerError> { - let s = server; - let service = ComputeService; - loop { - let (req, chan) = s.accept().await?.read_first().await?; - let service = service.clone(); - tokio::spawn(async move { Self::handle_rpc_request(service, req, chan).await }); - } + ) -> AbortOnDropHandle<()> { + server.spawn_accept_loop(|req, chan| Self::handle_rpc_request(ComputeService, req, chan)) } pub async fn handle_rpc_request( - service: ComputeService, + self, req: ComputeRequest, chan: RpcChannel, ) -> Result<(), RpcServerError> @@ -186,10 +181,10 @@ impl ComputeService { use ComputeRequest::*; #[rustfmt::skip] match req { - Sqr(msg) => chan.rpc(msg, service, ComputeService::sqr).await, - Sum(msg) => chan.client_streaming(msg, service, ComputeService::sum).await, - Fibonacci(msg) => chan.server_streaming(msg, service, ComputeService::fibonacci).await, - Multiply(msg) => chan.bidi_streaming(msg, service, ComputeService::multiply).await, + Sqr(msg) => chan.rpc(msg, self, Self::sqr).await, + Sum(msg) => chan.client_streaming(msg, self, Self::sum).await, + Fibonacci(msg) => chan.server_streaming(msg, self, Self::fibonacci).await, + Multiply(msg) => chan.bidi_streaming(msg, self, Self::multiply).await, MultiplyUpdate(_) => Err(RpcServerError::UnexpectedStartMessage)?, SumUpdate(_) => Err(RpcServerError::UnexpectedStartMessage)?, }?; diff --git a/tests/quinn.rs b/tests/quinn.rs index be445b77..40c3e552 100644 --- a/tests/quinn.rs +++ b/tests/quinn.rs @@ -4,15 +4,22 @@ use std::{ sync::Arc, }; -use quic_rpc::{transport, RpcClient, RpcServer}; +use quic_rpc::{ + transport::{ + self, + quinn::{QuinnConnector, QuinnListener}, + }, + RpcClient, RpcServer, +}; use quinn::{ crypto::rustls::{QuicClientConfig, QuicServerConfig}, rustls, ClientConfig, Endpoint, ServerConfig, }; -use tokio::task::JoinHandle; mod math; use math::*; +use testresult::TestResult; +use tokio_util::task::AbortOnDropHandle; mod util; /// Constructs a QUIC endpoint configured for use a client only. @@ -112,13 +119,10 @@ pub fn make_endpoints(port: u16) -> anyhow::Result { }) } -fn run_server(server: quinn::Endpoint) -> JoinHandle> { - tokio::task::spawn(async move { - let connection = transport::quinn::QuinnListener::new(server)?; - let server = RpcServer::new(connection); - ComputeService::server(server).await?; - anyhow::Ok(()) - }) +fn run_server(server: quinn::Endpoint) -> AbortOnDropHandle<()> { + let listener = QuinnListener::new(server).unwrap(); + let listener = RpcServer::new(listener); + ComputeService::server(listener) } // #[tokio::test(flavor = "multi_thread", worker_threads = 2)] @@ -131,13 +135,12 @@ async fn quinn_channel_bench() -> anyhow::Result<()> { server_addr, } = make_endpoints(12345)?; tracing::debug!("Starting server"); - let server_handle = run_server(server); + let _server_handle = run_server(server); tracing::debug!("Starting client"); - let client = transport::quinn::QuinnConnector::new(client, server_addr, "localhost".into()); + let client = QuinnConnector::new(client, server_addr, "localhost".into()); let client = RpcClient::new(client); tracing::debug!("Starting benchmark"); bench(client, 50000).await?; - server_handle.abort(); Ok(()) } @@ -149,11 +152,10 @@ async fn quinn_channel_smoke() -> anyhow::Result<()> { server, server_addr, } = make_endpoints(12346)?; - let server_handle = run_server(server); + let _server_handle = run_server(server); let client_connection = transport::quinn::QuinnConnector::new(client, server_addr, "localhost".into()); smoke_test(client_connection).await?; - server_handle.abort(); Ok(()) } @@ -162,7 +164,7 @@ async fn quinn_channel_smoke() -> anyhow::Result<()> { /// /// This is a regression test. #[tokio::test] -async fn server_away_and_back() -> anyhow::Result<()> { +async fn server_away_and_back() -> TestResult<()> { tracing_subscriber::fmt::try_init().ok(); tracing::info!("Creating endpoints"); @@ -185,10 +187,10 @@ async fn server_away_and_back() -> anyhow::Result<()> { let server_handle = tokio::task::spawn(ComputeService::server_bounded(server, 1)); // send the first request and wait for the response to ensure everything works as expected - let SqrResponse(response) = client.rpc(Sqr(4)).await.unwrap(); + let SqrResponse(response) = client.rpc(Sqr(4)).await?; assert_eq!(response, 16); - let server = server_handle.await.unwrap().unwrap(); + let server = server_handle.await??; drop(server); // wait for drop to free the socket tokio::time::sleep(tokio::time::Duration::from_millis(300)).await; @@ -200,9 +202,8 @@ async fn server_away_and_back() -> anyhow::Result<()> { let server_handle = tokio::task::spawn(ComputeService::server_bounded(server, 5)); // server is running, this should work - let SqrResponse(response) = client.rpc(Sqr(3)).await.unwrap(); + let SqrResponse(response) = client.rpc(Sqr(3)).await?; assert_eq!(response, 9); - server_handle.abort(); Ok(()) }