From c59d3152006050ba653e3c9de3790e8a42bf18bc Mon Sep 17 00:00:00 2001 From: dignifiedquire Date: Tue, 25 Jun 2024 19:04:05 +0200 Subject: [PATCH 01/14] feat: new transport based on tokio::sync::mpsc --- Cargo.lock | 12 ++ Cargo.toml | 4 +- src/transport/async_channel.rs | 362 +++++++++++++++++++++++++++++++++ src/transport/mod.rs | 2 + tests/async_channel.rs | 113 ++++++++++ 5 files changed, 492 insertions(+), 1 deletion(-) create mode 100644 src/transport/async_channel.rs create mode 100644 tests/async_channel.rs diff --git a/Cargo.lock b/Cargo.lock index fb97110..c157d47 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -836,6 +836,7 @@ dependencies = [ "thousands", "tokio", "tokio-serde", + "tokio-stream", "tokio-util", "tracing", "tracing-subscriber", @@ -1297,6 +1298,17 @@ dependencies = [ "serde", ] +[[package]] +name = "tokio-stream" +version = "0.1.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "267ac89e0bec6e691e5813911606935d77c476ff49024f98abcea3e7b15e37af" +dependencies = [ + "futures-core", + "pin-project-lite", + "tokio", +] + [[package]] name = "tokio-util" version = "0.7.11" diff --git a/Cargo.toml b/Cargo.toml index 874bbf3..04e1c2e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -33,6 +33,7 @@ interprocess = { version = "2.1", features = ["tokio"], optional = true } hex = "0.4.3" futures = { version = "0.3.30", optional = true } anyhow = "1.0.73" +tokio-stream = { version = "0.1", optional = true } [dependencies.educe] # This is an unused dependency, it is needed to make the minimal @@ -59,11 +60,12 @@ futures-buffered = "0.2.4" hyper-transport = ["dep:flume", "dep:hyper", "dep:bincode", "dep:bytes", "dep:tokio-serde", "dep:tokio-util"] quinn-transport = ["dep:flume", "dep:quinn", "dep:bincode", "dep:tokio-serde", "dep:tokio-util"] flume-transport = ["dep:flume"] +async-channel-transport = ["dep:tokio-util", "dep:tokio-stream"] interprocess-transport = ["quinn-transport", "quinn-flume-socket", "dep:quinn-udp", "dep:interprocess", "dep:bytes", "dep:tokio-util", "dep:futures"] combined-transport = [] quinn-flume-socket = ["dep:flume", "dep:quinn", "dep:quinn-udp", "dep:bytes", "dep:tokio-util"] macros = [] -default = ["flume-transport"] +default = ["flume-transport", "async-channel-transport"] [package.metadata.docs.rs] all-features = true diff --git a/src/transport/async_channel.rs b/src/transport/async_channel.rs new file mode 100644 index 0000000..f3252f1 --- /dev/null +++ b/src/transport/async_channel.rs @@ -0,0 +1,362 @@ +//! Memory transport implementation using [tokio::sync::mpsc] + +use futures_lite::{Future, Stream}; +use futures_sink::Sink; + +use crate::{ + transport::{Connection, ConnectionErrors, LocalAddr, ServerEndpoint}, + RpcMessage, Service, +}; +use core::fmt; +use std::{error, fmt::Display, marker::PhantomData, pin::Pin, result, sync::Arc, task::Poll}; +use tokio::sync::{mpsc, Mutex}; + +use super::ConnectionCommon; + +/// Error when receiving from a channel +/// +/// This type has zero inhabitants, so it is always safe to unwrap a result with this error type. +#[derive(Debug)] +pub enum RecvError {} + +impl fmt::Display for RecvError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Debug::fmt(self, f) + } +} + +/// Sink for memory channels +pub struct SendSink(pub(crate) tokio_util::sync::PollSender); + +impl fmt::Debug for SendSink { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("SendSink").finish() + } +} + +impl Sink for SendSink { + type Error = self::SendError; + + fn poll_ready( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + Pin::new(&mut self.0) + .poll_ready(cx) + .map_err(|_| SendError::ReceiverDropped) + } + + fn start_send(mut self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> { + Pin::new(&mut self.0) + .start_send(item) + .map_err(|_| SendError::ReceiverDropped) + } + + fn poll_flush( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + Pin::new(&mut self.0) + .poll_flush(cx) + .map_err(|_| SendError::ReceiverDropped) + } + + fn poll_close( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + Pin::new(&mut self.0) + .poll_close(cx) + .map_err(|_| SendError::ReceiverDropped) + } +} + +/// Stream for memory channels +pub struct RecvStream(pub(crate) tokio_stream::wrappers::ReceiverStream); + +impl fmt::Debug for RecvStream { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("RecvStream").finish() + } +} + +impl Stream for RecvStream { + type Item = result::Result; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + match Pin::new(&mut self.0).poll_next(cx) { + Poll::Ready(Some(v)) => Poll::Ready(Some(Ok(v))), + Poll::Ready(None) => Poll::Ready(None), + Poll::Pending => Poll::Pending, + } + } +} + +impl error::Error for RecvError {} + +/// A `tokio::sync::mpsc` based server endpoint. +/// +/// Created using [connection]. +pub struct MpscServerEndpoint { + #[allow(clippy::type_complexity)] + stream: Arc, RecvStream)>>>, +} + +impl Clone for MpscServerEndpoint { + fn clone(&self) -> Self { + Self { + stream: self.stream.clone(), + } + } +} + +impl fmt::Debug for MpscServerEndpoint { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("MpscServerEndpoint") + .field("stream", &self.stream) + .finish() + } +} + +impl ConnectionErrors for MpscServerEndpoint { + type SendError = self::SendError; + + type RecvError = self::RecvError; + + type OpenError = self::AcceptBiError; +} + +type Socket = (self::SendSink, self::RecvStream); + +/// Future returned by [MpscConnection::open_bi] +pub struct OpenBiFuture { + inner: OpenBiFutureBox, + res: Option>, +} + +impl fmt::Debug for OpenBiFuture { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("OpenBiFuture").finish() + } +} + +type OpenBiFutureBox = Pin< + Box< + dyn Future, RecvStream)>>> + + Send + + 'static, + >, +>; + +impl OpenBiFuture { + fn new(inner: OpenBiFutureBox, res: Socket) -> Self { + Self { + inner, + res: Some(res), + } + } +} + +impl Future for OpenBiFuture { + type Output = result::Result, self::OpenBiError>; + + fn poll( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll { + match Pin::new(&mut self.inner).poll(cx) { + Poll::Ready(Ok(())) => self + .res + .take() + .map(|x| Poll::Ready(Ok(x))) + .unwrap_or(Poll::Pending), + Poll::Ready(Err(_)) => Poll::Ready(Err(self::OpenBiError::RemoteDropped)), + Poll::Pending => Poll::Pending, + } + } +} + +/// Future returned by [MpscServerEndpoint::accept_bi] +pub struct AcceptBiFuture { + wrapped: + Pin, RecvStream)>> + Send + 'static>>, + _p: PhantomData<(In, Out)>, +} + +impl fmt::Debug for AcceptBiFuture { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("AcceptBiFuture").finish() + } +} + +impl Future for AcceptBiFuture { + type Output = result::Result<(SendSink, RecvStream), AcceptBiError>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll { + match Pin::new(&mut self.wrapped).poll(cx) { + Poll::Ready(Some((send, recv))) => Poll::Ready(Ok((send, recv))), + Poll::Ready(None) => Poll::Ready(Err(AcceptBiError::RemoteDropped)), + Poll::Pending => Poll::Pending, + } + } +} + +impl ConnectionCommon for MpscServerEndpoint { + type SendSink = SendSink; + type RecvStream = RecvStream; +} + +impl ServerEndpoint for MpscServerEndpoint { + #[allow(refining_impl_trait)] + fn accept_bi(&self) -> AcceptBiFuture { + let stream = self.stream.clone(); + let wrapped = Box::pin(async move { stream.lock().await.recv().await }); + + AcceptBiFuture { + wrapped, + _p: PhantomData, + } + } + + fn local_addr(&self) -> &[LocalAddr] { + &[LocalAddr::Mem] + } +} + +impl ConnectionErrors for MpscConnection { + type SendError = self::SendError; + + type RecvError = self::RecvError; + + type OpenError = self::OpenBiError; +} + +impl ConnectionCommon for MpscConnection { + type SendSink = SendSink; + type RecvStream = RecvStream; +} + +impl Connection for MpscConnection { + #[allow(refining_impl_trait)] + fn open_bi(&self) -> OpenBiFuture { + let (local_send, remote_recv) = mpsc::channel::(128); + let (remote_send, local_recv) = mpsc::channel::(128); + let remote_chan = ( + SendSink(tokio_util::sync::PollSender::new(remote_send)), + RecvStream(tokio_stream::wrappers::ReceiverStream::new(remote_recv)), + ); + let local_chan = ( + SendSink(tokio_util::sync::PollSender::new(local_send)), + RecvStream(tokio_stream::wrappers::ReceiverStream::new(local_recv)), + ); + let sink = self.sink.clone(); + OpenBiFuture::new( + Box::pin(async move { sink.send(remote_chan).await }), + local_chan, + ) + } +} + +/// A mpsc based connection to a server endpoint. +/// +/// Created using [connection]. +pub struct MpscConnection { + #[allow(clippy::type_complexity)] + sink: mpsc::Sender<(SendSink, RecvStream)>, +} + +impl Clone for MpscConnection { + fn clone(&self) -> Self { + Self { + sink: self.sink.clone(), + } + } +} + +impl fmt::Debug for MpscConnection { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("MpscClientChannel") + .field("sink", &self.sink) + .finish() + } +} + +/// AcceptBiError for mem channels. +/// +/// There is not much that can go wrong with mem channels. +#[derive(Debug)] +pub enum AcceptBiError { + /// The remote side of the channel was dropped + RemoteDropped, +} + +impl fmt::Display for AcceptBiError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Debug::fmt(self, f) + } +} + +impl error::Error for AcceptBiError {} + +/// SendError for mem channels. +/// +/// There is not much that can go wrong with mem channels. +#[derive(Debug)] +pub enum SendError { + /// Receiver was dropped + ReceiverDropped, +} + +impl Display for SendError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Debug::fmt(self, f) + } +} + +impl std::error::Error for SendError {} + +/// OpenBiError for mem channels. +#[derive(Debug)] +pub enum OpenBiError { + /// The remote side of the channel was dropped + RemoteDropped, +} + +impl Display for OpenBiError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Debug::fmt(self, f) + } +} + +impl std::error::Error for OpenBiError {} + +/// CreateChannelError for mem channels. +/// +/// You can always create a mem channel, so there is no possible error. +/// Nevertheless we need a type for it. +#[derive(Debug, Clone, Copy)] +pub enum CreateChannelError {} + +impl Display for CreateChannelError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Debug::fmt(self, f) + } +} + +impl std::error::Error for CreateChannelError {} + +/// Create a mpsc server endpoint and a connected mpsc client channel. +/// +/// `buffer` the size of the buffer for each channel. Keep this at a low value to get backpressure +pub fn connection(buffer: usize) -> (MpscServerEndpoint, MpscConnection) { + let (sink, stream) = mpsc::channel(buffer); + ( + MpscServerEndpoint { + stream: Arc::new(Mutex::new(stream)), + }, + MpscConnection { sink }, + ) +} diff --git a/src/transport/mod.rs b/src/transport/mod.rs index 654d36d..189d7ac 100644 --- a/src/transport/mod.rs +++ b/src/transport/mod.rs @@ -7,6 +7,8 @@ use std::{ fmt::{self, Debug, Display}, net::SocketAddr, }; +#[cfg(feature = "async-channel-transport")] +pub mod async_channel; #[cfg(feature = "flume-transport")] pub mod boxed; #[cfg(feature = "combined-transport")] diff --git a/tests/async_channel.rs b/tests/async_channel.rs new file mode 100644 index 0000000..751bdc6 --- /dev/null +++ b/tests/async_channel.rs @@ -0,0 +1,113 @@ +#![cfg(feature = "async-channel-transport")] +#![allow(non_local_definitions)] +mod math; +use math::*; +use quic_rpc::{ + server::{RpcChannel, RpcServerError}, + transport::async_channel, + RpcClient, RpcServer, Service, +}; + +#[tokio::test] +async fn async_channel_channel_bench() -> anyhow::Result<()> { + tracing_subscriber::fmt::try_init().ok(); + let (server, client) = async_channel::connection::(1); + + let server = RpcServer::::new(server); + let server_handle = tokio::task::spawn(ComputeService::server(server)); + let client = RpcClient::::new(client); + bench(client, 1000000).await?; + // dropping the client will cause the server to terminate + match server_handle.await? { + Err(RpcServerError::Accept(_)) => {} + e => panic!("unexpected termination result {e:?}"), + } + Ok(()) +} + +#[tokio::test] +async fn async_channel_channel_mapped_bench() -> anyhow::Result<()> { + use derive_more::{From, TryInto}; + use serde::{Deserialize, Serialize}; + + tracing_subscriber::fmt::try_init().ok(); + + #[derive(Debug, Serialize, Deserialize, From, TryInto)] + enum OuterRequest { + Inner(InnerRequest), + } + #[derive(Debug, Serialize, Deserialize, From, TryInto)] + enum InnerRequest { + Compute(ComputeRequest), + } + #[derive(Debug, Serialize, Deserialize, From, TryInto)] + enum OuterResponse { + Inner(InnerResponse), + } + #[derive(Debug, Serialize, Deserialize, From, TryInto)] + enum InnerResponse { + Compute(ComputeResponse), + } + #[derive(Debug, Clone)] + struct OuterService; + impl Service for OuterService { + type Req = OuterRequest; + type Res = OuterResponse; + } + #[derive(Debug, Clone)] + struct InnerService; + impl Service for InnerService { + type Req = InnerRequest; + type Res = InnerResponse; + } + let (server, client) = async_channel::connection::(1); + + let server = RpcServer::new(server); + let server_handle: tokio::task::JoinHandle>> = + tokio::task::spawn(async move { + let service = ComputeService; + loop { + let (req, chan) = server.accept().await?; + let service = service.clone(); + tokio::spawn(async move { + let req: OuterRequest = req; + match req { + OuterRequest::Inner(InnerRequest::Compute(req)) => { + let chan: RpcChannel = chan.map(); + let chan: RpcChannel = chan.map(); + ComputeService::handle_rpc_request(service, req, chan).await + } + } + }); + } + }); + + let client = RpcClient::::new(client); + let client: RpcClient = client.map(); + let client: RpcClient = client.map(); + bench(client, 1000000).await?; + // dropping the client will cause the server to terminate + match server_handle.await? { + Err(RpcServerError::Accept(_)) => {} + e => panic!("unexpected termination result {e:?}"), + } + Ok(()) +} + +/// simple happy path test for all 4 patterns +#[tokio::test] +async fn async_channel_channel_smoke() -> anyhow::Result<()> { + tracing_subscriber::fmt::try_init().ok(); + let (server, client) = async_channel::connection::(1); + + let server = RpcServer::::new(server); + let server_handle = tokio::task::spawn(ComputeService::server(server)); + smoke_test(client).await?; + + // dropping the client will cause the server to terminate + match server_handle.await? { + Err(RpcServerError::Accept(_)) => {} + e => panic!("unexpected termination result {e:?}"), + } + Ok(()) +} From e89b51fb2be6ecfad804a83fc57243ef4a053c05 Mon Sep 17 00:00:00 2001 From: dignifiedquire Date: Tue, 25 Jun 2024 19:12:04 +0200 Subject: [PATCH 02/14] type alias --- src/transport/async_channel.rs | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/transport/async_channel.rs b/src/transport/async_channel.rs index f3252f1..124ebe2 100644 --- a/src/transport/async_channel.rs +++ b/src/transport/async_channel.rs @@ -179,10 +179,12 @@ impl Future for OpenBiFuture { } } +type AcceptBiFutureBox = + Pin, RecvStream)>> + Send + 'static>>; + /// Future returned by [MpscServerEndpoint::accept_bi] pub struct AcceptBiFuture { - wrapped: - Pin, RecvStream)>> + Send + 'static>>, + wrapped: AcceptBiFutureBox, _p: PhantomData<(In, Out)>, } From bd4158bfad9a5fe65ebc9bd315c89fdda826ffd1 Mon Sep 17 00:00:00 2001 From: dignifiedquire Date: Tue, 25 Jun 2024 19:17:10 +0200 Subject: [PATCH 03/14] avoid box --- src/transport/async_channel.rs | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/transport/async_channel.rs b/src/transport/async_channel.rs index 124ebe2..e851756 100644 --- a/src/transport/async_channel.rs +++ b/src/transport/async_channel.rs @@ -212,15 +212,15 @@ impl ConnectionCommon for MpscServerEndpoint { } impl ServerEndpoint for MpscServerEndpoint { - #[allow(refining_impl_trait)] - fn accept_bi(&self) -> AcceptBiFuture { - let stream = self.stream.clone(); - let wrapped = Box::pin(async move { stream.lock().await.recv().await }); - - AcceptBiFuture { - wrapped, - _p: PhantomData, - } + async fn accept_bi(&self) -> Result<(Self::SendSink, Self::RecvStream), AcceptBiError> { + let (send, recv) = self + .stream + .lock() + .await + .recv() + .await + .ok_or_else(|| AcceptBiError::RemoteDropped)?; + Ok((send, recv)) } fn local_addr(&self) -> &[LocalAddr] { From c5ef01fbd4eaaf4b5a18d0c8db895dfb9d349245 Mon Sep 17 00:00:00 2001 From: dignifiedquire Date: Tue, 25 Jun 2024 19:24:01 +0200 Subject: [PATCH 04/14] remove another box --- src/transport/async_channel.rs | 92 +++------------------------------- 1 file changed, 8 insertions(+), 84 deletions(-) diff --git a/src/transport/async_channel.rs b/src/transport/async_channel.rs index e851756..3b278f4 100644 --- a/src/transport/async_channel.rs +++ b/src/transport/async_channel.rs @@ -1,6 +1,6 @@ //! Memory transport implementation using [tokio::sync::mpsc] -use futures_lite::{Future, Stream}; +use futures_lite::Stream; use futures_sink::Sink; use crate::{ @@ -8,7 +8,7 @@ use crate::{ RpcMessage, Service, }; use core::fmt; -use std::{error, fmt::Display, marker::PhantomData, pin::Pin, result, sync::Arc, task::Poll}; +use std::{error, fmt::Display, pin::Pin, result, sync::Arc, task::Poll}; use tokio::sync::{mpsc, Mutex}; use super::ConnectionCommon; @@ -131,81 +131,6 @@ impl ConnectionErrors for MpscServerEndpoint { type Socket = (self::SendSink, self::RecvStream); -/// Future returned by [MpscConnection::open_bi] -pub struct OpenBiFuture { - inner: OpenBiFutureBox, - res: Option>, -} - -impl fmt::Debug for OpenBiFuture { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("OpenBiFuture").finish() - } -} - -type OpenBiFutureBox = Pin< - Box< - dyn Future, RecvStream)>>> - + Send - + 'static, - >, ->; - -impl OpenBiFuture { - fn new(inner: OpenBiFutureBox, res: Socket) -> Self { - Self { - inner, - res: Some(res), - } - } -} - -impl Future for OpenBiFuture { - type Output = result::Result, self::OpenBiError>; - - fn poll( - mut self: Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll { - match Pin::new(&mut self.inner).poll(cx) { - Poll::Ready(Ok(())) => self - .res - .take() - .map(|x| Poll::Ready(Ok(x))) - .unwrap_or(Poll::Pending), - Poll::Ready(Err(_)) => Poll::Ready(Err(self::OpenBiError::RemoteDropped)), - Poll::Pending => Poll::Pending, - } - } -} - -type AcceptBiFutureBox = - Pin, RecvStream)>> + Send + 'static>>; - -/// Future returned by [MpscServerEndpoint::accept_bi] -pub struct AcceptBiFuture { - wrapped: AcceptBiFutureBox, - _p: PhantomData<(In, Out)>, -} - -impl fmt::Debug for AcceptBiFuture { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("AcceptBiFuture").finish() - } -} - -impl Future for AcceptBiFuture { - type Output = result::Result<(SendSink, RecvStream), AcceptBiError>; - - fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll { - match Pin::new(&mut self.wrapped).poll(cx) { - Poll::Ready(Some((send, recv))) => Poll::Ready(Ok((send, recv))), - Poll::Ready(None) => Poll::Ready(Err(AcceptBiError::RemoteDropped)), - Poll::Pending => Poll::Pending, - } - } -} - impl ConnectionCommon for MpscServerEndpoint { type SendSink = SendSink; type RecvStream = RecvStream; @@ -242,8 +167,7 @@ impl ConnectionCommon for MpscConnection { } impl Connection for MpscConnection { - #[allow(refining_impl_trait)] - fn open_bi(&self) -> OpenBiFuture { + async fn open_bi(&self) -> result::Result, self::OpenBiError> { let (local_send, remote_recv) = mpsc::channel::(128); let (remote_send, local_recv) = mpsc::channel::(128); let remote_chan = ( @@ -254,11 +178,11 @@ impl Connection for MpscConnection { SendSink(tokio_util::sync::PollSender::new(local_send)), RecvStream(tokio_stream::wrappers::ReceiverStream::new(local_recv)), ); - let sink = self.sink.clone(); - OpenBiFuture::new( - Box::pin(async move { sink.send(remote_chan).await }), - local_chan, - ) + self.sink + .send(remote_chan) + .await + .map_err(|_| self::OpenBiError::RemoteDropped)?; + Ok(local_chan) } } From d357445a95f1e3aceb63d67282740162457acbba Mon Sep 17 00:00:00 2001 From: dignifiedquire Date: Tue, 25 Jun 2024 19:39:11 +0200 Subject: [PATCH 05/14] rename --- Cargo.toml | 4 +- src/transport/mod.rs | 4 +- .../{async_channel.rs => tokio_mpsc.rs} | 38 +++++++++---------- tests/{async_channel.rs => tokio_mpsc.rs} | 0 4 files changed, 23 insertions(+), 23 deletions(-) rename src/transport/{async_channel.rs => tokio_mpsc.rs} (86%) rename tests/{async_channel.rs => tokio_mpsc.rs} (100%) diff --git a/Cargo.toml b/Cargo.toml index 04e1c2e..e1eb6c1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -60,12 +60,12 @@ futures-buffered = "0.2.4" hyper-transport = ["dep:flume", "dep:hyper", "dep:bincode", "dep:bytes", "dep:tokio-serde", "dep:tokio-util"] quinn-transport = ["dep:flume", "dep:quinn", "dep:bincode", "dep:tokio-serde", "dep:tokio-util"] flume-transport = ["dep:flume"] -async-channel-transport = ["dep:tokio-util", "dep:tokio-stream"] +tokio-mpsc-transport = ["dep:tokio-util", "dep:tokio-stream"] interprocess-transport = ["quinn-transport", "quinn-flume-socket", "dep:quinn-udp", "dep:interprocess", "dep:bytes", "dep:tokio-util", "dep:futures"] combined-transport = [] quinn-flume-socket = ["dep:flume", "dep:quinn", "dep:quinn-udp", "dep:bytes", "dep:tokio-util"] macros = [] -default = ["flume-transport", "async-channel-transport"] +default = ["flume-transport", "tokio-mpsc-transport"] [package.metadata.docs.rs] all-features = true diff --git a/src/transport/mod.rs b/src/transport/mod.rs index 189d7ac..69d4983 100644 --- a/src/transport/mod.rs +++ b/src/transport/mod.rs @@ -7,8 +7,6 @@ use std::{ fmt::{self, Debug, Display}, net::SocketAddr, }; -#[cfg(feature = "async-channel-transport")] -pub mod async_channel; #[cfg(feature = "flume-transport")] pub mod boxed; #[cfg(feature = "combined-transport")] @@ -23,6 +21,8 @@ pub mod interprocess; pub mod quinn; #[cfg(feature = "quinn-flume-socket")] pub mod quinn_flume_socket; +#[cfg(feature = "tokio-mpsc-transport")] +pub mod tokio_mpsc; pub mod misc; diff --git a/src/transport/async_channel.rs b/src/transport/tokio_mpsc.rs similarity index 86% rename from src/transport/async_channel.rs rename to src/transport/tokio_mpsc.rs index 3b278f4..e3c8788 100644 --- a/src/transport/async_channel.rs +++ b/src/transport/tokio_mpsc.rs @@ -4,7 +4,7 @@ use futures_lite::Stream; use futures_sink::Sink; use crate::{ - transport::{Connection, ConnectionErrors, LocalAddr, ServerEndpoint}, + transport::{self, ConnectionErrors, LocalAddr}, RpcMessage, Service, }; use core::fmt; @@ -100,12 +100,12 @@ impl error::Error for RecvError {} /// A `tokio::sync::mpsc` based server endpoint. /// /// Created using [connection]. -pub struct MpscServerEndpoint { +pub struct ServerEndpoint { #[allow(clippy::type_complexity)] stream: Arc, RecvStream)>>>, } -impl Clone for MpscServerEndpoint { +impl Clone for ServerEndpoint { fn clone(&self) -> Self { Self { stream: self.stream.clone(), @@ -113,15 +113,15 @@ impl Clone for MpscServerEndpoint { } } -impl fmt::Debug for MpscServerEndpoint { +impl fmt::Debug for ServerEndpoint { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("MpscServerEndpoint") + f.debug_struct("ServerEndpoint") .field("stream", &self.stream) .finish() } } -impl ConnectionErrors for MpscServerEndpoint { +impl ConnectionErrors for ServerEndpoint { type SendError = self::SendError; type RecvError = self::RecvError; @@ -131,12 +131,12 @@ impl ConnectionErrors for MpscServerEndpoint { type Socket = (self::SendSink, self::RecvStream); -impl ConnectionCommon for MpscServerEndpoint { +impl ConnectionCommon for ServerEndpoint { type SendSink = SendSink; type RecvStream = RecvStream; } -impl ServerEndpoint for MpscServerEndpoint { +impl transport::ServerEndpoint for ServerEndpoint { async fn accept_bi(&self) -> Result<(Self::SendSink, Self::RecvStream), AcceptBiError> { let (send, recv) = self .stream @@ -153,7 +153,7 @@ impl ServerEndpoint for MpscServerEndpoint { } } -impl ConnectionErrors for MpscConnection { +impl ConnectionErrors for Connection { type SendError = self::SendError; type RecvError = self::RecvError; @@ -161,12 +161,12 @@ impl ConnectionErrors for MpscConnection { type OpenError = self::OpenBiError; } -impl ConnectionCommon for MpscConnection { +impl ConnectionCommon for Connection { type SendSink = SendSink; type RecvStream = RecvStream; } -impl Connection for MpscConnection { +impl transport::Connection for Connection { async fn open_bi(&self) -> result::Result, self::OpenBiError> { let (local_send, remote_recv) = mpsc::channel::(128); let (remote_send, local_recv) = mpsc::channel::(128); @@ -186,15 +186,15 @@ impl Connection for MpscConnection { } } -/// A mpsc based connection to a server endpoint. +/// A tokio::sync::mpsc based connection to a server endpoint. /// /// Created using [connection]. -pub struct MpscConnection { +pub struct Connection { #[allow(clippy::type_complexity)] sink: mpsc::Sender<(SendSink, RecvStream)>, } -impl Clone for MpscConnection { +impl Clone for Connection { fn clone(&self) -> Self { Self { sink: self.sink.clone(), @@ -202,9 +202,9 @@ impl Clone for MpscConnection { } } -impl fmt::Debug for MpscConnection { +impl fmt::Debug for Connection { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("MpscClientChannel") + f.debug_struct("ClientChannel") .field("sink", &self.sink) .finish() } @@ -277,12 +277,12 @@ impl std::error::Error for CreateChannelError {} /// Create a mpsc server endpoint and a connected mpsc client channel. /// /// `buffer` the size of the buffer for each channel. Keep this at a low value to get backpressure -pub fn connection(buffer: usize) -> (MpscServerEndpoint, MpscConnection) { +pub fn connection(buffer: usize) -> (ServerEndpoint, Connection) { let (sink, stream) = mpsc::channel(buffer); ( - MpscServerEndpoint { + ServerEndpoint { stream: Arc::new(Mutex::new(stream)), }, - MpscConnection { sink }, + Connection { sink }, ) } diff --git a/tests/async_channel.rs b/tests/tokio_mpsc.rs similarity index 100% rename from tests/async_channel.rs rename to tests/tokio_mpsc.rs From f55ddc5648652030fd6dda57c07b46db955d1031 Mon Sep 17 00:00:00 2001 From: dignifiedquire Date: Tue, 25 Jun 2024 19:57:48 +0200 Subject: [PATCH 06/14] integrate boxed transport --- src/transport/boxed.rs | 89 +++++++++++++++++++++++++++++++++++++++++- 1 file changed, 88 insertions(+), 1 deletion(-) diff --git a/src/transport/boxed.rs b/src/transport/boxed.rs index 5cc94d0..4240600 100644 --- a/src/transport/boxed.rs +++ b/src/transport/boxed.rs @@ -8,7 +8,7 @@ use std::{ use futures_lite::FutureExt; use futures_sink::Sink; -#[cfg(feature = "quinn-transport")] +#[cfg(any(feature = "quinn-transport", feature = "tokio-mpsc-transport"))] use futures_util::TryStreamExt; use futures_util::{future::BoxFuture, SinkExt, Stream, StreamExt}; use pin_project::pin_project; @@ -21,6 +21,7 @@ type BoxedFuture<'a, T> = Pin + Send + Sync + 'a>>; enum SendSinkInner { Direct(::flume::r#async::SendSink<'static, T>), + DirectTokio(tokio_util::sync::PollSender), Boxed(Pin + Send + Sync + 'static>>), } @@ -42,6 +43,10 @@ impl SendSink { pub(crate) fn direct(sink: ::flume::r#async::SendSink<'static, T>) -> Self { Self(SendSinkInner::Direct(sink)) } + + pub(crate) fn direct_tokio(sink: tokio_util::sync::PollSender) -> Self { + Self(SendSinkInner::DirectTokio(sink)) + } } impl Sink for SendSink { @@ -53,6 +58,9 @@ impl Sink for SendSink { ) -> Poll> { match self.project().0 { SendSinkInner::Direct(sink) => sink.poll_ready_unpin(cx).map_err(anyhow::Error::from), + SendSinkInner::DirectTokio(sink) => { + sink.poll_ready_unpin(cx).map_err(anyhow::Error::from) + } SendSinkInner::Boxed(sink) => sink.poll_ready_unpin(cx).map_err(anyhow::Error::from), } } @@ -60,6 +68,9 @@ impl Sink for SendSink { fn start_send(self: std::pin::Pin<&mut Self>, item: T) -> Result<(), Self::Error> { match self.project().0 { SendSinkInner::Direct(sink) => sink.start_send_unpin(item).map_err(anyhow::Error::from), + SendSinkInner::DirectTokio(sink) => { + sink.start_send_unpin(item).map_err(anyhow::Error::from) + } SendSinkInner::Boxed(sink) => sink.start_send_unpin(item).map_err(anyhow::Error::from), } } @@ -70,6 +81,9 @@ impl Sink for SendSink { ) -> Poll> { match self.project().0 { SendSinkInner::Direct(sink) => sink.poll_flush_unpin(cx).map_err(anyhow::Error::from), + SendSinkInner::DirectTokio(sink) => { + sink.poll_flush_unpin(cx).map_err(anyhow::Error::from) + } SendSinkInner::Boxed(sink) => sink.poll_flush_unpin(cx).map_err(anyhow::Error::from), } } @@ -80,6 +94,9 @@ impl Sink for SendSink { ) -> Poll> { match self.project().0 { SendSinkInner::Direct(sink) => sink.poll_close_unpin(cx).map_err(anyhow::Error::from), + SendSinkInner::DirectTokio(sink) => { + sink.poll_close_unpin(cx).map_err(anyhow::Error::from) + } SendSinkInner::Boxed(sink) => sink.poll_close_unpin(cx).map_err(anyhow::Error::from), } } @@ -87,6 +104,7 @@ impl Sink for SendSink { enum RecvStreamInner { Direct(::flume::r#async::RecvStream<'static, T>), + DirectTokio(tokio_stream::wrappers::ReceiverStream), Boxed(Pin> + Send + Sync + 'static>>), } @@ -109,6 +127,11 @@ impl RecvStream { pub(crate) fn direct(stream: ::flume::r#async::RecvStream<'static, T>) -> Self { Self(RecvStreamInner::Direct(stream)) } + + /// Create a new receive stream from a direct flume receive stream + pub(crate) fn direct_tokio(stream: tokio_stream::wrappers::ReceiverStream) -> Self { + Self(RecvStreamInner::DirectTokio(stream)) + } } impl Stream for RecvStream { @@ -121,6 +144,11 @@ impl Stream for RecvStream { Poll::Ready(None) => Poll::Ready(None), Poll::Pending => Poll::Pending, }, + RecvStreamInner::DirectTokio(stream) => match stream.poll_next_unpin(cx) { + Poll::Ready(Some(item)) => Poll::Ready(Some(Ok(item))), + Poll::Ready(None) => Poll::Ready(None), + Poll::Pending => Poll::Pending, + }, RecvStreamInner::Boxed(stream) => stream.poll_next_unpin(cx), } } @@ -129,6 +157,8 @@ impl Stream for RecvStream { enum OpenFutureInner<'a, In: RpcMessage, Out: RpcMessage> { /// A direct future (todo) Direct(super::flume::OpenBiFuture), + /// A direct future (todo) + DirectTokio(BoxFuture<'a, anyhow::Result<(SendSink, RecvStream)>>), /// A boxed future Boxed(BoxFuture<'a, anyhow::Result<(SendSink, RecvStream)>>), } @@ -141,6 +171,12 @@ impl<'a, In: RpcMessage, Out: RpcMessage> OpenFuture<'a, In, Out> { fn direct(f: super::flume::OpenBiFuture) -> Self { Self(OpenFutureInner::Direct(f)) } + /// Create a new boxed future + pub fn direct_tokio( + f: impl Future, RecvStream)>> + Send + Sync + 'a, + ) -> Self { + Self(OpenFutureInner::DirectTokio(Box::pin(f))) + } /// Create a new boxed future pub fn boxed( @@ -159,6 +195,7 @@ impl<'a, In: RpcMessage, Out: RpcMessage> Future for OpenFuture<'a, In, Out> { .poll(cx) .map_ok(|(send, recv)| (SendSink::direct(send.0), RecvStream::direct(recv.0))) .map_err(|e| e.into()), + OpenFutureInner::DirectTokio(f) => f.poll(cx), OpenFutureInner::Boxed(f) => f.poll(cx), } } @@ -167,6 +204,8 @@ impl<'a, In: RpcMessage, Out: RpcMessage> Future for OpenFuture<'a, In, Out> { enum AcceptFutureInner<'a, In: RpcMessage, Out: RpcMessage> { /// A direct future Direct(super::flume::AcceptBiFuture), + /// A direct future + DirectTokio(BoxedFuture<'a, anyhow::Result<(SendSink, RecvStream)>>), /// A boxed future Boxed(BoxedFuture<'a, anyhow::Result<(SendSink, RecvStream)>>), } @@ -180,6 +219,13 @@ impl<'a, In: RpcMessage, Out: RpcMessage> AcceptFuture<'a, In, Out> { Self(AcceptFutureInner::Direct(f)) } + /// bla + pub fn direct_tokio( + f: impl Future, RecvStream)>> + Send + Sync + 'a, + ) -> Self { + Self(AcceptFutureInner::DirectTokio(Box::pin(f))) + } + /// Create a new boxed future pub fn boxed( f: impl Future, RecvStream)>> + Send + Sync + 'a, @@ -197,6 +243,7 @@ impl<'a, In: RpcMessage, Out: RpcMessage> Future for AcceptFuture<'a, In, Out> { .poll(cx) .map_ok(|(send, recv)| (SendSink::direct(send.0), RecvStream::direct(recv.0))) .map_err(|e| e.into()), + AcceptFutureInner::DirectTokio(f) => f.poll(cx), AcceptFutureInner::Boxed(f) => f.poll(cx), } } @@ -368,6 +415,46 @@ impl BoxableServerEndpoint for super::flume::FlumeSe } } +#[cfg(feature = "tokio-mpsc-transport")] +impl BoxableConnection for super::tokio_mpsc::Connection { + fn clone_box(&self) -> Box> { + Box::new(self.clone()) + } + + fn open_bi_boxed(&self) -> OpenFuture { + let f = Box::pin(async move { + let (send, recv) = super::Connection::open_bi(self).await?; + // return the boxed streams + anyhow::Ok(( + SendSink::direct_tokio(send.0), + RecvStream::direct_tokio(recv.0), + )) + }); + OpenFuture::direct_tokio(f) + } +} + +#[cfg(feature = "tokio-mpsc-transport")] +impl BoxableServerEndpoint for super::tokio_mpsc::ServerEndpoint { + fn clone_box(&self) -> Box> { + Box::new(self.clone()) + } + + fn accept_bi_boxed(&self) -> AcceptFuture { + let f = async move { + let (send, recv) = super::ServerEndpoint::accept_bi(self).await?; + let send = send.sink_map_err(anyhow::Error::from); + let recv = recv.map_err(anyhow::Error::from); + anyhow::Ok((SendSink::boxed(send), RecvStream::boxed(recv))) + }; + AcceptFuture::direct_tokio(f) + } + + fn local_addr(&self) -> &[super::LocalAddr] { + super::ServerEndpoint::local_addr(self) + } +} + #[cfg(test)] mod tests { use crate::Service; From 32ce2375e844b2721a6459e9b7c36ecfcb1960b2 Mon Sep 17 00:00:00 2001 From: dignifiedquire Date: Tue, 25 Jun 2024 19:59:04 +0200 Subject: [PATCH 07/14] clippy --- src/transport/tokio_mpsc.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transport/tokio_mpsc.rs b/src/transport/tokio_mpsc.rs index e3c8788..44b5810 100644 --- a/src/transport/tokio_mpsc.rs +++ b/src/transport/tokio_mpsc.rs @@ -144,7 +144,7 @@ impl transport::ServerEndpoint for ServerEndpoint .await .recv() .await - .ok_or_else(|| AcceptBiError::RemoteDropped)?; + .ok_or(AcceptBiError::RemoteDropped)?; Ok((send, recv)) } From e004ac4fbbcb9fb4a7cf8da47733a968b7a3eaf6 Mon Sep 17 00:00:00 2001 From: dignifiedquire Date: Tue, 25 Jun 2024 20:04:07 +0200 Subject: [PATCH 08/14] better feature detection --- src/transport/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transport/mod.rs b/src/transport/mod.rs index 69d4983..168222c 100644 --- a/src/transport/mod.rs +++ b/src/transport/mod.rs @@ -7,7 +7,7 @@ use std::{ fmt::{self, Debug, Display}, net::SocketAddr, }; -#[cfg(feature = "flume-transport")] +#[cfg(all(feature = "flume-transport", feature = "tokio-mpsc-transport"))] pub mod boxed; #[cfg(feature = "combined-transport")] pub mod combined; From 18cce9c9dacce214ad3eb6e7f0310ab1c765870b Mon Sep 17 00:00:00 2001 From: Ruediger Klaehn Date: Tue, 23 Jul 2024 14:30:51 +0300 Subject: [PATCH 09/14] Fix tokio_mpsc test to use the right feature flag --- tests/tokio_mpsc.rs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/tokio_mpsc.rs b/tests/tokio_mpsc.rs index 751bdc6..f2930c3 100644 --- a/tests/tokio_mpsc.rs +++ b/tests/tokio_mpsc.rs @@ -1,17 +1,17 @@ -#![cfg(feature = "async-channel-transport")] +#![cfg(feature = "tokio-mpsc-transport")] #![allow(non_local_definitions)] mod math; use math::*; use quic_rpc::{ server::{RpcChannel, RpcServerError}, - transport::async_channel, + transport::tokio_mpsc, RpcClient, RpcServer, Service, }; #[tokio::test] async fn async_channel_channel_bench() -> anyhow::Result<()> { tracing_subscriber::fmt::try_init().ok(); - let (server, client) = async_channel::connection::(1); + let (server, client) = tokio_mpsc::connection::(1); let server = RpcServer::::new(server); let server_handle = tokio::task::spawn(ComputeService::server(server)); @@ -60,7 +60,7 @@ async fn async_channel_channel_mapped_bench() -> anyhow::Result<()> { type Req = InnerRequest; type Res = InnerResponse; } - let (server, client) = async_channel::connection::(1); + let (server, client) = tokio_mpsc::connection::(1); let server = RpcServer::new(server); let server_handle: tokio::task::JoinHandle>> = @@ -96,9 +96,9 @@ async fn async_channel_channel_mapped_bench() -> anyhow::Result<()> { /// simple happy path test for all 4 patterns #[tokio::test] -async fn async_channel_channel_smoke() -> anyhow::Result<()> { +async fn tokio_mpsc_channel_smoke() -> anyhow::Result<()> { tracing_subscriber::fmt::try_init().ok(); - let (server, client) = async_channel::connection::(1); + let (server, client) = tokio_mpsc::connection::(1); let server = RpcServer::::new(server); let server_handle = tokio::task::spawn(ComputeService::server(server)); From a714a154e1df6ee9cfb77279c627de1a8eb66e9c Mon Sep 17 00:00:00 2001 From: Ruediger Klaehn Date: Tue, 23 Jul 2024 14:39:21 +0300 Subject: [PATCH 10/14] Merge main and fix up --- src/transport/boxed.rs | 6 +++--- src/transport/tokio_mpsc.rs | 4 ++-- tests/tokio_mpsc.rs | 6 +++--- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/transport/boxed.rs b/src/transport/boxed.rs index 7285eee..025a378 100644 --- a/src/transport/boxed.rs +++ b/src/transport/boxed.rs @@ -421,9 +421,9 @@ impl BoxableConnection for super::tokio_mpsc::Connec Box::new(self.clone()) } - fn open_bi_boxed(&self) -> OpenFuture { + fn open_boxed(&self) -> OpenFuture { let f = Box::pin(async move { - let (send, recv) = super::Connection::open_bi(self).await?; + let (send, recv) = super::Connection::open(self).await?; // return the boxed streams anyhow::Ok(( SendSink::direct_tokio(send.0), @@ -442,7 +442,7 @@ impl BoxableServerEndpoint for super::tokio_mpsc::Se fn accept_bi_boxed(&self) -> AcceptFuture { let f = async move { - let (send, recv) = super::ServerEndpoint::accept_bi(self).await?; + let (send, recv) = super::ServerEndpoint::accept(self).await?; let send = send.sink_map_err(anyhow::Error::from); let recv = recv.map_err(anyhow::Error::from); anyhow::Ok((SendSink::boxed(send), RecvStream::boxed(recv))) diff --git a/src/transport/tokio_mpsc.rs b/src/transport/tokio_mpsc.rs index 44b5810..67d92c6 100644 --- a/src/transport/tokio_mpsc.rs +++ b/src/transport/tokio_mpsc.rs @@ -137,7 +137,7 @@ impl ConnectionCommon for ServerEndpoint { } impl transport::ServerEndpoint for ServerEndpoint { - async fn accept_bi(&self) -> Result<(Self::SendSink, Self::RecvStream), AcceptBiError> { + async fn accept(&self) -> Result<(Self::SendSink, Self::RecvStream), AcceptBiError> { let (send, recv) = self .stream .lock() @@ -167,7 +167,7 @@ impl ConnectionCommon for Connection { } impl transport::Connection for Connection { - async fn open_bi(&self) -> result::Result, self::OpenBiError> { + async fn open(&self) -> result::Result, self::OpenBiError> { let (local_send, remote_recv) = mpsc::channel::(128); let (remote_send, local_recv) = mpsc::channel::(128); let remote_chan = ( diff --git a/tests/tokio_mpsc.rs b/tests/tokio_mpsc.rs index f2930c3..a2f5735 100644 --- a/tests/tokio_mpsc.rs +++ b/tests/tokio_mpsc.rs @@ -9,7 +9,7 @@ use quic_rpc::{ }; #[tokio::test] -async fn async_channel_channel_bench() -> anyhow::Result<()> { +async fn tokio_mpsc_channel_bench() -> anyhow::Result<()> { tracing_subscriber::fmt::try_init().ok(); let (server, client) = tokio_mpsc::connection::(1); @@ -26,7 +26,7 @@ async fn async_channel_channel_bench() -> anyhow::Result<()> { } #[tokio::test] -async fn async_channel_channel_mapped_bench() -> anyhow::Result<()> { +async fn tokio_mpsc_channel_mapped_bench() -> anyhow::Result<()> { use derive_more::{From, TryInto}; use serde::{Deserialize, Serialize}; @@ -67,7 +67,7 @@ async fn async_channel_channel_mapped_bench() -> anyhow::Result<()> { tokio::task::spawn(async move { let service = ComputeService; loop { - let (req, chan) = server.accept().await?; + let (req, chan) = server.accept().await?.read_first().await?; let service = service.clone(); tokio::spawn(async move { let req: OuterRequest = req; From 297848a93f41fe130932d9d76643dac0572a5d8c Mon Sep 17 00:00:00 2001 From: Ruediger Klaehn Date: Tue, 23 Jul 2024 15:04:54 +0300 Subject: [PATCH 11/14] add feature flags to direct-tokio-boxed --- src/transport/boxed.rs | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/transport/boxed.rs b/src/transport/boxed.rs index 025a378..b0f7186 100644 --- a/src/transport/boxed.rs +++ b/src/transport/boxed.rs @@ -104,6 +104,7 @@ impl Sink for SendSink { enum RecvStreamInner { Direct(::flume::r#async::RecvStream<'static, T>), + #[cfg(feature = "tokio-mpsc-transport")] DirectTokio(tokio_stream::wrappers::ReceiverStream), Boxed(Pin> + Send + Sync + 'static>>), } @@ -129,6 +130,7 @@ impl RecvStream { } /// Create a new receive stream from a direct flume receive stream + #[cfg(feature = "tokio-mpsc-transport")] pub(crate) fn direct_tokio(stream: tokio_stream::wrappers::ReceiverStream) -> Self { Self(RecvStreamInner::DirectTokio(stream)) } @@ -144,6 +146,7 @@ impl Stream for RecvStream { Poll::Ready(None) => Poll::Ready(None), Poll::Pending => Poll::Pending, }, + #[cfg(feature = "tokio-mpsc-transport")] RecvStreamInner::DirectTokio(stream) => match stream.poll_next_unpin(cx) { Poll::Ready(Some(item)) => Poll::Ready(Some(Ok(item))), Poll::Ready(None) => Poll::Ready(None), @@ -158,6 +161,7 @@ enum OpenFutureInner<'a, In: RpcMessage, Out: RpcMessage> { /// A direct future (todo) Direct(super::flume::OpenBiFuture), /// A direct future (todo) + #[cfg(feature = "tokio-mpsc-transport")] DirectTokio(BoxFuture<'a, anyhow::Result<(SendSink, RecvStream)>>), /// A boxed future Boxed(BoxFuture<'a, anyhow::Result<(SendSink, RecvStream)>>), @@ -172,6 +176,7 @@ impl<'a, In: RpcMessage, Out: RpcMessage> OpenFuture<'a, In, Out> { Self(OpenFutureInner::Direct(f)) } /// Create a new boxed future + #[cfg(feature = "tokio-mpsc-transport")] pub fn direct_tokio( f: impl Future, RecvStream)>> + Send + Sync + 'a, ) -> Self { @@ -195,6 +200,7 @@ impl<'a, In: RpcMessage, Out: RpcMessage> Future for OpenFuture<'a, In, Out> { .poll(cx) .map_ok(|(send, recv)| (SendSink::direct(send.0), RecvStream::direct(recv.0))) .map_err(|e| e.into()), + #[cfg(feature = "tokio-mpsc-transport")] OpenFutureInner::DirectTokio(f) => f.poll(cx), OpenFutureInner::Boxed(f) => f.poll(cx), } @@ -205,6 +211,7 @@ enum AcceptFutureInner<'a, In: RpcMessage, Out: RpcMessage> { /// A direct future Direct(super::flume::AcceptBiFuture), /// A direct future + #[cfg(feature = "tokio-mpsc-transport")] DirectTokio(BoxedFuture<'a, anyhow::Result<(SendSink, RecvStream)>>), /// A boxed future Boxed(BoxedFuture<'a, anyhow::Result<(SendSink, RecvStream)>>), @@ -220,6 +227,7 @@ impl<'a, In: RpcMessage, Out: RpcMessage> AcceptFuture<'a, In, Out> { } /// bla + #[cfg(feature = "tokio-mpsc-transport")] pub fn direct_tokio( f: impl Future, RecvStream)>> + Send + Sync + 'a, ) -> Self { @@ -243,6 +251,7 @@ impl<'a, In: RpcMessage, Out: RpcMessage> Future for AcceptFuture<'a, In, Out> { .poll(cx) .map_ok(|(send, recv)| (SendSink::direct(send.0), RecvStream::direct(recv.0))) .map_err(|e| e.into()), + #[cfg(feature = "tokio-mpsc-transport")] AcceptFutureInner::DirectTokio(f) => f.poll(cx), AcceptFutureInner::Boxed(f) => f.poll(cx), } From 64775a4513f719be1a4abd70ed49c92407a73cda Mon Sep 17 00:00:00 2001 From: Ruediger Klaehn Date: Tue, 23 Jul 2024 15:12:13 +0300 Subject: [PATCH 12/14] more direct_tokio feature flags --- src/transport/boxed.rs | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/transport/boxed.rs b/src/transport/boxed.rs index b0f7186..eacd11a 100644 --- a/src/transport/boxed.rs +++ b/src/transport/boxed.rs @@ -21,6 +21,7 @@ type BoxedFuture<'a, T> = Pin + Send + Sync + 'a>>; enum SendSinkInner { Direct(::flume::r#async::SendSink<'static, T>), + #[cfg(feature = "tokio-mpsc-transport")] DirectTokio(tokio_util::sync::PollSender), Boxed(Pin + Send + Sync + 'static>>), } @@ -44,6 +45,7 @@ impl SendSink { Self(SendSinkInner::Direct(sink)) } + #[cfg(feature = "tokio-mpsc-transport")] pub(crate) fn direct_tokio(sink: tokio_util::sync::PollSender) -> Self { Self(SendSinkInner::DirectTokio(sink)) } @@ -58,6 +60,7 @@ impl Sink for SendSink { ) -> Poll> { match self.project().0 { SendSinkInner::Direct(sink) => sink.poll_ready_unpin(cx).map_err(anyhow::Error::from), + #[cfg(feature = "tokio-mpsc-transport")] SendSinkInner::DirectTokio(sink) => { sink.poll_ready_unpin(cx).map_err(anyhow::Error::from) } @@ -68,6 +71,7 @@ impl Sink for SendSink { fn start_send(self: std::pin::Pin<&mut Self>, item: T) -> Result<(), Self::Error> { match self.project().0 { SendSinkInner::Direct(sink) => sink.start_send_unpin(item).map_err(anyhow::Error::from), + #[cfg(feature = "tokio-mpsc-transport")] SendSinkInner::DirectTokio(sink) => { sink.start_send_unpin(item).map_err(anyhow::Error::from) } @@ -81,6 +85,7 @@ impl Sink for SendSink { ) -> Poll> { match self.project().0 { SendSinkInner::Direct(sink) => sink.poll_flush_unpin(cx).map_err(anyhow::Error::from), + #[cfg(feature = "tokio-mpsc-transport")] SendSinkInner::DirectTokio(sink) => { sink.poll_flush_unpin(cx).map_err(anyhow::Error::from) } @@ -94,6 +99,7 @@ impl Sink for SendSink { ) -> Poll> { match self.project().0 { SendSinkInner::Direct(sink) => sink.poll_close_unpin(cx).map_err(anyhow::Error::from), + #[cfg(feature = "tokio-mpsc-transport")] SendSinkInner::DirectTokio(sink) => { sink.poll_close_unpin(cx).map_err(anyhow::Error::from) } From 0b525b0a5641f9f558118184468ee6e961116903 Mon Sep 17 00:00:00 2001 From: Ruediger Klaehn Date: Tue, 23 Jul 2024 15:16:59 +0300 Subject: [PATCH 13/14] minimal crates fix? --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index 12a9ab4..62dcb7a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -33,7 +33,7 @@ interprocess = { version = "2.1", features = ["tokio"], optional = true } hex = "0.4.3" futures = { version = "0.3.30", optional = true } anyhow = "1.0.73" -tokio-stream = { version = "0.1", optional = true } +tokio-stream = { version = "0.1.15", optional = true } [dependencies.educe] # This is an unused dependency, it is needed to make the minimal From 021088cb0267523a36a2d155bc43fe2490e5e4d8 Mon Sep 17 00:00:00 2001 From: Ruediger Klaehn Date: Wed, 14 Aug 2024 17:12:33 +0300 Subject: [PATCH 14/14] Use concrete OpenBiFuture --- src/transport/tokio_mpsc.rs | 66 +++++++++++++++++++++++++++++++++---- 1 file changed, 59 insertions(+), 7 deletions(-) diff --git a/src/transport/tokio_mpsc.rs b/src/transport/tokio_mpsc.rs index 67d92c6..a56f682 100644 --- a/src/transport/tokio_mpsc.rs +++ b/src/transport/tokio_mpsc.rs @@ -2,13 +2,14 @@ use futures_lite::Stream; use futures_sink::Sink; +use tokio_util::sync::PollSender; use crate::{ transport::{self, ConnectionErrors, LocalAddr}, RpcMessage, Service, }; use core::fmt; -use std::{error, fmt::Display, pin::Pin, result, sync::Arc, task::Poll}; +use std::{error, fmt::Display, future::Future, pin::Pin, result, sync::Arc, task::Poll}; use tokio::sync::{mpsc, Mutex}; use super::ConnectionCommon; @@ -167,7 +168,8 @@ impl ConnectionCommon for Connection { } impl transport::Connection for Connection { - async fn open(&self) -> result::Result, self::OpenBiError> { + #[allow(refining_impl_trait)] + fn open(&self) -> OpenBiFuture { let (local_send, remote_recv) = mpsc::channel::(128); let (remote_send, local_recv) = mpsc::channel::(128); let remote_chan = ( @@ -178,11 +180,61 @@ impl transport::Connection for Connection { SendSink(tokio_util::sync::PollSender::new(local_send)), RecvStream(tokio_stream::wrappers::ReceiverStream::new(local_recv)), ); - self.sink - .send(remote_chan) - .await - .map_err(|_| self::OpenBiError::RemoteDropped)?; - Ok(local_chan) + let sender = PollSender::new(self.sink.clone()); + OpenBiFuture::new(sender, remote_chan, local_chan) + } +} + +/// Future returned by [FlumeConnection::open] +pub struct OpenBiFuture { + inner: PollSender>, + send: Option>, + res: Option>, +} + +impl fmt::Debug for OpenBiFuture { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("OpenBiFuture").finish() + } +} + +impl OpenBiFuture { + fn new( + inner: PollSender>, + send: Socket, + res: Socket, + ) -> Self { + Self { + inner, + send: Some(send), + res: Some(res), + } + } +} + +impl Future for OpenBiFuture { + type Output = result::Result, self::OpenBiError>; + + fn poll( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll { + match Pin::new(&mut self.inner).poll_reserve(cx) { + Poll::Ready(Ok(())) => { + let Some(item) = self.send.take() else { + return Poll::Pending; + }; + let Ok(_) = self.inner.send_item(item) else { + return Poll::Ready(Err(self::OpenBiError::RemoteDropped)); + }; + self.res + .take() + .map(|x| Poll::Ready(Ok(x))) + .unwrap_or(Poll::Pending) + } + Poll::Ready(Err(_)) => Poll::Ready(Err(self::OpenBiError::RemoteDropped)), + Poll::Pending => Poll::Pending, + } } }