diff --git a/Cargo.lock b/Cargo.lock index 2eb398c..70fe0a9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1655,6 +1655,7 @@ dependencies = [ "rustls", "serde", "smallvec", + "testresult", "thiserror 2.0.12", "thousands", "tokio", @@ -3300,6 +3301,12 @@ dependencies = [ "winapi-util", ] +[[package]] +name = "testresult" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "614b328ff036a4ef882c61570f72918f7e9c5bee1da33f8e7f91e01daee7e56c" + [[package]] name = "thiserror" version = "1.0.69" diff --git a/Cargo.toml b/Cargo.toml index 616d9c6..09c5076 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -57,6 +57,7 @@ tokio = { workspace = true, features = ["full"] } thousands = "0.2.0" # macro tests trybuild = "1.0.104" +testresult = "0.4.1" [features] # enable the remote transport diff --git a/examples/compute.rs b/examples/compute.rs index 0ec3446..cb7037a 100644 --- a/examples/compute.rs +++ b/examples/compute.rs @@ -7,11 +7,11 @@ use std::{ use anyhow::bail; use futures_buffered::BufferedStreamExt; use irpc::{ - channel::{oneshot, spsc}, + channel::{mpsc, oneshot}, rpc::{listen, Handler}, rpc_requests, util::{make_client_endpoint, make_server_endpoint}, - Client, LocalSender, Request, Service, WithChannels, + Client, LocalSender, Request, RequestSender, Service, }; use n0_future::{ stream::StreamExt, @@ -59,13 +59,13 @@ enum ComputeRequest { #[rpc_requests(ComputeService, message = ComputeMessage)] #[derive(Serialize, Deserialize)] enum ComputeProtocol { - #[rpc(tx=oneshot::Sender)] + #[rpc(reply=oneshot::Sender)] Sqr(Sqr), - #[rpc(rx=spsc::Receiver, tx=oneshot::Sender)] + #[rpc(updates=mpsc::Receiver, reply=oneshot::Sender)] Sum(Sum), - #[rpc(tx=spsc::Sender)] + #[rpc(reply=mpsc::Sender)] Fibonacci(Fibonacci), - #[rpc(rx=spsc::Receiver, tx=spsc::Sender)] + #[rpc(updates=mpsc::Receiver, reply=mpsc::Sender)] Multiply(Multiply), } @@ -76,10 +76,10 @@ struct ComputeActor { impl ComputeActor { pub fn local() -> ComputeApi { - let (tx, rx) = tokio::sync::mpsc::channel(128); - let actor = Self { recv: rx }; + let (reply, request) = tokio::sync::mpsc::channel(128); + let actor = Self { recv: request }; n0_future::task::spawn(actor.run()); - let local = LocalSender::::from(tx); + let local = LocalSender::::from(reply); ComputeApi { inner: local.into(), } @@ -99,34 +99,45 @@ impl ComputeActor { match msg { ComputeMessage::Sqr(sqr) => { trace!("sqr {:?}", sqr); - let WithChannels { - tx, inner, span, .. + let Request { + reply, + message, + span, + .. } = sqr; let _entered = span.enter(); - let result = (inner.num as u128) * (inner.num as u128); - tx.send(result).await?; + let result = (message.num as u128) * (message.num as u128); + reply.send(result).await?; } ComputeMessage::Sum(sum) => { trace!("sum {:?}", sum); - let WithChannels { rx, tx, span, .. } = sum; + let Request { + updates, + reply, + span, + .. + } = sum; let _entered = span.enter(); - let mut receiver = rx; + let mut receiver = updates; let mut total = 0; while let Some(num) = receiver.recv().await? { total += num; } - tx.send(total).await?; + reply.send(total).await?; } ComputeMessage::Fibonacci(fib) => { trace!("fibonacci {:?}", fib); - let WithChannels { - tx, inner, span, .. + let Request { + reply, + message, + span, + .. } = fib; let _entered = span.enter(); - let mut sender = tx; + let sender = reply; let mut a = 0u64; let mut b = 1u64; - while a <= inner.max { + while a <= message.max { sender.send(a).await?; let next = a + b; a = b; @@ -135,17 +146,17 @@ impl ComputeActor { } ComputeMessage::Multiply(mult) => { trace!("multiply {:?}", mult); - let WithChannels { - rx, - tx, - inner, + let Request { + updates, + reply, + message, span, .. } = mult; let _entered = span.enter(); - let mut receiver = rx; - let mut sender = tx; - let multiplier = inner.initial; + let mut receiver = updates; + let sender = reply; + let multiplier = message.initial; while let Some(num) = receiver.recv().await? { sender.send(multiplier * num).await?; } @@ -171,13 +182,13 @@ impl ComputeApi { let Some(local) = self.inner.local() else { bail!("cannot listen on a remote service"); }; - let handler: Handler = Arc::new(move |msg, rx, tx| { + let handler: Handler = Arc::new(move |msg, updates, reply| { let local = local.clone(); Box::pin(match msg { - ComputeProtocol::Sqr(msg) => local.send((msg, tx)), - ComputeProtocol::Sum(msg) => local.send((msg, tx, rx)), - ComputeProtocol::Fibonacci(msg) => local.send((msg, tx)), - ComputeProtocol::Multiply(msg) => local.send((msg, tx, rx)), + ComputeProtocol::Sqr(msg) => local.send((msg, reply)), + ComputeProtocol::Sum(msg) => local.send((msg, reply, updates)), + ComputeProtocol::Fibonacci(msg) => local.send((msg, reply)), + ComputeProtocol::Multiply(msg) => local.send((msg, reply, updates)), }) }); Ok(AbortOnDropHandle::new(task::spawn(listen( @@ -188,44 +199,44 @@ impl ComputeApi { pub async fn sqr(&self, num: u64) -> anyhow::Result> { let msg = Sqr { num }; match self.inner.request().await? { - Request::Local(request) => { + RequestSender::Local(sender) => { let (tx, rx) = oneshot::channel(); - request.send((msg, tx)).await?; + sender.send((msg, tx)).await?; Ok(rx) } - Request::Remote(request) => { - let (_tx, rx) = request.write(msg).await?; + RequestSender::Remote(sender) => { + let (_tx, rx) = sender.write(msg).await?; Ok(rx.into()) } } } - pub async fn sum(&self) -> anyhow::Result<(spsc::Sender, oneshot::Receiver)> { + pub async fn sum(&self) -> anyhow::Result<(mpsc::Sender, oneshot::Receiver)> { let msg = Sum; match self.inner.request().await? { - Request::Local(request) => { - let (num_tx, num_rx) = spsc::channel(10); + RequestSender::Local(sender) => { + let (num_tx, num_rx) = mpsc::channel(10); let (sum_tx, sum_rx) = oneshot::channel(); - request.send((msg, sum_tx, num_rx)).await?; + sender.send((msg, sum_tx, num_rx)).await?; Ok((num_tx, sum_rx)) } - Request::Remote(request) => { - let (tx, rx) = request.write(msg).await?; + RequestSender::Remote(sender) => { + let (tx, rx) = sender.write(msg).await?; Ok((tx.into(), rx.into())) } } } - pub async fn fibonacci(&self, max: u64) -> anyhow::Result> { + pub async fn fibonacci(&self, max: u64) -> anyhow::Result> { let msg = Fibonacci { max }; match self.inner.request().await? { - Request::Local(request) => { - let (tx, rx) = spsc::channel(128); - request.send((msg, tx)).await?; + RequestSender::Local(sender) => { + let (tx, rx) = mpsc::channel(128); + sender.send((msg, tx)).await?; Ok(rx) } - Request::Remote(request) => { - let (_tx, rx) = request.write(msg).await?; + RequestSender::Remote(sender) => { + let (_tx, rx) = sender.write(msg).await?; Ok(rx.into()) } } @@ -234,17 +245,17 @@ impl ComputeApi { pub async fn multiply( &self, initial: u64, - ) -> anyhow::Result<(spsc::Sender, spsc::Receiver)> { + ) -> anyhow::Result<(mpsc::Sender, mpsc::Receiver)> { let msg = Multiply { initial }; match self.inner.request().await? { - Request::Local(request) => { - let (in_tx, in_rx) = spsc::channel(128); - let (out_tx, out_rx) = spsc::channel(128); - request.send((msg, out_tx, in_rx)).await?; + RequestSender::Local(sender) => { + let (in_tx, in_rx) = mpsc::channel(128); + let (out_tx, out_rx) = mpsc::channel(128); + sender.send((msg, out_tx, in_rx)).await?; Ok((in_tx, out_rx)) } - Request::Remote(request) => { - let (tx, rx) = request.write(msg).await?; + RequestSender::Remote(sender) => { + let (tx, rx) = sender.write(msg).await?; Ok((tx.into(), rx.into())) } } @@ -260,7 +271,7 @@ async fn local() -> anyhow::Result<()> { println!("Local: 5^2 = {}", rx.await?); // Test Sum - let (mut tx, rx) = api.sum().await?; + let (tx, rx) = api.sum().await?; tx.send(1).await?; tx.send(2).await?; tx.send(3).await?; @@ -276,7 +287,7 @@ async fn local() -> anyhow::Result<()> { println!(); // Test Multiply - let (mut in_tx, mut out_rx) = api.multiply(3).await?; + let (in_tx, mut out_rx) = api.multiply(3).await?; in_tx.send(2).await?; in_tx.send(4).await?; in_tx.send(6).await?; @@ -311,7 +322,7 @@ async fn remote() -> anyhow::Result<()> { println!("Remote: 4^2 = {}", rx.await?); // Test Sum - let (mut tx, rx) = api.sum().await?; + let (tx, rx) = api.sum().await?; tx.send(4).await?; tx.send(5).await?; tx.send(6).await?; @@ -327,7 +338,7 @@ async fn remote() -> anyhow::Result<()> { println!(); // Test Multiply - let (mut in_tx, mut out_rx) = api.multiply(5).await?; + let (in_tx, mut out_rx) = api.multiply(5).await?; in_tx.send(1).await?; in_tx.send(2).await?; in_tx.send(3).await?; @@ -380,7 +391,7 @@ async fn bench(api: ComputeApi, n: u64) -> anyhow::Result<()> { // Sequential streaming (using Multiply instead of MultiplyUpdate) { let t0 = std::time::Instant::now(); - let (mut send, mut recv) = api.multiply(2).await?; + let (send, mut recv) = api.multiply(2).await?; let handle = tokio::task::spawn(async move { for i in 0..n { send.send(i).await?; diff --git a/examples/derive.rs b/examples/derive.rs index e03f39f..2743e07 100644 --- a/examples/derive.rs +++ b/examples/derive.rs @@ -6,11 +6,11 @@ use std::{ use anyhow::{Context, Result}; use irpc::{ - channel::{oneshot, spsc}, + channel::{mpsc, oneshot}, rpc::Handler, rpc_requests, util::{make_client_endpoint, make_server_endpoint}, - Client, LocalSender, Service, WithChannels, + Client, LocalSender, Request, Service, }; // Import the macro use n0_future::task::{self, AbortOnDropHandle}; @@ -51,13 +51,13 @@ struct SetMany; #[rpc_requests(StorageService, message = StorageMessage)] #[derive(Serialize, Deserialize)] enum StorageProtocol { - #[rpc(tx=oneshot::Sender>)] + #[rpc(reply=oneshot::Sender>)] Get(Get), - #[rpc(tx=oneshot::Sender<()>)] + #[rpc(reply=oneshot::Sender<()>)] Set(Set), - #[rpc(tx=oneshot::Sender, rx=spsc::Receiver<(String, String)>)] + #[rpc(reply=oneshot::Sender, updates=mpsc::Receiver<(String, String)>)] SetMany(SetMany), - #[rpc(tx=spsc::Sender)] + #[rpc(reply=mpsc::Sender)] List(List), } @@ -68,13 +68,13 @@ struct StorageActor { impl StorageActor { pub fn spawn() -> StorageApi { - let (tx, rx) = tokio::sync::mpsc::channel(1); + let (reply, request) = tokio::sync::mpsc::channel(1); let actor = Self { - recv: rx, + recv: request, state: BTreeMap::new(), }; n0_future::task::spawn(actor.run()); - let local = LocalSender::::from(tx); + let local = LocalSender::::from(reply); StorageApi { inner: local.into(), } @@ -90,30 +90,32 @@ impl StorageActor { match msg { StorageMessage::Get(get) => { info!("get {:?}", get); - let WithChannels { tx, inner, .. } = get; - tx.send(self.state.get(&inner.key).cloned()).await.ok(); + let Request { reply, message, .. } = get; + reply.send(self.state.get(&message.key).cloned()).await.ok(); } StorageMessage::Set(set) => { info!("set {:?}", set); - let WithChannels { tx, inner, .. } = set; - self.state.insert(inner.key, inner.value); - tx.send(()).await.ok(); + let Request { reply, message, .. } = set; + self.state.insert(message.key, message.value); + reply.send(()).await.ok(); } StorageMessage::SetMany(set) => { info!("set-many {:?}", set); - let WithChannels { mut rx, tx, .. } = set; + let Request { + mut updates, reply, .. + } = set; let mut count = 0; - while let Ok(Some((key, value))) = rx.recv().await { + while let Ok(Some((key, value))) = updates.recv().await { self.state.insert(key, value); count += 1; } - tx.send(count).await.ok(); + reply.send(count).await.ok(); } StorageMessage::List(list) => { info!("list {:?}", list); - let WithChannels { mut tx, .. } = list; + let Request { reply, .. } = list; for (key, value) in &self.state { - if tx.send(format!("{key}={value}")).await.is_err() { + if reply.send(format!("{key}={value}")).await.is_err() { break; } } @@ -135,13 +137,13 @@ impl StorageApi { pub fn listen(&self, endpoint: quinn::Endpoint) -> Result> { let local = self.inner.local().context("cannot listen on remote API")?; - let handler: Handler = Arc::new(move |msg, rx, tx| { + let handler: Handler = Arc::new(move |msg, updates, reply| { let local = local.clone(); Box::pin(match msg { - StorageProtocol::Get(msg) => local.send((msg, tx)), - StorageProtocol::Set(msg) => local.send((msg, tx)), - StorageProtocol::SetMany(msg) => local.send((msg, tx, rx)), - StorageProtocol::List(msg) => local.send((msg, tx)), + StorageProtocol::Get(msg) => local.send((msg, reply)), + StorageProtocol::Set(msg) => local.send((msg, reply)), + StorageProtocol::SetMany(msg) => local.send((msg, reply, updates)), + StorageProtocol::List(msg) => local.send((msg, reply)), }) }); let join_handle = task::spawn(irpc::rpc::listen(endpoint, handler)); @@ -152,7 +154,7 @@ impl StorageApi { self.inner.rpc(Get { key }).await } - pub async fn list(&self) -> irpc::Result> { + pub async fn list(&self) -> irpc::Result> { self.inner.server_streaming(List, 16).await } @@ -162,7 +164,7 @@ impl StorageApi { pub async fn set_many( &self, - ) -> irpc::Result<(spsc::Sender<(String, String)>, oneshot::Receiver)> { + ) -> irpc::Result<(mpsc::Sender<(String, String)>, oneshot::Receiver)> { self.inner.client_streaming(SetMany, 4).await } } @@ -172,7 +174,7 @@ async fn client_demo(api: StorageApi) -> Result<()> { let value = api.get("hello".to_string()).await?; println!("get: hello = {:?}", value); - let (mut tx, rx) = api.set_many().await?; + let (tx, rx) = api.set_many().await?; for i in 0..3 { tx.send((format!("key{i}"), format!("value{i}"))).await?; } diff --git a/examples/storage.rs b/examples/storage.rs index d73f29f..7120d5d 100644 --- a/examples/storage.rs +++ b/examples/storage.rs @@ -6,10 +6,10 @@ use std::{ use anyhow::bail; use irpc::{ - channel::{none::NoReceiver, oneshot, spsc}, + channel::{mpsc, none::NoReceiver, oneshot}, rpc::{listen, Handler}, util::{make_client_endpoint, make_server_endpoint}, - Channels, Client, LocalSender, Request, Service, WithChannels, + Channels, Client, LocalSender, Request, RequestSender, Service, }; use n0_future::task::{self, AbortOnDropHandle}; use serde::{Deserialize, Serialize}; @@ -27,16 +27,16 @@ struct Get { } impl Channels for Get { - type Rx = NoReceiver; - type Tx = oneshot::Sender>; + type Updates = NoReceiver; + type Reply = oneshot::Sender>; } #[derive(Debug, Serialize, Deserialize)] struct List; impl Channels for List { - type Rx = NoReceiver; - type Tx = spsc::Sender; + type Updates = NoReceiver; + type Reply = mpsc::Sender; } #[derive(Debug, Serialize, Deserialize)] @@ -46,8 +46,8 @@ struct Set { } impl Channels for Set { - type Rx = NoReceiver; - type Tx = oneshot::Sender<()>; + type Updates = NoReceiver; + type Reply = oneshot::Sender<()>; } #[derive(derive_more::From, Serialize, Deserialize)] @@ -59,9 +59,9 @@ enum StorageProtocol { #[derive(derive_more::From)] enum StorageMessage { - Get(WithChannels), - Set(WithChannels), - List(WithChannels), + Get(Request), + Set(Request), + List(Request), } struct StorageActor { @@ -71,13 +71,13 @@ struct StorageActor { impl StorageActor { pub fn local() -> StorageApi { - let (tx, rx) = tokio::sync::mpsc::channel(1); + let (reply, request) = tokio::sync::mpsc::channel(1); let actor = Self { - recv: rx, + recv: request, state: BTreeMap::new(), }; n0_future::task::spawn(actor.run()); - let local = LocalSender::::from(tx); + let local = LocalSender::::from(reply); StorageApi { inner: local.into(), } @@ -93,20 +93,20 @@ impl StorageActor { match msg { StorageMessage::Get(get) => { info!("get {:?}", get); - let WithChannels { tx, inner, .. } = get; - tx.send(self.state.get(&inner.key).cloned()).await.ok(); + let Request { reply, message, .. } = get; + reply.send(self.state.get(&message.key).cloned()).await.ok(); } StorageMessage::Set(set) => { info!("set {:?}", set); - let WithChannels { tx, inner, .. } = set; - self.state.insert(inner.key, inner.value); - tx.send(()).await.ok(); + let Request { reply, message, .. } = set; + self.state.insert(message.key, message.value); + reply.send(()).await.ok(); } StorageMessage::List(list) => { info!("list {:?}", list); - let WithChannels { mut tx, .. } = list; + let Request { reply, .. } = list; for (key, value) in &self.state { - if tx.send(format!("{key}={value}")).await.is_err() { + if reply.send(format!("{key}={value}")).await.is_err() { break; } } @@ -129,12 +129,12 @@ impl StorageApi { let Some(local) = self.inner.local() else { bail!("cannot listen on a remote service"); }; - let handler: Handler = Arc::new(move |msg, _rx, tx| { + let handler: Handler = Arc::new(move |msg, _request, reply| { let local = local.clone(); Box::pin(match msg { - StorageProtocol::Get(msg) => local.send((msg, tx)), - StorageProtocol::Set(msg) => local.send((msg, tx)), - StorageProtocol::List(msg) => local.send((msg, tx)), + StorageProtocol::Get(msg) => local.send((msg, reply)), + StorageProtocol::Set(msg) => local.send((msg, reply)), + StorageProtocol::List(msg) => local.send((msg, reply)), }) }); Ok(AbortOnDropHandle::new(task::spawn(listen( @@ -145,29 +145,29 @@ impl StorageApi { pub async fn get(&self, key: String) -> anyhow::Result>> { let msg = Get { key }; match self.inner.request().await? { - Request::Local(request) => { - let (tx, rx) = oneshot::channel(); - request.send((msg, tx)).await?; - Ok(rx) + RequestSender::Local(sender) => { + let (reply, request) = oneshot::channel(); + sender.send((msg, reply)).await?; + Ok(request) } - Request::Remote(request) => { - let (_tx, rx) = request.write(msg).await?; - Ok(rx.into()) + RequestSender::Remote(sender) => { + let (_reply, request) = sender.write(msg).await?; + Ok(request.into()) } } } - pub async fn list(&self) -> anyhow::Result> { + pub async fn list(&self) -> anyhow::Result> { let msg = List; match self.inner.request().await? { - Request::Local(request) => { - let (tx, rx) = spsc::channel(10); - request.send((msg, tx)).await?; - Ok(rx) + RequestSender::Local(sender) => { + let (reply, request) = mpsc::channel(10); + sender.send((msg, reply)).await?; + Ok(request) } - Request::Remote(request) => { - let (_tx, rx) = request.write(msg).await?; - Ok(rx.into()) + RequestSender::Remote(sender) => { + let (_reply, request) = sender.write(msg).await?; + Ok(request.into()) } } } @@ -175,14 +175,14 @@ impl StorageApi { pub async fn set(&self, key: String, value: String) -> anyhow::Result> { let msg = Set { key, value }; match self.inner.request().await? { - Request::Local(request) => { - let (tx, rx) = oneshot::channel(); - request.send((msg, tx)).await?; - Ok(rx) + RequestSender::Local(sender) => { + let (reply, request) = oneshot::channel(); + sender.send((msg, reply)).await?; + Ok(request) } - Request::Remote(request) => { - let (_tx, rx) = request.write(msg).await?; - Ok(rx.into()) + RequestSender::Remote(sender) => { + let (_reply, request) = sender.write(msg).await?; + Ok(request.into()) } } } diff --git a/irpc-derive/src/lib.rs b/irpc-derive/src/lib.rs index 754889e..66334c8 100644 --- a/irpc-derive/src/lib.rs +++ b/irpc-derive/src/lib.rs @@ -17,11 +17,11 @@ fn error_tokens(span: Span, message: &str) -> TokenStream { /// The only attribute we care about const ATTR_NAME: &str = "rpc"; -/// the tx type name -const TX_ATTR: &str = "tx"; -/// the rx type name -const RX_ATTR: &str = "rx"; -/// Fully qualified path to the default rx type +/// the reply type name +const TX_ATTR: &str = "reply"; +/// the request type name +const RX_ATTR: &str = "updates"; +/// Fully qualified path to the default request type const DEFAULT_RX_TYPE: &str = "::irpc::channel::none::NoReceiver"; /// Generate parent span method for an enum @@ -45,18 +45,18 @@ fn generate_channels_impl( request_type: &Type, attr_span: Span, ) -> syn::Result { - // Try to get rx, default to NoReceiver if not present + // Try to get updates, default to NoReceiver if not present // Use unwrap_or_else for a cleaner default - let rx = args.types.remove(RX_ATTR).unwrap_or_else(|| { + let updates = args.types.remove(RX_ATTR).unwrap_or_else(|| { // We can safely unwrap here because this is a known valid type - syn::parse_str::(DEFAULT_RX_TYPE).expect("Failed to parse default rx type") + syn::parse_str::(DEFAULT_RX_TYPE).expect("Failed to parse default request type") }); - let tx = args.get(TX_ATTR, attr_span)?; + let reply = args.get(TX_ATTR, attr_span)?; let res = quote! { impl ::irpc::Channels<#service_name> for #request_type { - type Tx = #tx; - type Rx = #rx; + type Reply = #reply; + type Updates = #updates; } }; @@ -98,11 +98,11 @@ fn generate_message_enum_from_impls( ) -> TokenStream2 { let mut impls = quote! {}; - // Generate From> implementations for each case with an rpc attribute + // Generate From> implementations for each case with an rpc attribute for (variant_name, inner_type) in variants_with_attr { let impl_tokens = quote! { - impl From<::irpc::WithChannels<#inner_type, #service_name>> for #message_enum_name { - fn from(value: ::irpc::WithChannels<#inner_type, #service_name>) -> Self { + impl From<::irpc::Request<#inner_type, #service_name>> for #message_enum_name { + fn from(value: ::irpc::Request<#inner_type, #service_name>) -> Self { #message_enum_name::#variant_name(value) } } @@ -117,7 +117,7 @@ fn generate_message_enum_from_impls( impls } -/// Generate type aliases for WithChannels +/// Generate type aliases for Request fn generate_type_aliases( variants: &[(Ident, Type)], service_name: &Ident, @@ -132,8 +132,8 @@ fn generate_type_aliases( let type_ident = Ident::new(&type_name, variant_name.span()); let alias = quote! { - /// Type alias for WithChannels<#inner_type, #service_name> - pub type #type_ident = ::irpc::WithChannels<#inner_type, #service_name>; + /// Type alias for Request<#message_type, #service_name> + pub type #type_ident = ::irpc::Request<#inner_type, #service_name>; }; aliases = quote! { @@ -153,17 +153,17 @@ fn generate_type_aliases( /// # Macro Arguments /// /// * First positional argument (required): The service type that will handle these requests -/// * `message` (optional): Generate an extended enum wrapping each type in `WithChannels` -/// * `alias` (optional): Generate type aliases with the given suffix for each `WithChannels` +/// * `message` (optional): Generate an extended enum wrapping each type in `Request` +/// * `alias` (optional): Generate type aliases with the given suffix for each `Request` /// /// # Variant Attributes /// /// Individual enum variants can be annotated with the `#[rpc(...)]` attribute to specify channel types: /// -/// * `#[rpc(tx=SomeType)]`: Specify the transmitter/sender channel type (required) -/// * `#[rpc(tx=SomeType, rx=OtherType)]`: Also specify a receiver channel type (optional) +/// * `#[rpc(reply=SomeType)]`: Specify the transmitter/sender channel type (required) +/// * `#[rpc(reply=SomeType, updates=OtherType)]`: Also specify a receiver channel type (optional) /// -/// If `rx` is not specified, it defaults to `NoReceiver`. +/// If `request` is not specified, it defaults to `NoReceiver`. /// /// # Examples /// @@ -171,9 +171,9 @@ fn generate_type_aliases( /// ``` /// #[rpc_requests(ComputeService)] /// enum ComputeProtocol { -/// #[rpc(tx=oneshot::Sender)] +/// #[rpc(reply=oneshot::Sender)] /// Sqr(Sqr), -/// #[rpc(tx=oneshot::Sender)] +/// #[rpc(reply=oneshot::Sender)] /// Sum(Sum), /// } /// ``` @@ -182,9 +182,9 @@ fn generate_type_aliases( /// ``` /// #[rpc_requests(ComputeService, message = ComputeMessage)] /// enum ComputeProtocol { -/// #[rpc(tx=oneshot::Sender)] +/// #[rpc(reply=oneshot::Sender)] /// Sqr(Sqr), -/// #[rpc(tx=oneshot::Sender)] +/// #[rpc(reply=oneshot::Sender)] /// Sum(Sum), /// } /// ``` @@ -193,10 +193,10 @@ fn generate_type_aliases( /// ``` /// #[rpc_requests(ComputeService, alias = "Msg")] /// enum ComputeProtocol { -/// #[rpc(tx=oneshot::Sender)] -/// Sqr(Sqr), // Generates type SqrMsg = WithChannels -/// #[rpc(tx=oneshot::Sender)] -/// Sum(Sum), // Generates type SumMsg = WithChannels +/// #[rpc(reply=oneshot::Sender)] +/// Sqr(Sqr), // Generates type SqrMsg = Request +/// #[rpc(reply=oneshot::Sender)] +/// Sum(Sum), // Generates type SumMsg = Request /// } /// ``` #[proc_macro_attribute] @@ -302,7 +302,7 @@ pub fn rpc_requests(attr: TokenStream, item: TokenStream) -> TokenStream { .map(|(variant_name, inner_type)| { quote! { #[allow(missing_docs)] - #variant_name(::irpc::WithChannels<#inner_type, #service_name>) + #variant_name(::irpc::Request<#inner_type, #service_name>) } }) .collect::>(); @@ -349,7 +349,7 @@ pub fn rpc_requests(attr: TokenStream, item: TokenStream) -> TokenStream { // From implementations for the original enum #original_from_impls - // Type aliases for WithChannels + // Type aliases for Request #type_aliases // Extended enum and its implementations diff --git a/irpc-iroh/examples/auth.rs b/irpc-iroh/examples/auth.rs index 88944a7..39ea5a4 100644 --- a/irpc-iroh/examples/auth.rs +++ b/irpc-iroh/examples/auth.rs @@ -72,8 +72,8 @@ mod storage { Endpoint, }; use irpc::{ - channel::{oneshot, spsc}, - Client, Service, WithChannels, + channel::{mpsc, oneshot}, + Client, Request, Service, }; // Import the macro use irpc_derive::rpc_requests; @@ -116,15 +116,15 @@ mod storage { #[rpc_requests(StorageService, message = StorageMessage)] #[derive(Serialize, Deserialize)] enum StorageProtocol { - #[rpc(tx=oneshot::Sender>)] + #[rpc(reply=oneshot::Sender>)] Auth(Auth), - #[rpc(tx=oneshot::Sender>)] + #[rpc(reply=oneshot::Sender>)] Get(Get), - #[rpc(tx=oneshot::Sender<()>)] + #[rpc(reply=oneshot::Sender<()>)] Set(Set), - #[rpc(tx=oneshot::Sender, rx=spsc::Receiver<(String, String)>)] + #[rpc(reply=oneshot::Sender, updates=mpsc::Receiver<(String, String)>)] SetMany(SetMany), - #[rpc(tx=spsc::Sender)] + #[rpc(reply=mpsc::Sender)] List(List), } @@ -139,20 +139,20 @@ mod storage { let this = self.clone(); Box::pin(async move { let mut authed = false; - while let Some((msg, rx, tx)) = read_request(&conn).await? { - let msg_with_channels = upcast_message(msg, rx, tx); + while let Some((msg, request, reply)) = read_request(&conn).await? { + let msg_with_channels = upcast_message(msg, request, reply); match msg_with_channels { StorageMessage::Auth(msg) => { - let WithChannels { inner, tx, .. } = msg; + let Request { message, reply, .. } = msg; if authed { conn.close(1u32.into(), b"invalid message"); break; - } else if inner.token != this.auth_token { + } else if message.token != this.auth_token { conn.close(1u32.into(), b"permission denied"); break; } else { authed = true; - tx.send(Ok(())).await.ok(); + reply.send(Ok(())).await.ok(); } } _ => { @@ -171,13 +171,17 @@ mod storage { } } - fn upcast_message(msg: StorageProtocol, rx: RecvStream, tx: SendStream) -> StorageMessage { + fn upcast_message( + msg: StorageProtocol, + updates: RecvStream, + reply: SendStream, + ) -> StorageMessage { match msg { - StorageProtocol::Auth(msg) => WithChannels::from((msg, tx, rx)).into(), - StorageProtocol::Get(msg) => WithChannels::from((msg, tx, rx)).into(), - StorageProtocol::Set(msg) => WithChannels::from((msg, tx, rx)).into(), - StorageProtocol::SetMany(msg) => WithChannels::from((msg, tx, rx)).into(), - StorageProtocol::List(msg) => WithChannels::from((msg, tx, rx)).into(), + StorageProtocol::Auth(msg) => Request::from((msg, reply, updates)).into(), + StorageProtocol::Get(msg) => Request::from((msg, reply, updates)).into(), + StorageProtocol::Set(msg) => Request::from((msg, reply, updates)).into(), + StorageProtocol::SetMany(msg) => Request::from((msg, reply, updates)).into(), + StorageProtocol::List(msg) => Request::from((msg, reply, updates)).into(), } } @@ -196,29 +200,34 @@ mod storage { StorageMessage::Auth(_) => unreachable!("handled in ProtocolHandler::accept"), StorageMessage::Get(get) => { info!("get {:?}", get); - let WithChannels { tx, inner, .. } = get; - let res = self.state.lock().unwrap().get(&inner.key).cloned(); - tx.send(res).await.ok(); + let Request { reply, message, .. } = get; + let res = self.state.lock().unwrap().get(&message.key).cloned(); + reply.send(res).await.ok(); } StorageMessage::Set(set) => { info!("set {:?}", set); - let WithChannels { tx, inner, .. } = set; - self.state.lock().unwrap().insert(inner.key, inner.value); - tx.send(()).await.ok(); + let Request { reply, message, .. } = set; + self.state + .lock() + .unwrap() + .insert(message.key, message.value); + reply.send(()).await.ok(); } StorageMessage::SetMany(list) => { - let WithChannels { tx, mut rx, .. } = list; + let Request { + reply, mut updates, .. + } = list; let mut i = 0; - while let Ok(Some((key, value))) = rx.recv().await { + while let Ok(Some((key, value))) = updates.recv().await { let mut state = self.state.lock().unwrap(); state.insert(key, value); i += 1; } - tx.send(i).await.ok(); + reply.send(i).await.ok(); } StorageMessage::List(list) => { info!("list {:?}", list); - let WithChannels { mut tx, .. } = list; + let Request { reply, .. } = list; let values = { let state = self.state.lock().unwrap(); // TODO: use async lock to not clone here. @@ -229,7 +238,7 @@ mod storage { values }; for value in values { - if tx.send(value).await.is_err() { + if reply.send(value).await.is_err() { break; } } @@ -265,7 +274,7 @@ mod storage { self.inner.rpc(Get { key }).await } - pub async fn list(&self) -> Result, irpc::Error> { + pub async fn list(&self) -> Result, irpc::Error> { self.inner.server_streaming(List, 10).await } diff --git a/irpc-iroh/examples/derive.rs b/irpc-iroh/examples/derive.rs index f348654..1db185f 100644 --- a/irpc-iroh/examples/derive.rs +++ b/irpc-iroh/examples/derive.rs @@ -60,9 +60,9 @@ mod storage { use anyhow::{Context, Result}; use iroh::{protocol::ProtocolHandler, Endpoint}; use irpc::{ - channel::{oneshot, spsc}, + channel::{mpsc, oneshot}, rpc::Handler, - rpc_requests, Client, LocalSender, Service, WithChannels, + rpc_requests, Client, LocalSender, Request, Service, }; // Import the macro use irpc_iroh::{IrohProtocol, IrohRemoteConnection}; @@ -93,11 +93,11 @@ mod storage { #[rpc_requests(StorageService, message = StorageMessage)] #[derive(Serialize, Deserialize)] enum StorageProtocol { - #[rpc(tx=oneshot::Sender>)] + #[rpc(reply=oneshot::Sender>)] Get(Get), - #[rpc(tx=oneshot::Sender<()>)] + #[rpc(reply=oneshot::Sender<()>)] Set(Set), - #[rpc(tx=spsc::Sender)] + #[rpc(reply=mpsc::Sender)] List(List), } @@ -130,20 +130,20 @@ mod storage { match msg { StorageMessage::Get(get) => { info!("get {:?}", get); - let WithChannels { tx, inner, .. } = get; - tx.send(self.state.get(&inner.key).cloned()).await.ok(); + let Request { reply, message, .. } = get; + reply.send(self.state.get(&message.key).cloned()).await.ok(); } StorageMessage::Set(set) => { info!("set {:?}", set); - let WithChannels { tx, inner, .. } = set; - self.state.insert(inner.key, inner.value); - tx.send(()).await.ok(); + let Request { reply, message, .. } = set; + self.state.insert(message.key, message.value); + reply.send(()).await.ok(); } StorageMessage::List(list) => { info!("list {:?}", list); - let WithChannels { mut tx, .. } = list; + let Request { reply, .. } = list; for (key, value) in &self.state { - if tx.send(format!("{key}={value}")).await.is_err() { + if reply.send(format!("{key}={value}")).await.is_err() { break; } } @@ -175,12 +175,12 @@ mod storage { .inner .local() .context("can not listen on remote service")?; - let handler: Handler = Arc::new(move |msg, _rx, tx| { + let handler: Handler = Arc::new(move |msg, _updates, reply| { let local = local.clone(); Box::pin(match msg { - StorageProtocol::Get(msg) => local.send((msg, tx)), - StorageProtocol::Set(msg) => local.send((msg, tx)), - StorageProtocol::List(msg) => local.send((msg, tx)), + StorageProtocol::Get(msg) => local.send((msg, reply)), + StorageProtocol::Set(msg) => local.send((msg, reply)), + StorageProtocol::List(msg) => local.send((msg, reply)), }) }); Ok(IrohProtocol::new(handler)) @@ -190,7 +190,7 @@ mod storage { self.inner.rpc(Get { key }).await } - pub async fn list(&self) -> irpc::Result> { + pub async fn list(&self) -> irpc::Result> { self.inner.server_streaming(List, 10).await } diff --git a/irpc-iroh/src/lib.rs b/irpc-iroh/src/lib.rs index 2d8d065..9c852db 100644 --- a/irpc-iroh/src/lib.rs +++ b/irpc-iroh/src/lib.rs @@ -128,10 +128,10 @@ pub async fn handle_connection( handler: Handler, ) -> io::Result<()> { loop { - let Some((msg, rx, tx)) = read_request(&connection).await? else { + let Some((msg, updates, reply)) = read_request(&connection).await? else { return Ok(()); }; - handler(msg, rx, tx).await?; + handler(msg, updates, reply).await?; } } @@ -166,9 +166,9 @@ pub async fn read_request( .map_err(|e| io::Error::new(io::ErrorKind::UnexpectedEof, e))?; let msg: R = postcard::from_bytes(&buf).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; - let rx = recv; - let tx = send; - Ok(Some((msg, rx, tx))) + let updates = recv; + let reply = send; + Ok(Some((msg, updates, reply))) } /// Utility function to listen for incoming connections and handle them with the provided handler diff --git a/src/lib.rs b/src/lib.rs index 3d594f3..09257aa 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -20,17 +20,17 @@ //! //! ## Interaction patterns //! -//! For each request, there can be a response and update channel. Each channel +//! For each request, there can be a reply and update channel. Each channel //! can be either oneshot, carry multiple messages, or be disabled. This enables //! the typical interaction patterns known from libraries like grpc: //! -//! - rpc: 1 request, 1 response -//! - server streaming: 1 request, multiple responses -//! - client streaming: multiple requests, 1 response -//! - bidi streaming: multiple requests, multiple responses +//! - rpc: 1 request, 1 reply +//! - server streaming: 1 request, multiple replys +//! - client streaming: multiple requests, 1 reply +//! - bidi streaming: multiple requests, multiple replys //! //! as well as more complex patterns. It is however not possible to have multiple -//! differently typed tx channels for a single message type. +//! differently typed reply channels for a single message type. //! //! ## Transports //! @@ -111,7 +111,7 @@ impl RpcMessage for T where /// This is usually implemented by a zero-sized struct. /// It has various bounds to make derives easier. /// -/// A service acts as a scope for defining the tx and rx channels for each +/// A service acts as a scope for defining the reply and request channels for each /// message type, and provides some type safety when sending messages. pub trait Service: Send + Sync + Debug + Clone + 'static {} @@ -127,12 +127,12 @@ pub trait Receiver: Debug + Sealed {} /// Trait to specify channels for a message and service pub trait Channels { - /// The sender type, can be either spsc, oneshot or none - type Tx: Sender; - /// The receiver type, can be either spsc, oneshot or none + /// The sender type, can be either mpsc, oneshot or none + type Reply: Sender; + /// The receiver type, can be either mpsc, oneshot or none /// /// For many services, the receiver is not needed, so it can be set to [`NoReceiver`]. - type Rx: Receiver; + type Updates: Receiver; } /// Channels that abstract over local or remote sending @@ -152,8 +152,8 @@ pub mod channel { /// /// This is currently using a tokio channel pair internally. pub fn channel() -> (Sender, Receiver) { - let (tx, rx) = tokio::sync::oneshot::channel(); - (tx.into(), rx.into()) + let (reply, request) = tokio::sync::oneshot::channel(); + (reply.into(), request.into()) } /// A generic boxed sender. @@ -206,8 +206,8 @@ pub mod channel { } impl From> for Sender { - fn from(tx: tokio::sync::oneshot::Sender) -> Self { - Self::Tokio(tx) + fn from(reply: tokio::sync::oneshot::Sender) -> Self { + Self::Tokio(reply) } } @@ -216,7 +216,7 @@ pub mod channel { fn try_from(value: Sender) -> Result { match value { - Sender::Tokio(tx) => Ok(tx), + Sender::Tokio(reply) => Ok(reply), Sender::Boxed(_) => Err(value), } } @@ -229,7 +229,9 @@ pub mod channel { /// Local senders will never yield, but can fail if the receiver has been closed. pub async fn send(self, value: T) -> std::result::Result<(), SendError> { match self { - Sender::Tokio(tx) => tx.send(value).map_err(|_| SendError::ReceiverClosed), + Sender::Tokio(reply) => { + reply.send(value).map_err(|_| SendError::ReceiverClosed) + } Sender::Boxed(f) => f(value).await.map_err(SendError::from), } } @@ -265,16 +267,18 @@ pub mod channel { fn poll(self: Pin<&mut Self>, cx: &mut task::Context) -> task::Poll { match self.get_mut() { - Self::Tokio(rx) => Pin::new(rx).poll(cx).map_err(|_| RecvError::SenderClosed), - Self::Boxed(rx) => Pin::new(rx).poll(cx).map_err(RecvError::Io), + Self::Tokio(request) => Pin::new(request) + .poll(cx) + .map_err(|_| RecvError::SenderClosed), + Self::Boxed(request) => Pin::new(request).poll(cx).map_err(RecvError::Io), } } } /// Convert a tokio oneshot receiver to a receiver for this crate impl From> for Receiver { - fn from(rx: tokio::sync::oneshot::Receiver) -> Self { - Self::Tokio(FusedOneshotReceiver(rx)) + fn from(request: tokio::sync::oneshot::Receiver) -> Self { + Self::Tokio(FusedOneshotReceiver(request)) } } @@ -283,7 +287,7 @@ pub mod channel { fn try_from(value: Receiver) -> Result { match value { - Receiver::Tokio(tx) => Ok(tx.0), + Receiver::Tokio(reply) => Ok(reply.0), Receiver::Boxed(_) => Err(value), } } @@ -315,32 +319,28 @@ pub mod channel { /// SPSC channel, similar to tokio's mpsc channel /// - /// For the rpc case, the send side can not be cloned, hence spsc instead of mpsc. - pub mod spsc { - use std::{fmt::Debug, future::Future, io, pin::Pin}; + /// For the rpc case, the send side can not be cloned, hence mpsc instead of mpsc. + pub mod mpsc { + use std::{fmt::Debug, future::Future, io, pin::Pin, sync::Arc}; use super::{RecvError, SendError}; use crate::RpcMessage; - /// Create a local spsc sender and receiver pair, with the given buffer size. + /// Create a local mpsc sender and receiver pair, with the given buffer size. /// /// This is currently using a tokio channel pair internally. pub fn channel(buffer: usize) -> (Sender, Receiver) { - let (tx, rx) = tokio::sync::mpsc::channel(buffer); - (tx.into(), rx.into()) + let (reply, request) = tokio::sync::mpsc::channel(buffer); + (reply.into(), request.into()) } /// Single producer, single consumer sender. /// - /// For the local case, this wraps a tokio::sync::mpsc::Sender. However, - /// due to the fact that a stream to a remote service can not be cloned, - /// this can also not be cloned. - /// - /// This forces you to use senders in a linear way, passing out references - /// to the sender to other tasks instead of cloning it. + /// For the local case, this wraps a tokio::sync::mpsc::Sender. + #[derive(Clone)] pub enum Sender { Tokio(tokio::sync::mpsc::Sender), - Boxed(Box>), + Boxed(Arc>), } impl Sender { @@ -354,12 +354,12 @@ pub mod channel { } } - pub async fn closed(&mut self) + pub async fn closed(&self) where T: RpcMessage, { match self { - Sender::Tokio(tx) => tx.closed().await, + Sender::Tokio(reply) => reply.closed().await, Sender::Boxed(sink) => sink.closed().await, } } @@ -369,7 +369,7 @@ pub mod channel { where T: RpcMessage, { - futures_util::sink::unfold(self, |mut sink, value| async move { + futures_util::sink::unfold(self, |sink, value| async move { sink.send(value).await?; Ok(sink) }) @@ -377,8 +377,8 @@ pub mod channel { } impl From> for Sender { - fn from(tx: tokio::sync::mpsc::Sender) -> Self { - Self::Tokio(tx) + fn from(reply: tokio::sync::mpsc::Sender) -> Self { + Self::Tokio(reply) } } @@ -387,22 +387,22 @@ pub mod channel { fn try_from(value: Sender) -> Result { match value { - Sender::Tokio(tx) => Ok(tx), + Sender::Tokio(reply) => Ok(reply), Sender::Boxed(_) => Err(value), } } } - /// A sender that can be wrapped in a `Box>`. + /// A sender that can be wrapped in a `Arc>`. pub trait DynSender: Debug + Send + Sync + 'static { /// Send a message. /// /// For the remote case, if the message can not be completely sent, /// this must return an error and disable the channel. fn send( - &mut self, + &self, value: T, - ) -> Pin> + Send + '_>>; + ) -> Pin> + Send + Sync + '_>>; /// Try to send a message, returning as fast as possible if sending /// is not currently possible. @@ -410,12 +410,12 @@ pub mod channel { /// For the remote case, it must be guaranteed that the message is /// either completely sent or not at all. fn try_send( - &mut self, + &self, value: T, - ) -> Pin> + Send + '_>>; + ) -> Pin> + Send + Sync + '_>>; /// Await the sender close - fn closed(&mut self) -> Pin + Send + '_>>; + fn closed(&self) -> Pin + Send + Sync + '_>>; /// True if this is a remote sender fn is_rpc(&self) -> bool; @@ -425,7 +425,14 @@ pub mod channel { pub trait DynReceiver: Debug + Send + Sync + 'static { fn recv( &mut self, - ) -> Pin, RecvError>> + Send + '_>>; + ) -> Pin< + Box< + dyn Future, RecvError>> + + Send + + Sync + + '_, + >, + >; } impl Debug for Sender { @@ -436,18 +443,26 @@ pub mod channel { .field("avail", &x.capacity()) .field("cap", &x.max_capacity()) .finish(), - Self::Boxed(inner) => f.debug_tuple("Boxed").field(&inner).finish(), + Self::Boxed(message) => f.debug_tuple("Boxed").field(&message).finish(), } } } impl Sender { /// Send a message and yield until either it is sent or an error occurs. - pub async fn send(&mut self, value: T) -> std::result::Result<(), SendError> { + /// + /// ## Cancellation safety + /// + /// If the future is dropped before completion, and if this is a remote sender, + /// then the sender will be closed and further sends will return an [`io::Error`] + /// with [`io::ErrorKind::BrokenPipe`]. Therefore, make sure to always poll the + /// future until completion if you want to reuse the sender or any clone afterwards. + pub async fn send(&self, value: T) -> std::result::Result<(), SendError> { match self { - Sender::Tokio(tx) => { - tx.send(value).await.map_err(|_| SendError::ReceiverClosed) - } + Sender::Tokio(reply) => reply + .send(value) + .await + .map_err(|_| SendError::ReceiverClosed), Sender::Boxed(sink) => sink.send(value).await.map_err(SendError::from), } } @@ -466,9 +481,16 @@ pub mod channel { /// all. /// /// Returns true if the message was sent. + /// + /// ## Cancellation safety + /// + /// If the future is dropped before completion, and if this is a remote sender, + /// then the sender will be closed and further sends will return an [`io::Error`] + /// with [`io::ErrorKind::BrokenPipe`]. Therefore, make sure to always poll the + /// future until completion if you want to reuse the sender or any clone afterwards. pub async fn try_send(&mut self, value: T) -> std::result::Result { match self { - Sender::Tokio(tx) => match tx.try_send(value) { + Sender::Tokio(reply) => match reply.try_send(value) { Ok(()) => Ok(true), Err(tokio::sync::mpsc::error::TrySendError::Closed(_)) => { Err(SendError::ReceiverClosed) @@ -497,15 +519,15 @@ pub mod channel { /// Returns an an io error if there was an error receiving the message. pub async fn recv(&mut self) -> std::result::Result, RecvError> { match self { - Self::Tokio(rx) => Ok(rx.recv().await), - Self::Boxed(rx) => Ok(rx.recv().await?), + Self::Tokio(request) => Ok(request.recv().await), + Self::Boxed(request) => Ok(request.recv().await?), } } #[cfg(feature = "stream")] pub fn into_stream( self, - ) -> impl n0_future::Stream> + Send + 'static + ) -> impl n0_future::Stream> + Send + Sync + 'static { n0_future::stream::unfold(self, |mut recv| async move { recv.recv().await.transpose().map(|msg| (msg, recv)) @@ -514,8 +536,8 @@ pub mod channel { } impl From> for Receiver { - fn from(rx: tokio::sync::mpsc::Receiver) -> Self { - Self::Tokio(rx) + fn from(request: tokio::sync::mpsc::Receiver) -> Self { + Self::Tokio(request) } } @@ -524,7 +546,7 @@ pub mod channel { fn try_from(value: Receiver) -> Result { match value { - Receiver::Tokio(tx) => Ok(tx), + Receiver::Tokio(reply) => Ok(reply), Receiver::Boxed(_) => Err(value), } } @@ -533,12 +555,12 @@ pub mod channel { impl Debug for Receiver { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - Self::Tokio(inner) => f + Self::Tokio(message) => f .debug_struct("Tokio") - .field("avail", &inner.capacity()) - .field("cap", &inner.max_capacity()) + .field("avail", &message.capacity()) + .field("cap", &message.max_capacity()) .finish(), - Self::Boxed(inner) => f.debug_tuple("Boxed").field(&inner).finish(), + Self::Boxed(message) => f.debug_tuple("Boxed").field(&message).finish(), } } } @@ -565,7 +587,7 @@ pub mod channel { impl crate::Receiver for NoReceiver {} } - /// Error when sending a oneshot or spsc message. For local communication, + /// Error when sending a oneshot or mpsc message. For local communication, /// the only thing that can go wrong is that the receiver has been dropped. /// /// For rpc communication, there can be any number of errors, so this is a @@ -591,7 +613,7 @@ pub mod channel { } } - /// Error when receiving a oneshot or spsc message. For local communication, + /// Error when receiving a oneshot or mpsc message. For local communication, /// the only thing that can go wrong is that the sender has been closed. /// /// For rpc communication, there can be any number of errors, so this is a @@ -622,35 +644,35 @@ pub mod channel { /// This expands the protocol message to a full message that includes the /// active and unserializable channels. /// -/// The channel kind for rx and tx is defined by implementing the `Channels` +/// The channel kind for request and reply is defined by implementing the `Channels` /// trait, either manually or using a macro. /// /// When the `message_spans` feature is enabled, this also includes a tracing /// span to carry the tracing context during message passing. -pub struct WithChannels, S: Service> { - /// The inner message. - pub inner: I, - /// The return channel to send the response to. Can be set to [`crate::channel::none::NoSender`] if not needed. - pub tx: >::Tx, +pub struct Request, S: Service> { + /// The request message. + pub message: I, + /// The return channel to send the reply to. Can be set to [`crate::channel::none::NoSender`] if not needed. + pub reply: >::Reply, /// The request channel to receive the request from. Can be set to [`NoReceiver`] if not needed. - pub rx: >::Rx, + pub updates: >::Updates, /// The current span where the full message was created. #[cfg(feature = "message_spans")] #[cfg_attr(quicrpc_docsrs, doc(cfg(feature = "message_spans")))] pub span: tracing::Span, } -impl + Debug, S: Service> Debug for WithChannels { +impl + Debug, S: Service> Debug for Request { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_tuple("") - .field(&self.inner) - .field(&self.tx) - .field(&self.rx) + .field(&self.message) + .field(&self.reply) + .field(&self.updates) .finish() } } -impl, S: Service> WithChannels { +impl, S: Service> Request { /// Get the parent span #[cfg(feature = "message_spans")] pub fn parent_span_opt(&self) -> Option<&tracing::Span> { @@ -658,21 +680,21 @@ impl, S: Service> WithChannels { } } -/// Tuple conversion from inner message and tx/rx channels to a WithChannels struct +/// Tuple conversion from message message and reply/request channels to a Request struct /// -/// For the case where you want both tx and rx channels. -impl, S: Service, Tx, Rx> From<(I, Tx, Rx)> for WithChannels +/// For the case where you want both reply and request channels. +impl, S: Service, Reply, Updates> From<(I, Reply, Updates)> for Request where I: Channels, - >::Tx: From, - >::Rx: From, + >::Reply: From, + >::Updates: From, { - fn from(inner: (I, Tx, Rx)) -> Self { - let (inner, tx, rx) = inner; + fn from(message: (I, Reply, Updates)) -> Self { + let (message, reply, updates) = message; Self { - inner, - tx: tx.into(), - rx: rx.into(), + message, + reply: reply.into(), + updates: updates.into(), #[cfg(feature = "message_spans")] #[cfg_attr(quicrpc_docsrs, doc(cfg(feature = "message_spans")))] span: tracing::Span::current(), @@ -680,21 +702,21 @@ where } } -/// Tuple conversion from inner message and tx channel to a WithChannels struct +/// Tuple conversion from message message and reply channel to a Request struct /// -/// For the very common case where you just need a tx channel to send the response to. -impl From<(I, Tx)> for WithChannels +/// For the very common case where you just need a reply channel to send the reply to. +impl From<(I, Reply)> for Request where - I: Channels, + I: Channels, S: Service, - >::Tx: From, + >::Reply: From, { - fn from(inner: (I, Tx)) -> Self { - let (inner, tx) = inner; + fn from(message: (I, Reply)) -> Self { + let (message, reply) = message; Self { - inner, - tx: tx.into(), - rx: NoReceiver, + message, + reply: reply.into(), + updates: NoReceiver, #[cfg(feature = "message_spans")] #[cfg_attr(quicrpc_docsrs, doc(cfg(feature = "message_spans")))] span: tracing::Span::current(), @@ -702,15 +724,15 @@ where } } -/// Deref so you can access the inner fields directly. +/// Deref so you can access the message fields directly. /// -/// If the inner message has fields named `tx`, `rx` or `span`, you need to use the -/// `inner` field to access them. -impl, S: Service> Deref for WithChannels { +/// If the message message has fields named `reply`, `request` or `span`, you need to use the +/// `message` field to access them. +impl, S: Service> Deref for Request { type Target = I; fn deref(&self) -> &Self::Target { - &self.inner + &self.message } } @@ -721,8 +743,8 @@ impl, S: Service> Deref for WithChannels { /// type. It can be thought of as the definition of the protocol. /// /// `M` is typically an enum with a case for each possible message type, where -/// each case is a `WithChannels` struct that extends the inner protocol message -/// with a local tx and rx channel as well as a tracing span to allow for +/// each case is a `Request` struct that extends the message protocol message +/// with a local reply and request channel as well as a tracing span to allow for /// keeping tracing context across async boundaries. /// /// In some cases, `M` and `R` can be enums for a subset of the protocol. E.g. @@ -740,14 +762,14 @@ impl Clone for Client { } impl From> for Client { - fn from(tx: LocalSender) -> Self { - Self(ClientInner::Local(tx.0), PhantomData) + fn from(reply: LocalSender) -> Self { + Self(ClientInner::Local(reply.0), PhantomData) } } impl From> for Client { - fn from(tx: tokio::sync::mpsc::Sender) -> Self { - LocalSender::from(tx).into() + fn from(reply: tokio::sync::mpsc::Sender) -> Self { + LocalSender::from(reply).into() } } @@ -771,7 +793,7 @@ impl Client { /// requests. pub fn local(&self) -> Option> { match &self.0 { - ClientInner::Local(tx) => Some(tx.clone().into()), + ClientInner::Local(reply) => Some(reply.clone().into()), ClientInner::Remote(..) => None, } } @@ -791,7 +813,10 @@ impl Client { pub fn request( &self, ) -> impl Future< - Output = result::Result, rpc::RemoteSender>, RequestError>, + Output = result::Result< + RequestSender, rpc::RemoteSender>, + RequestError, + >, > + 'static where S: Service, @@ -801,26 +826,26 @@ impl Client { #[cfg(feature = "rpc")] { let cloned = match &self.0 { - ClientInner::Local(tx) => Request::Local(tx.clone()), - ClientInner::Remote(connection) => Request::Remote(connection.clone_boxed()), + ClientInner::Local(reply) => RequestSender::Local(reply.clone()), + ClientInner::Remote(connection) => RequestSender::Remote(connection.clone_boxed()), }; async move { match cloned { - Request::Local(tx) => Ok(Request::Local(tx.into())), - Request::Remote(conn) => { + RequestSender::Local(reply) => Ok(RequestSender::Local(reply.into())), + RequestSender::Remote(conn) => { let (send, recv) = conn.open_bi().await?; - Ok(Request::Remote(rpc::RemoteSender::new(send, recv))) + Ok(RequestSender::Remote(rpc::RemoteSender::new(send, recv))) } } } } #[cfg(not(feature = "rpc"))] { - let ClientInner::Local(tx) = &self.0 else { + let ClientInner::Local(reply) = &self.0 else { unreachable!() }; - let tx = tx.clone().into(); - async move { Ok(Request::Local(tx)) } + let reply = reply.clone().into(); + async move { Ok(RequestSender::Local(reply)) } } } @@ -828,25 +853,27 @@ impl Client { pub fn rpc(&self, msg: Req) -> impl Future> + Send + 'static where S: Service, - M: From> + Send + Sync + Unpin + 'static, + M: From> + Send + Sync + Unpin + 'static, R: From + Serialize + Send + Sync + 'static, - Req: Channels, Rx = NoReceiver> + Send + 'static, + Req: Channels, Updates = NoReceiver> + + Send + + 'static, Res: RpcMessage, { let request = self.request(); async move { let recv: channel::oneshot::Receiver = match request.await? { - Request::Local(request) => { - let (tx, rx) = channel::oneshot::channel(); - request.send((msg, tx)).await?; - rx + RequestSender::Local(tx) => { + let (reply, request) = channel::oneshot::channel(); + tx.send((msg, reply)).await?; + request } #[cfg(not(feature = "rpc"))] - Request::Remote(_request) => unreachable!(), + RequestSender::Remote(_request) => unreachable!(), #[cfg(feature = "rpc")] - Request::Remote(request) => { - let (_tx, rx) = request.write(msg).await?; - rx.into() + RequestSender::Remote(tx) => { + let (_reply, request) = tx.write(msg).await?; + request.into() } }; let res = recv.await?; @@ -854,33 +881,33 @@ impl Client { } } - /// Performs a request for which the server returns a spsc receiver. + /// Performs a request for which the server returns a mpsc receiver. pub fn server_streaming( &self, msg: Req, - local_response_cap: usize, - ) -> impl Future>> + Send + 'static + local_reply_cap: usize, + ) -> impl Future>> + Send + 'static where S: Service, - M: From> + Send + Sync + Unpin + 'static, + M: From> + Send + Sync + Unpin + 'static, R: From + Serialize + Send + Sync + 'static, - Req: Channels, Rx = NoReceiver> + Send + 'static, + Req: Channels, Updates = NoReceiver> + Send + 'static, Res: RpcMessage, { let request = self.request(); async move { - let recv: channel::spsc::Receiver = match request.await? { - Request::Local(request) => { - let (tx, rx) = channel::spsc::channel(local_response_cap); - request.send((msg, tx)).await?; - rx + let recv: channel::mpsc::Receiver = match request.await? { + RequestSender::Local(tx) => { + let (reply, request) = channel::mpsc::channel(local_reply_cap); + tx.send((msg, reply)).await?; + request } #[cfg(not(feature = "rpc"))] - Request::Remote(_request) => unreachable!(), + RequestSender::Remote(_request) => unreachable!(), #[cfg(feature = "rpc")] - Request::Remote(request) => { - let (_tx, rx) = request.write(msg).await?; - rx.into() + RequestSender::Remote(tx) => { + let (_reply, request) = tx.write(msg).await?; + request.into() } }; Ok(recv) @@ -894,80 +921,89 @@ impl Client { local_update_cap: usize, ) -> impl Future< Output = Result<( - channel::spsc::Sender, + channel::mpsc::Sender, channel::oneshot::Receiver, )>, > where S: Service, - M: From> + Send + Sync + Unpin + 'static, + M: From> + Send + Sync + Unpin + 'static, R: From + Serialize + 'static, - Req: Channels, Rx = channel::spsc::Receiver>, + Req: Channels< + S, + Reply = channel::oneshot::Sender, + Updates = channel::mpsc::Receiver, + >, Update: RpcMessage, Res: RpcMessage, { let request = self.request(); async move { - let (update_tx, res_rx): ( - channel::spsc::Sender, + let (update_reply, res_request): ( + channel::mpsc::Sender, channel::oneshot::Receiver, ) = match request.await? { - Request::Local(request) => { - let (req_tx, req_rx) = channel::spsc::channel(local_update_cap); - let (res_tx, res_rx) = channel::oneshot::channel(); - request.send((msg, res_tx, req_rx)).await?; - (req_tx, res_rx) + RequestSender::Local(request) => { + let (req_reply, req_request) = channel::mpsc::channel(local_update_cap); + let (res_reply, res_request) = channel::oneshot::channel(); + request.send((msg, res_reply, req_request)).await?; + (req_reply, res_request) } #[cfg(not(feature = "rpc"))] - Request::Remote(_request) => unreachable!(), + RequestSender::Remote(_request) => unreachable!(), #[cfg(feature = "rpc")] - Request::Remote(request) => { - let (tx, rx) = request.write(msg).await?; - (tx.into(), rx.into()) + RequestSender::Remote(request) => { + let (reply, request) = request.write(msg).await?; + (reply.into(), request.into()) } }; - Ok((update_tx, res_rx)) + Ok((update_reply, res_request)) } } - /// Performs a request for which the client can send updates, and the server returns a spsc receiver. + /// Performs a request for which the client can send updates, and the server returns a mpsc receiver. pub fn bidi_streaming( &self, msg: Req, local_update_cap: usize, - local_response_cap: usize, - ) -> impl Future, channel::spsc::Receiver)>> + local_reply_cap: usize, + ) -> impl Future, channel::mpsc::Receiver)>> + Send + 'static where S: Service, - M: From> + Send + Sync + Unpin + 'static, + M: From> + Send + Sync + Unpin + 'static, R: From + Serialize + Send + 'static, - Req: Channels, Rx = channel::spsc::Receiver> - + Send + Req: Channels< + S, + Reply = channel::mpsc::Sender, + Updates = channel::mpsc::Receiver, + > + Send + 'static, Update: RpcMessage, Res: RpcMessage, { let request = self.request(); async move { - let (update_tx, res_rx): (channel::spsc::Sender, channel::spsc::Receiver) = - match request.await? { - Request::Local(request) => { - let (update_tx, update_rx) = channel::spsc::channel(local_update_cap); - let (res_tx, res_rx) = channel::spsc::channel(local_response_cap); - request.send((msg, res_tx, update_rx)).await?; - (update_tx, res_rx) - } - #[cfg(not(feature = "rpc"))] - Request::Remote(_request) => unreachable!(), - #[cfg(feature = "rpc")] - Request::Remote(request) => { - let (tx, rx) = request.write(msg).await?; - (tx.into(), rx.into()) - } - }; - Ok((update_tx, res_rx)) + let (update_reply, res_request): ( + channel::mpsc::Sender, + channel::mpsc::Receiver, + ) = match request.await? { + RequestSender::Local(request) => { + let (update_reply, update_request) = channel::mpsc::channel(local_update_cap); + let (res_reply, res_request) = channel::mpsc::channel(local_reply_cap); + request.send((msg, res_reply, update_request)).await?; + (update_reply, res_request) + } + #[cfg(not(feature = "rpc"))] + RequestSender::Remote(_request) => unreachable!(), + #[cfg(feature = "rpc")] + RequestSender::Remote(request) => { + let (reply, request) = request.write(msg).await?; + (reply.into(), request.into()) + } + }; + Ok((update_reply, res_request)) } } } @@ -987,7 +1023,7 @@ pub(crate) enum ClientInner { impl Clone for ClientInner { fn clone(&self) -> Self { match self { - Self::Local(tx) => Self::Local(tx.clone()), + Self::Local(reply) => Self::Local(reply.clone()), #[cfg(feature = "rpc")] Self::Remote(conn) => Self::Remote(conn.clone_boxed()), #[cfg(not(feature = "rpc"))] @@ -1063,7 +1099,7 @@ impl From for io::Error { /// /// This is a wrapper around an in-memory channel (currently [`tokio::sync::mpsc::Sender`]), /// that adds nice syntax for sending messages that can be converted into -/// [`WithChannels`]. +/// [`Request`]. #[derive(Debug)] #[repr(transparent)] pub struct LocalSender(tokio::sync::mpsc::Sender, std::marker::PhantomData); @@ -1075,8 +1111,8 @@ impl Clone for LocalSender { } impl From> for LocalSender { - fn from(tx: tokio::sync::mpsc::Sender) -> Self { - Self(tx, PhantomData) + fn from(reply: tokio::sync::mpsc::Sender) -> Self { + Self(reply, PhantomData) } } @@ -1089,7 +1125,9 @@ pub mod rpc { #[cfg_attr(quicrpc_docsrs, doc(cfg(feature = "rpc")))] pub mod rpc { //! Module for cross-process RPC using [`quinn`]. - use std::{fmt::Debug, future::Future, io, marker::PhantomData, pin::Pin, sync::Arc}; + use std::{ + fmt::Debug, future::Future, io, marker::PhantomData, ops::DerefMut, pin::Pin, sync::Arc, + }; use n0_future::{future::Boxed as BoxFuture, task::JoinSet}; use quinn::ConnectionError; @@ -1099,10 +1137,9 @@ pub mod rpc { use crate::{ channel::{ + mpsc::{self, DynReceiver, DynSender}, none::NoSender, - oneshot, - spsc::{self, DynReceiver, DynSender}, - RecvError, SendError, + oneshot, RecvError, SendError, }, util::{now_or_never, AsyncReadVarintExt, WriteVarintExt}, RequestError, RpcMessage, @@ -1271,9 +1308,9 @@ pub mod rpc { } } - impl From for spsc::Receiver { + impl From for mpsc::Receiver { fn from(read: quinn::RecvStream) -> Self { - spsc::Receiver::Boxed(Box::new(QuinnReceiver { + mpsc::Receiver::Boxed(Box::new(QuinnReceiver { recv: read, _marker: PhantomData, })) @@ -1301,13 +1338,15 @@ pub mod rpc { } } - impl From for spsc::Sender { + impl From for mpsc::Sender { fn from(write: quinn::SendStream) -> Self { - spsc::Sender::Boxed(Box::new(QuinnSender { - send: write, - buffer: SmallVec::new(), - _marker: PhantomData, - })) + mpsc::Sender::Boxed(Arc::new(QuinnSender(tokio::sync::Mutex::new( + QuinnSenderState::Open(QuinnSenderInner { + send: write, + buffer: SmallVec::new(), + _marker: PhantomData, + }), + )))) } } @@ -1325,8 +1364,9 @@ pub mod rpc { impl DynReceiver for QuinnReceiver { fn recv( &mut self, - ) -> Pin, RecvError>> + Send + '_>> - { + ) -> Pin< + Box, RecvError>> + Send + Sync + '_>, + > { Box::pin(async { let read = &mut self.recv; let Some(size) = read.read_varint_u64().await? else { @@ -1347,20 +1387,17 @@ pub mod rpc { fn drop(&mut self) {} } - struct QuinnSender { + struct QuinnSenderInner { send: quinn::SendStream, buffer: SmallVec<[u8; 128]>, _marker: std::marker::PhantomData, } - impl Debug for QuinnSender { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("QuinnSender").finish() - } - } - - impl DynSender for QuinnSender { - fn send(&mut self, value: T) -> Pin> + Send + '_>> { + impl QuinnSenderInner { + fn send( + &mut self, + value: T, + ) -> Pin> + Send + Sync + '_>> { Box::pin(async { let value = value; self.buffer.clear(); @@ -1374,7 +1411,7 @@ pub mod rpc { fn try_send( &mut self, value: T, - ) -> Pin> + Send + '_>> { + ) -> Pin> + Send + Sync + '_>> { Box::pin(async { // todo: move the non-async part out of the box. Will require a new return type. let value = value; @@ -1390,20 +1427,81 @@ pub mod rpc { }) } - fn closed(&mut self) -> Pin + Send + '_>> { + fn closed(&mut self) -> Pin + Send + Sync + '_>> { Box::pin(async move { self.send.stopped().await.ok(); }) } + } - fn is_rpc(&self) -> bool { - true + #[derive(Default)] + enum QuinnSenderState { + Open(QuinnSenderInner), + #[default] + Closed, + } + + struct QuinnSender(tokio::sync::Mutex>); + + impl Debug for QuinnSender { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("QuinnSender").finish() } } - impl Drop for QuinnSender { - fn drop(&mut self) { - self.send.finish().ok(); + impl DynSender for QuinnSender { + fn send( + &self, + value: T, + ) -> Pin> + Send + Sync + '_>> { + Box::pin(async { + let mut guard = self.0.lock().await; + let sender = std::mem::take(guard.deref_mut()); + match sender { + QuinnSenderState::Open(mut sender) => { + let res = sender.send(value).await; + if res.is_ok() { + *guard = QuinnSenderState::Open(sender); + } + res + } + QuinnSenderState::Closed => Err(io::ErrorKind::BrokenPipe.into()), + } + }) + } + + fn try_send( + &self, + value: T, + ) -> Pin> + Send + Sync + '_>> { + Box::pin(async { + let mut guard = self.0.lock().await; + let sender = std::mem::take(guard.deref_mut()); + match sender { + QuinnSenderState::Open(mut sender) => { + let res = sender.try_send(value).await; + if res.is_ok() { + *guard = QuinnSenderState::Open(sender); + } + res + } + QuinnSenderState::Closed => Err(io::ErrorKind::BrokenPipe.into()), + } + }) + } + + fn closed(&self) -> Pin + Send + Sync + '_>> { + Box::pin(async { + let mut guard = self.0.lock().await; + match guard.deref_mut() { + QuinnSenderState::Open(sender) => sender.closed().await, + QuinnSenderState::Closed => {} + } + }) + } + + fn is_rpc(&self) -> bool { + true } } @@ -1459,9 +1557,9 @@ pub mod rpc { .map_err(|e| io::Error::new(io::ErrorKind::UnexpectedEof, e))?; let msg: R = postcard::from_bytes(&buf) .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; - let rx = recv; - let tx = send; - handler(msg, rx, tx).await?; + let request = recv; + let reply = send; + handler(msg, request, reply).await?; } }; let span = trace_span!("rpc", id = request_id); @@ -1473,7 +1571,7 @@ pub mod rpc { /// A request to a service. This can be either local or remote. #[derive(Debug)] -pub enum Request { +pub enum RequestSender { /// Local in memory request Local(L), /// Remote cross process request @@ -1482,10 +1580,10 @@ pub enum Request { impl LocalSender { /// Send a message to the service - pub fn send(&self, value: impl Into>) -> SendFut + pub fn send(&self, value: impl Into>) -> SendFut where T: Channels, - M: From>, + M: From>, { let value: M = value.into().into(); SendFut::new(self.0.clone(), value) diff --git a/tests/compile_fail/extra_attr_types.rs b/tests/compile_fail/extra_attr_types.rs index d6c288f..d5eea11 100644 --- a/tests/compile_fail/extra_attr_types.rs +++ b/tests/compile_fail/extra_attr_types.rs @@ -2,8 +2,8 @@ use irpc::rpc_requests; #[rpc_requests(Service, Msg)] enum Enum { - #[rpc(tx = NoSender, rx = NoReceiver, fnord = Foo)] + #[rpc(reply = NoSender, updates = NoReceiver, fnord = Foo)] A(u8), } -fn main() {} \ No newline at end of file +fn main() {} diff --git a/tests/compile_fail/extra_attr_types.stderr b/tests/compile_fail/extra_attr_types.stderr index c19048b..b117330 100644 --- a/tests/compile_fail/extra_attr_types.stderr +++ b/tests/compile_fail/extra_attr_types.stderr @@ -1,5 +1,5 @@ error: Unknown arguments provided: ["fnord"] --> tests/compile_fail/extra_attr_types.rs:5:5 | -5 | #[rpc(tx = NoSender, rx = NoReceiver, fnord = Foo)] +5 | #[rpc(reply = NoSender, updates = NoReceiver, fnord = Foo)] | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/tests/compile_fail/wrong_attr_types.stderr b/tests/compile_fail/wrong_attr_types.stderr index f679dbd..fd1c062 100644 --- a/tests/compile_fail/wrong_attr_types.stderr +++ b/tests/compile_fail/wrong_attr_types.stderr @@ -1,4 +1,4 @@ -error: rpc requires a tx type +error: rpc requires a reply type --> tests/compile_fail/wrong_attr_types.rs:5:5 | 5 | #[rpc(fnord = Bla)] diff --git a/tests/derive.rs b/tests/derive.rs index 3e0122f..666cde8 100644 --- a/tests/derive.rs +++ b/tests/derive.rs @@ -39,13 +39,13 @@ fn derive_simple() { #[rpc_requests(Service, message = RequestWithChannels)] #[derive(Debug, Serialize, Deserialize)] enum Request { - #[rpc(tx=oneshot::Sender<()>)] + #[rpc(reply=oneshot::Sender<()>)] Rpc(RpcRequest), - #[rpc(tx=NoSender)] + #[rpc(reply=NoSender)] ServerStreaming(ServerStreamingRequest), - #[rpc(tx=NoSender)] + #[rpc(reply=NoSender)] BidiStreaming(BidiStreamingRequest), - #[rpc(tx=NoSender)] + #[rpc(reply=NoSender)] ClientStreaming(ClientStreamingRequest), } diff --git a/tests/mpsc_sender.rs b/tests/mpsc_sender.rs new file mode 100644 index 0000000..e8382bb --- /dev/null +++ b/tests/mpsc_sender.rs @@ -0,0 +1,117 @@ +use std::{ + io::ErrorKind, + net::{Ipv4Addr, SocketAddr, SocketAddrV4}, + time::Duration, +}; + +use irpc::{ + channel::{mpsc, SendError}, + util::{make_client_endpoint, make_server_endpoint}, +}; +use quinn::Endpoint; +use testresult::TestResult; +use tokio::time::timeout; + +fn create_connected_endpoints() -> TestResult<(Endpoint, Endpoint, SocketAddr)> { + let addr = SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0).into(); + let (server, cert) = make_server_endpoint(addr)?; + let client = make_client_endpoint(addr, &[cert.as_slice()])?; + let port = server.local_addr()?.port(); + let server_addr = SocketAddrV4::new(Ipv4Addr::LOCALHOST, port).into(); + Ok((server, client, server_addr)) +} + +/// Checks that all clones of a `Sender` will get the closed signal as soon as +/// a send fails with an io error. +#[tokio::test] +async fn mpsc_sender_clone_closed_error() -> TestResult<()> { + tracing_subscriber::fmt::try_init().ok(); + let (server, client, server_addr) = create_connected_endpoints()?; + // accept a single bidi stream on a single connection, then immediately stop it + let server = tokio::spawn(async move { + let conn = server.accept().await.unwrap().await?; + let (_, mut recv) = conn.accept_bi().await?; + recv.stop(1u8.into())?; + TestResult::Ok(()) + }); + let conn = client.connect(server_addr, "localhost")?.await?; + let (send, _) = conn.open_bi().await?; + let send1 = mpsc::Sender::>::from(send); + let send2 = send1.clone(); + let send3 = send1.clone(); + let second_client = tokio::spawn(async move { + send2.closed().await; + }); + let third_client = tokio::spawn(async move { + // this should fail with an io error, since the stream was stopped + loop { + match send3.send(vec![1, 2, 3]).await { + Err(SendError::Io(e)) if e.kind() == ErrorKind::BrokenPipe => break, + _ => {} + }; + } + }); + // send until we get an error because the remote side stopped the stream + while send1.send(vec![1, 2, 3]).await.is_ok() {} + match send1.send(vec![4, 5, 6]).await { + Err(SendError::Io(e)) if e.kind() == ErrorKind::BrokenPipe => {} + e => panic!("Expected SendError::Io with kind BrokenPipe, got {:?}", e), + }; + // check that closed signal was received by the second sender + second_client.await?; + // check that the third sender will get the right kind of io error eventually + third_client.await?; + // server should finish without errors + server.await??; + Ok(()) +} + +/// Checks that all clones of a `Sender` will get the closed signal as soon as +/// a send future gets dropped before completing. +#[tokio::test] +async fn mpsc_sender_clone_drop_error() -> TestResult<()> { + let (server, client, server_addr) = create_connected_endpoints()?; + // accept a single bidi stream on a single connection, then read indefinitely + // until we get an error or the stream is finished + let server = tokio::spawn(async move { + let conn = server.accept().await.unwrap().await?; + let (_, mut recv) = conn.accept_bi().await?; + let mut buf = vec![0u8; 1024]; + while let Ok(Some(_)) = recv.read(&mut buf).await {} + TestResult::Ok(()) + }); + let conn = client.connect(server_addr, "localhost")?.await?; + let (send, _) = conn.open_bi().await?; + let send1 = mpsc::Sender::>::from(send); + let send2 = send1.clone(); + let send3 = send1.clone(); + let second_client = tokio::spawn(async move { + send2.closed().await; + }); + let third_client = tokio::spawn(async move { + // this should fail with an io error, since the stream was stopped + loop { + match send3.send(vec![1, 2, 3]).await { + Err(SendError::Io(e)) if e.kind() == ErrorKind::BrokenPipe => break, + _ => {} + }; + } + }); + // send a lot of data with a tiny timeout, this will cause the send future to be dropped + loop { + let send_future = send1.send(vec![0u8; 1024 * 1024]); + // not sure if there is a better way. I want to poll the future a few times so it has time to + // start sending, but don't want to give it enough time to complete. + // I don't think now_or_never would work, since it wouldn't have time to start sending + if timeout(Duration::from_micros(1), send_future) + .await + .is_err() + { + break; + } + } + server.await??; + second_client.await?; + third_client.await?; + Ok(()) +}