From f540b50da8553c0c70a2d04c9ffb6c746a680524 Mon Sep 17 00:00:00 2001 From: Frando Date: Fri, 6 Jun 2025 13:31:39 +0200 Subject: [PATCH 01/17] fix: add Sync bound to DynReceiver::recv so that receiver stream can be Sync --- src/lib.rs | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 3d594f3..1aa2ccb 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -425,7 +425,14 @@ pub mod channel { pub trait DynReceiver: Debug + Send + Sync + 'static { fn recv( &mut self, - ) -> Pin, RecvError>> + Send + '_>>; + ) -> Pin< + Box< + dyn Future, RecvError>> + + Send + + Sync + + '_, + >, + >; } impl Debug for Sender { @@ -505,7 +512,7 @@ pub mod channel { #[cfg(feature = "stream")] pub fn into_stream( self, - ) -> impl n0_future::Stream> + Send + 'static + ) -> impl n0_future::Stream> + Send + Sync + 'static { n0_future::stream::unfold(self, |mut recv| async move { recv.recv().await.transpose().map(|msg| (msg, recv)) @@ -1325,8 +1332,9 @@ pub mod rpc { impl DynReceiver for QuinnReceiver { fn recv( &mut self, - ) -> Pin, RecvError>> + Send + '_>> - { + ) -> Pin< + Box, RecvError>> + Send + Sync + '_>, + > { Box::pin(async { let read = &mut self.recv; let Some(size) = read.read_varint_u64().await? else { From 0fd89a24ee3afa6d3b2c632f3130cfd8409ea104 Mon Sep 17 00:00:00 2001 From: Frando Date: Fri, 6 Jun 2025 14:26:39 +0200 Subject: [PATCH 02/17] refactor: make sender clone --- examples/compute.rs | 14 +++--- examples/derive.rs | 4 +- examples/storage.rs | 2 +- irpc-iroh/examples/derive.rs | 2 +- src/lib.rs | 84 +++++++++++++++++++++--------------- 5 files changed, 60 insertions(+), 46 deletions(-) diff --git a/examples/compute.rs b/examples/compute.rs index 0ec3446..95e4f45 100644 --- a/examples/compute.rs +++ b/examples/compute.rs @@ -123,7 +123,7 @@ impl ComputeActor { tx, inner, span, .. } = fib; let _entered = span.enter(); - let mut sender = tx; + let sender = tx; let mut a = 0u64; let mut b = 1u64; while a <= inner.max { @@ -144,7 +144,7 @@ impl ComputeActor { } = mult; let _entered = span.enter(); let mut receiver = rx; - let mut sender = tx; + let sender = tx; let multiplier = inner.initial; while let Some(num) = receiver.recv().await? { sender.send(multiplier * num).await?; @@ -260,7 +260,7 @@ async fn local() -> anyhow::Result<()> { println!("Local: 5^2 = {}", rx.await?); // Test Sum - let (mut tx, rx) = api.sum().await?; + let (tx, rx) = api.sum().await?; tx.send(1).await?; tx.send(2).await?; tx.send(3).await?; @@ -276,7 +276,7 @@ async fn local() -> anyhow::Result<()> { println!(); // Test Multiply - let (mut in_tx, mut out_rx) = api.multiply(3).await?; + let (in_tx, mut out_rx) = api.multiply(3).await?; in_tx.send(2).await?; in_tx.send(4).await?; in_tx.send(6).await?; @@ -311,7 +311,7 @@ async fn remote() -> anyhow::Result<()> { println!("Remote: 4^2 = {}", rx.await?); // Test Sum - let (mut tx, rx) = api.sum().await?; + let (tx, rx) = api.sum().await?; tx.send(4).await?; tx.send(5).await?; tx.send(6).await?; @@ -327,7 +327,7 @@ async fn remote() -> anyhow::Result<()> { println!(); // Test Multiply - let (mut in_tx, mut out_rx) = api.multiply(5).await?; + let (in_tx, mut out_rx) = api.multiply(5).await?; in_tx.send(1).await?; in_tx.send(2).await?; in_tx.send(3).await?; @@ -380,7 +380,7 @@ async fn bench(api: ComputeApi, n: u64) -> anyhow::Result<()> { // Sequential streaming (using Multiply instead of MultiplyUpdate) { let t0 = std::time::Instant::now(); - let (mut send, mut recv) = api.multiply(2).await?; + let (send, mut recv) = api.multiply(2).await?; let handle = tokio::task::spawn(async move { for i in 0..n { send.send(i).await?; diff --git a/examples/derive.rs b/examples/derive.rs index e03f39f..0f482e4 100644 --- a/examples/derive.rs +++ b/examples/derive.rs @@ -111,7 +111,7 @@ impl StorageActor { } StorageMessage::List(list) => { info!("list {:?}", list); - let WithChannels { mut tx, .. } = list; + let WithChannels { tx, .. } = list; for (key, value) in &self.state { if tx.send(format!("{key}={value}")).await.is_err() { break; @@ -172,7 +172,7 @@ async fn client_demo(api: StorageApi) -> Result<()> { let value = api.get("hello".to_string()).await?; println!("get: hello = {:?}", value); - let (mut tx, rx) = api.set_many().await?; + let (tx, rx) = api.set_many().await?; for i in 0..3 { tx.send((format!("key{i}"), format!("value{i}"))).await?; } diff --git a/examples/storage.rs b/examples/storage.rs index d73f29f..29b07b1 100644 --- a/examples/storage.rs +++ b/examples/storage.rs @@ -104,7 +104,7 @@ impl StorageActor { } StorageMessage::List(list) => { info!("list {:?}", list); - let WithChannels { mut tx, .. } = list; + let WithChannels { tx, .. } = list; for (key, value) in &self.state { if tx.send(format!("{key}={value}")).await.is_err() { break; diff --git a/irpc-iroh/examples/derive.rs b/irpc-iroh/examples/derive.rs index f348654..b381cb7 100644 --- a/irpc-iroh/examples/derive.rs +++ b/irpc-iroh/examples/derive.rs @@ -141,7 +141,7 @@ mod storage { } StorageMessage::List(list) => { info!("list {:?}", list); - let WithChannels { mut tx, .. } = list; + let WithChannels { tx, .. } = list; for (key, value) in &self.state { if tx.send(format!("{key}={value}")).await.is_err() { break; diff --git a/src/lib.rs b/src/lib.rs index 1aa2ccb..10ebffe 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -317,7 +317,7 @@ pub mod channel { /// /// For the rpc case, the send side can not be cloned, hence spsc instead of mpsc. pub mod spsc { - use std::{fmt::Debug, future::Future, io, pin::Pin}; + use std::{fmt::Debug, future::Future, io, pin::Pin, sync::Arc}; use super::{RecvError, SendError}; use crate::RpcMessage; @@ -332,15 +332,11 @@ pub mod channel { /// Single producer, single consumer sender. /// - /// For the local case, this wraps a tokio::sync::mpsc::Sender. However, - /// due to the fact that a stream to a remote service can not be cloned, - /// this can also not be cloned. - /// - /// This forces you to use senders in a linear way, passing out references - /// to the sender to other tasks instead of cloning it. + /// For the local case, this wraps a tokio::sync::mpsc::Sender. + #[derive(Clone)] pub enum Sender { Tokio(tokio::sync::mpsc::Sender), - Boxed(Box>), + Boxed(Arc>), } impl Sender { @@ -354,7 +350,7 @@ pub mod channel { } } - pub async fn closed(&mut self) + pub async fn closed(&self) where T: RpcMessage, { @@ -369,7 +365,7 @@ pub mod channel { where T: RpcMessage, { - futures_util::sink::unfold(self, |mut sink, value| async move { + futures_util::sink::unfold(self, |sink, value| async move { sink.send(value).await?; Ok(sink) }) @@ -399,10 +395,7 @@ pub mod channel { /// /// For the remote case, if the message can not be completely sent, /// this must return an error and disable the channel. - fn send( - &mut self, - value: T, - ) -> Pin> + Send + '_>>; + fn send(&self, value: T) -> Pin> + Send + '_>>; /// Try to send a message, returning as fast as possible if sending /// is not currently possible. @@ -410,12 +403,12 @@ pub mod channel { /// For the remote case, it must be guaranteed that the message is /// either completely sent or not at all. fn try_send( - &mut self, + &self, value: T, ) -> Pin> + Send + '_>>; /// Await the sender close - fn closed(&mut self) -> Pin + Send + '_>>; + fn closed(&self) -> Pin + Send + '_>>; /// True if this is a remote sender fn is_rpc(&self) -> bool; @@ -450,7 +443,7 @@ pub mod channel { impl Sender { /// Send a message and yield until either it is sent or an error occurs. - pub async fn send(&mut self, value: T) -> std::result::Result<(), SendError> { + pub async fn send(&self, value: T) -> std::result::Result<(), SendError> { match self { Sender::Tokio(tx) => { tx.send(value).await.map_err(|_| SendError::ReceiverClosed) @@ -1310,11 +1303,13 @@ pub mod rpc { impl From for spsc::Sender { fn from(write: quinn::SendStream) -> Self { - spsc::Sender::Boxed(Box::new(QuinnSender { - send: write, - buffer: SmallVec::new(), - _marker: PhantomData, - })) + spsc::Sender::Boxed(Arc::new(QuinnSender(tokio::sync::Mutex::new( + QuinnSenderInner { + send: write, + buffer: SmallVec::new(), + _marker: PhantomData, + }, + )))) } } @@ -1355,19 +1350,13 @@ pub mod rpc { fn drop(&mut self) {} } - struct QuinnSender { + struct QuinnSenderInner { send: quinn::SendStream, buffer: SmallVec<[u8; 128]>, _marker: std::marker::PhantomData, } - impl Debug for QuinnSender { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("QuinnSender").finish() - } - } - - impl DynSender for QuinnSender { + impl QuinnSenderInner { fn send(&mut self, value: T) -> Pin> + Send + '_>> { Box::pin(async { let value = value; @@ -1403,18 +1392,43 @@ pub mod rpc { self.send.stopped().await.ok(); }) } + } - fn is_rpc(&self) -> bool { - true + struct QuinnSender(tokio::sync::Mutex>); + + impl Debug for QuinnSender { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("QuinnSender").finish() } } - impl Drop for QuinnSender { - fn drop(&mut self) { - self.send.finish().ok(); + impl DynSender for QuinnSender { + fn send(&self, value: T) -> Pin> + Send + '_>> { + Box::pin(async { self.0.lock().await.send(value).await }) + } + + fn try_send( + &self, + value: T, + ) -> Pin> + Send + '_>> { + Box::pin(async { self.0.lock().await.try_send(value).await }) + } + + fn closed(&self) -> Pin + Send + '_>> { + Box::pin(async { self.0.lock().await.closed().await }) + } + + fn is_rpc(&self) -> bool { + true } } + // impl Drop for QuinnSender { + // fn drop(&mut self) { + // self.send.finish().ok(); + // } + // } + /// Type alias for a handler fn for remote requests pub type Handler = Arc< dyn Fn( From 846f23e0a9bae139441d2880aa784a7f2c5ac5d5 Mon Sep 17 00:00:00 2001 From: Frando Date: Fri, 6 Jun 2025 14:45:53 +0200 Subject: [PATCH 03/17] refactor: add Sync bounds for send methods too --- src/lib.rs | 27 ++++++++++++++++++--------- 1 file changed, 18 insertions(+), 9 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 10ebffe..14737d7 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -395,7 +395,10 @@ pub mod channel { /// /// For the remote case, if the message can not be completely sent, /// this must return an error and disable the channel. - fn send(&self, value: T) -> Pin> + Send + '_>>; + fn send( + &self, + value: T, + ) -> Pin> + Send + Sync + '_>>; /// Try to send a message, returning as fast as possible if sending /// is not currently possible. @@ -405,10 +408,10 @@ pub mod channel { fn try_send( &self, value: T, - ) -> Pin> + Send + '_>>; + ) -> Pin> + Send + Sync + '_>>; /// Await the sender close - fn closed(&self) -> Pin + Send + '_>>; + fn closed(&self) -> Pin + Send + Sync + '_>>; /// True if this is a remote sender fn is_rpc(&self) -> bool; @@ -1357,7 +1360,10 @@ pub mod rpc { } impl QuinnSenderInner { - fn send(&mut self, value: T) -> Pin> + Send + '_>> { + fn send( + &mut self, + value: T, + ) -> Pin> + Send + Sync + '_>> { Box::pin(async { let value = value; self.buffer.clear(); @@ -1371,7 +1377,7 @@ pub mod rpc { fn try_send( &mut self, value: T, - ) -> Pin> + Send + '_>> { + ) -> Pin> + Send + Sync + '_>> { Box::pin(async { // todo: move the non-async part out of the box. Will require a new return type. let value = value; @@ -1387,7 +1393,7 @@ pub mod rpc { }) } - fn closed(&mut self) -> Pin + Send + '_>> { + fn closed(&mut self) -> Pin + Send + Sync + '_>> { Box::pin(async move { self.send.stopped().await.ok(); }) @@ -1403,18 +1409,21 @@ pub mod rpc { } impl DynSender for QuinnSender { - fn send(&self, value: T) -> Pin> + Send + '_>> { + fn send( + &self, + value: T, + ) -> Pin> + Send + Sync + '_>> { Box::pin(async { self.0.lock().await.send(value).await }) } fn try_send( &self, value: T, - ) -> Pin> + Send + '_>> { + ) -> Pin> + Send + Sync + '_>> { Box::pin(async { self.0.lock().await.try_send(value).await }) } - fn closed(&self) -> Pin + Send + '_>> { + fn closed(&self) -> Pin + Send + Sync + '_>> { Box::pin(async { self.0.lock().await.closed().await }) } From b36a4b4dab81d7cf65116d631f32d5564cadeebb Mon Sep 17 00:00:00 2001 From: Frando Date: Fri, 6 Jun 2025 14:47:40 +0200 Subject: [PATCH 04/17] cleanup --- src/lib.rs | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 14737d7..4b88bf8 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1432,12 +1432,6 @@ pub mod rpc { } } - // impl Drop for QuinnSender { - // fn drop(&mut self) { - // self.send.finish().ok(); - // } - // } - /// Type alias for a handler fn for remote requests pub type Handler = Arc< dyn Fn( From c4b0d4eb38b384b517591190881fe34fcff8269d Mon Sep 17 00:00:00 2001 From: Ruediger Klaehn Date: Fri, 13 Jun 2025 10:24:00 +0200 Subject: [PATCH 05/17] Poison the spsc (now mpsc) sender... if the future gets dropped before completion or if it returns an io error. --- src/lib.rs | 63 +++++++++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 55 insertions(+), 8 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 4b88bf8..f86421e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -389,7 +389,7 @@ pub mod channel { } } - /// A sender that can be wrapped in a `Box>`. + /// A sender that can be wrapped in a `Arc>`. pub trait DynSender: Debug + Send + Sync + 'static { /// Send a message. /// @@ -1092,7 +1092,9 @@ pub mod rpc { #[cfg_attr(quicrpc_docsrs, doc(cfg(feature = "rpc")))] pub mod rpc { //! Module for cross-process RPC using [`quinn`]. - use std::{fmt::Debug, future::Future, io, marker::PhantomData, pin::Pin, sync::Arc}; + use std::{ + fmt::Debug, future::Future, io, marker::PhantomData, ops::DerefMut, pin::Pin, sync::Arc, + }; use n0_future::{future::Boxed as BoxFuture, task::JoinSet}; use quinn::ConnectionError; @@ -1307,11 +1309,11 @@ pub mod rpc { impl From for spsc::Sender { fn from(write: quinn::SendStream) -> Self { spsc::Sender::Boxed(Arc::new(QuinnSender(tokio::sync::Mutex::new( - QuinnSenderInner { + QuinnSenderState::Open(QuinnSenderInner { send: write, buffer: SmallVec::new(), _marker: PhantomData, - }, + }), )))) } } @@ -1400,7 +1402,14 @@ pub mod rpc { } } - struct QuinnSender(tokio::sync::Mutex>); + #[derive(Default)] + enum QuinnSenderState { + Open(QuinnSenderInner), + #[default] + Closed, + } + + struct QuinnSender(tokio::sync::Mutex>); impl Debug for QuinnSender { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -1413,18 +1422,56 @@ pub mod rpc { &self, value: T, ) -> Pin> + Send + Sync + '_>> { - Box::pin(async { self.0.lock().await.send(value).await }) + Box::pin(async { + let mut guard = self.0.lock().await; + let sender = std::mem::take(guard.deref_mut()); + match sender { + QuinnSenderState::Open(mut sender) => { + let res = sender.send(value).await; + if res.is_ok() { + *guard = QuinnSenderState::Open(sender); + } + res + } + QuinnSenderState::Closed => Err(io::Error::new( + io::ErrorKind::BrokenPipe, + SendError::ReceiverClosed, + )), + } + }) } fn try_send( &self, value: T, ) -> Pin> + Send + Sync + '_>> { - Box::pin(async { self.0.lock().await.try_send(value).await }) + Box::pin(async { + let mut guard = self.0.lock().await; + let sender = std::mem::take(guard.deref_mut()); + match sender { + QuinnSenderState::Open(mut sender) => { + let res = sender.try_send(value).await; + if res.is_ok() { + *guard = QuinnSenderState::Open(sender); + } + res + } + QuinnSenderState::Closed => Err(io::Error::new( + io::ErrorKind::BrokenPipe, + SendError::ReceiverClosed, + )), + } + }) } fn closed(&self) -> Pin + Send + Sync + '_>> { - Box::pin(async { self.0.lock().await.closed().await }) + Box::pin(async { + let mut guard = self.0.lock().await; + match guard.deref_mut() { + QuinnSenderState::Open(sender) => sender.closed().await, + QuinnSenderState::Closed => {} + } + }) } fn is_rpc(&self) -> bool { From 3ec5bdd98073a781a0e0c8cac4c525a6b4abdca2 Mon Sep 17 00:00:00 2001 From: Ruediger Klaehn Date: Fri, 13 Jun 2025 11:28:06 +0200 Subject: [PATCH 06/17] Add some tests to make sure sender clones behave correctly. --- Cargo.lock | 7 ++++ Cargo.toml | 1 + tests/mpsc_sender.rs | 91 ++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 99 insertions(+) create mode 100644 tests/mpsc_sender.rs diff --git a/Cargo.lock b/Cargo.lock index cca823f..c98ed22 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1655,6 +1655,7 @@ dependencies = [ "rustls", "serde", "smallvec", + "testresult", "thiserror 2.0.12", "thousands", "tokio", @@ -3300,6 +3301,12 @@ dependencies = [ "winapi-util", ] +[[package]] +name = "testresult" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "614b328ff036a4ef882c61570f72918f7e9c5bee1da33f8e7f91e01daee7e56c" + [[package]] name = "thiserror" version = "1.0.69" diff --git a/Cargo.toml b/Cargo.toml index 457a84e..ca078dd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -57,6 +57,7 @@ tokio = { workspace = true, features = ["full"] } thousands = "0.2.0" # macro tests trybuild = "1.0.104" +testresult = "0.4.1" [features] # enable the remote transport diff --git a/tests/mpsc_sender.rs b/tests/mpsc_sender.rs new file mode 100644 index 0000000..aa66a9d --- /dev/null +++ b/tests/mpsc_sender.rs @@ -0,0 +1,91 @@ +use std::{ + net::{Ipv4Addr, SocketAddr, SocketAddrV4}, + time::Duration, +}; + +use irpc::{ + channel::spsc, + util::{make_client_endpoint, make_server_endpoint}, +}; +use quinn::Endpoint; +use testresult::TestResult; +use tokio::time::timeout; + +fn create_connected_endpoints() -> TestResult<(Endpoint, Endpoint, SocketAddr)> { + let addr = SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0).into(); + let (server, cert) = make_server_endpoint(addr)?; + let client = make_client_endpoint(addr, &[cert.as_slice()])?; + let port = server.local_addr()?.port(); + let server_addr = SocketAddrV4::new(Ipv4Addr::LOCALHOST, port).into(); + Ok((server, client, server_addr)) +} + +/// Checks that all clones of a `Sender` will get the closed signal as soon as +/// a send fails with an io error. +#[tokio::test] +async fn mpsc_sender_clone_closed_error() -> TestResult<()> { + let (server, client, server_addr) = create_connected_endpoints()?; + // accept a single bidi stream on a single connection, then immediately stop it + let server = tokio::spawn(async move { + let conn = server.accept().await.unwrap().await?; + let (_, mut recv) = conn.accept_bi().await?; + recv.stop(1u8.into())?; + TestResult::Ok(()) + }); + let conn = client.connect(server_addr, "localhost")?.await?; + let (send, _) = conn.open_bi().await?; + let send1 = spsc::Sender::>::from(send); + let send2 = send1.clone(); + let second_client = tokio::spawn(async move { + send2.closed().await; + }); + // send until we get an error because the remote side stopped the stream + while send1.send(vec![1, 2, 3]).await.is_ok() {} + // check that closed signal was received by the other sender + second_client.await?; + // server should finish without errors + server.await??; + Ok(()) +} + +/// Checks that all clones of a `Sender` will get the closed signal as soon as +/// a send future gets dropped before completing. +#[tokio::test] +async fn mpsc_sender_clone_drop_error() -> TestResult<()> { + let (server, client, server_addr) = create_connected_endpoints()?; + // accept a single bidi stream on a single connection, then read indefinitely + // until we get an error + let server = tokio::spawn(async move { + let conn = server.accept().await.unwrap().await?; + let (_, mut recv) = conn.accept_bi().await?; + let mut buf = vec![0u8; 1024]; + while recv.read(&mut buf).await.is_ok() {} + TestResult::Ok(()) + }); + let conn = client.connect(server_addr, "localhost")?.await?; + let (send, _) = conn.open_bi().await?; + let send1 = spsc::Sender::>::from(send); + let send2 = send1.clone(); + let second_client = tokio::spawn(async move { + send2.closed().await; + // why do I have to do this? + // Shouldn't dropping the quinn:SendStream call finish, so the server would get an io error? + conn.close(1u8.into(), b""); + }); + // send a lot of data with a tiny timeout, this will cause the send future to be dropped + loop { + let send_future = send1.send(vec![0u8; 1024 * 1024]); + // not sure if there is a better way. I want to poll the future a few times so it has time to + // start sending, but don't want to give it enough time to complete. + // I don't think now_or_never would work, since it wouldn't have time to start sending + if timeout(Duration::from_micros(1), send_future) + .await + .is_err() + { + break; + } + } + server.await??; + second_client.await?; + Ok(()) +} From fb3d2569edfe109e5675335c22b6433fbebbb5c3 Mon Sep 17 00:00:00 2001 From: Ruediger Klaehn Date: Fri, 13 Jun 2025 11:38:49 +0200 Subject: [PATCH 07/17] Add tests that cloned senders behave correctly --- src/lib.rs | 10 ++-------- tests/mpsc_sender.rs | 33 +++++++++++++++++++++++++++++++-- 2 files changed, 33 insertions(+), 10 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index f86421e..7fa18b3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1433,10 +1433,7 @@ pub mod rpc { } res } - QuinnSenderState::Closed => Err(io::Error::new( - io::ErrorKind::BrokenPipe, - SendError::ReceiverClosed, - )), + QuinnSenderState::Closed => Err(io::ErrorKind::BrokenPipe.into()), } }) } @@ -1456,10 +1453,7 @@ pub mod rpc { } res } - QuinnSenderState::Closed => Err(io::Error::new( - io::ErrorKind::BrokenPipe, - SendError::ReceiverClosed, - )), + QuinnSenderState::Closed => Err(io::ErrorKind::BrokenPipe.into()), } }) } diff --git a/tests/mpsc_sender.rs b/tests/mpsc_sender.rs index aa66a9d..eb22b20 100644 --- a/tests/mpsc_sender.rs +++ b/tests/mpsc_sender.rs @@ -1,10 +1,11 @@ use std::{ + io::ErrorKind, net::{Ipv4Addr, SocketAddr, SocketAddrV4}, time::Duration, }; use irpc::{ - channel::spsc, + channel::{spsc, SendError}, util::{make_client_endpoint, make_server_endpoint}, }; use quinn::Endpoint; @@ -24,6 +25,7 @@ fn create_connected_endpoints() -> TestResult<(Endpoint, Endpoint, SocketAddr)> /// a send fails with an io error. #[tokio::test] async fn mpsc_sender_clone_closed_error() -> TestResult<()> { + tracing_subscriber::fmt::try_init().ok(); let (server, client, server_addr) = create_connected_endpoints()?; // accept a single bidi stream on a single connection, then immediately stop it let server = tokio::spawn(async move { @@ -36,13 +38,29 @@ async fn mpsc_sender_clone_closed_error() -> TestResult<()> { let (send, _) = conn.open_bi().await?; let send1 = spsc::Sender::>::from(send); let send2 = send1.clone(); + let send3 = send1.clone(); let second_client = tokio::spawn(async move { send2.closed().await; }); + let third_client = tokio::spawn(async move { + // this should fail with an io error, since the stream was stopped + loop { + match send3.send(vec![1, 2, 3]).await { + Err(SendError::Io(e)) if e.kind() == ErrorKind::BrokenPipe => break, + _ => {} + }; + } + }); // send until we get an error because the remote side stopped the stream while send1.send(vec![1, 2, 3]).await.is_ok() {} - // check that closed signal was received by the other sender + match send1.send(vec![4, 5, 6]).await { + Err(SendError::Io(e)) if e.kind() == ErrorKind::BrokenPipe => {} + e => panic!("Expected SendError::Io with kind BrokenPipe, got {:?}", e), + }; + // check that closed signal was received by the second sender second_client.await?; + // check that the third sender will get the right kind of io error eventually + third_client.await?; // server should finish without errors server.await??; Ok(()) @@ -66,12 +84,22 @@ async fn mpsc_sender_clone_drop_error() -> TestResult<()> { let (send, _) = conn.open_bi().await?; let send1 = spsc::Sender::>::from(send); let send2 = send1.clone(); + let send3 = send1.clone(); let second_client = tokio::spawn(async move { send2.closed().await; // why do I have to do this? // Shouldn't dropping the quinn:SendStream call finish, so the server would get an io error? conn.close(1u8.into(), b""); }); + let third_client = tokio::spawn(async move { + // this should fail with an io error, since the stream was stopped + loop { + match send3.send(vec![1, 2, 3]).await { + Err(SendError::Io(e)) if e.kind() == ErrorKind::BrokenPipe => break, + _ => {} + }; + } + }); // send a lot of data with a tiny timeout, this will cause the send future to be dropped loop { let send_future = send1.send(vec![0u8; 1024 * 1024]); @@ -87,5 +115,6 @@ async fn mpsc_sender_clone_drop_error() -> TestResult<()> { } server.await??; second_client.await?; + third_client.await?; Ok(()) } From f3e0ac8bbdbc730fb78b3db3c69233f230ec8d04 Mon Sep 17 00:00:00 2001 From: Ruediger Klaehn Date: Fri, 13 Jun 2025 11:58:17 +0200 Subject: [PATCH 08/17] read until error or finished --- tests/mpsc_sender.rs | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/tests/mpsc_sender.rs b/tests/mpsc_sender.rs index eb22b20..c9a58da 100644 --- a/tests/mpsc_sender.rs +++ b/tests/mpsc_sender.rs @@ -72,12 +72,12 @@ async fn mpsc_sender_clone_closed_error() -> TestResult<()> { async fn mpsc_sender_clone_drop_error() -> TestResult<()> { let (server, client, server_addr) = create_connected_endpoints()?; // accept a single bidi stream on a single connection, then read indefinitely - // until we get an error + // until we get an error or the stream is finished let server = tokio::spawn(async move { let conn = server.accept().await.unwrap().await?; let (_, mut recv) = conn.accept_bi().await?; let mut buf = vec![0u8; 1024]; - while recv.read(&mut buf).await.is_ok() {} + while let Ok(Some(_)) = recv.read(&mut buf).await {} TestResult::Ok(()) }); let conn = client.connect(server_addr, "localhost")?.await?; @@ -87,9 +87,6 @@ async fn mpsc_sender_clone_drop_error() -> TestResult<()> { let send3 = send1.clone(); let second_client = tokio::spawn(async move { send2.closed().await; - // why do I have to do this? - // Shouldn't dropping the quinn:SendStream call finish, so the server would get an io error? - conn.close(1u8.into(), b""); }); let third_client = tokio::spawn(async move { // this should fail with an io error, since the stream was stopped From 10dc0973908e0bde6870870f5bb3a05cf8ae8871 Mon Sep 17 00:00:00 2001 From: Frando Date: Thu, 19 Jun 2025 11:16:35 +0200 Subject: [PATCH 09/17] docs: add docs about cancellation safety of Sender --- src/lib.rs | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/src/lib.rs b/src/lib.rs index 7fa18b3..27dbaf1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -446,6 +446,13 @@ pub mod channel { impl Sender { /// Send a message and yield until either it is sent or an error occurs. + /// + /// ## Cancellation safety + /// + /// If the future is dropped before completion, and if this is a remote sender, + /// then the sender will be closed and further sends will return an [`io::Error`] + /// with [`io::ErrorKind::BrokenPipe`]. Therefore, make sure to always poll the + /// future until completion if you want to reuse the sender or any clone afterwards. pub async fn send(&self, value: T) -> std::result::Result<(), SendError> { match self { Sender::Tokio(tx) => { @@ -469,6 +476,13 @@ pub mod channel { /// all. /// /// Returns true if the message was sent. + /// + /// ## Cancellation safety + /// + /// If the future is dropped before completion, and if this is a remote sender, + /// then the sender will be closed and further sends will return an [`io::Error`] + /// with [`io::ErrorKind::BrokenPipe`]. Therefore, make sure to always poll the + /// future until completion if you want to reuse the sender or any clone afterwards. pub async fn try_send(&mut self, value: T) -> std::result::Result { match self { Sender::Tokio(tx) => match tx.try_send(value) { From 1a630b0259545bab1d92f96da76d081d3a895143 Mon Sep 17 00:00:00 2001 From: Frando Date: Thu, 19 Jun 2025 12:17:09 +0200 Subject: [PATCH 10/17] chore: clippy --- irpc-iroh/examples/auth.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/irpc-iroh/examples/auth.rs b/irpc-iroh/examples/auth.rs index 88944a7..6558bb3 100644 --- a/irpc-iroh/examples/auth.rs +++ b/irpc-iroh/examples/auth.rs @@ -218,7 +218,7 @@ mod storage { } StorageMessage::List(list) => { info!("list {:?}", list); - let WithChannels { mut tx, .. } = list; + let WithChannels { tx, .. } = list; let values = { let state = self.state.lock().unwrap(); // TODO: use async lock to not clone here. From a049272645e3bbee650ae93a14c9baf44c46bd6e Mon Sep 17 00:00:00 2001 From: Frando Date: Thu, 19 Jun 2025 12:21:58 +0200 Subject: [PATCH 11/17] refactor: rename spsc module to mpsc --- examples/compute.rs | 22 +++++++-------- examples/derive.rs | 10 +++---- examples/storage.rs | 8 +++--- irpc-iroh/examples/auth.rs | 8 +++--- irpc-iroh/examples/derive.rs | 6 ++-- src/lib.rs | 54 ++++++++++++++++++------------------ tests/mpsc_sender.rs | 6 ++-- 7 files changed, 57 insertions(+), 57 deletions(-) diff --git a/examples/compute.rs b/examples/compute.rs index 95e4f45..f8ac923 100644 --- a/examples/compute.rs +++ b/examples/compute.rs @@ -7,7 +7,7 @@ use std::{ use anyhow::bail; use futures_buffered::BufferedStreamExt; use irpc::{ - channel::{oneshot, spsc}, + channel::{oneshot, mpsc}, rpc::{listen, Handler}, rpc_requests, util::{make_client_endpoint, make_server_endpoint}, @@ -61,11 +61,11 @@ enum ComputeRequest { enum ComputeProtocol { #[rpc(tx=oneshot::Sender)] Sqr(Sqr), - #[rpc(rx=spsc::Receiver, tx=oneshot::Sender)] + #[rpc(rx=mpsc::Receiver, tx=oneshot::Sender)] Sum(Sum), - #[rpc(tx=spsc::Sender)] + #[rpc(tx=mpsc::Sender)] Fibonacci(Fibonacci), - #[rpc(rx=spsc::Receiver, tx=spsc::Sender)] + #[rpc(rx=mpsc::Receiver, tx=mpsc::Sender)] Multiply(Multiply), } @@ -200,11 +200,11 @@ impl ComputeApi { } } - pub async fn sum(&self) -> anyhow::Result<(spsc::Sender, oneshot::Receiver)> { + pub async fn sum(&self) -> anyhow::Result<(mpsc::Sender, oneshot::Receiver)> { let msg = Sum; match self.inner.request().await? { Request::Local(request) => { - let (num_tx, num_rx) = spsc::channel(10); + let (num_tx, num_rx) = mpsc::channel(10); let (sum_tx, sum_rx) = oneshot::channel(); request.send((msg, sum_tx, num_rx)).await?; Ok((num_tx, sum_rx)) @@ -216,11 +216,11 @@ impl ComputeApi { } } - pub async fn fibonacci(&self, max: u64) -> anyhow::Result> { + pub async fn fibonacci(&self, max: u64) -> anyhow::Result> { let msg = Fibonacci { max }; match self.inner.request().await? { Request::Local(request) => { - let (tx, rx) = spsc::channel(128); + let (tx, rx) = mpsc::channel(128); request.send((msg, tx)).await?; Ok(rx) } @@ -234,12 +234,12 @@ impl ComputeApi { pub async fn multiply( &self, initial: u64, - ) -> anyhow::Result<(spsc::Sender, spsc::Receiver)> { + ) -> anyhow::Result<(mpsc::Sender, mpsc::Receiver)> { let msg = Multiply { initial }; match self.inner.request().await? { Request::Local(request) => { - let (in_tx, in_rx) = spsc::channel(128); - let (out_tx, out_rx) = spsc::channel(128); + let (in_tx, in_rx) = mpsc::channel(128); + let (out_tx, out_rx) = mpsc::channel(128); request.send((msg, out_tx, in_rx)).await?; Ok((in_tx, out_rx)) } diff --git a/examples/derive.rs b/examples/derive.rs index 0f482e4..80ebc05 100644 --- a/examples/derive.rs +++ b/examples/derive.rs @@ -6,7 +6,7 @@ use std::{ use anyhow::{Context, Result}; use irpc::{ - channel::{oneshot, spsc}, + channel::{oneshot, mpsc}, rpc::Handler, rpc_requests, util::{make_client_endpoint, make_server_endpoint}, @@ -55,9 +55,9 @@ enum StorageProtocol { Get(Get), #[rpc(tx=oneshot::Sender<()>)] Set(Set), - #[rpc(tx=oneshot::Sender, rx=spsc::Receiver<(String, String)>)] + #[rpc(tx=oneshot::Sender, rx=mpsc::Receiver<(String, String)>)] SetMany(SetMany), - #[rpc(tx=spsc::Sender)] + #[rpc(tx=mpsc::Sender)] List(List), } @@ -152,7 +152,7 @@ impl StorageApi { self.inner.rpc(Get { key }).await } - pub async fn list(&self) -> irpc::Result> { + pub async fn list(&self) -> irpc::Result> { self.inner.server_streaming(List, 16).await } @@ -162,7 +162,7 @@ impl StorageApi { pub async fn set_many( &self, - ) -> irpc::Result<(spsc::Sender<(String, String)>, oneshot::Receiver)> { + ) -> irpc::Result<(mpsc::Sender<(String, String)>, oneshot::Receiver)> { self.inner.client_streaming(SetMany, 4).await } } diff --git a/examples/storage.rs b/examples/storage.rs index 29b07b1..df1c076 100644 --- a/examples/storage.rs +++ b/examples/storage.rs @@ -6,7 +6,7 @@ use std::{ use anyhow::bail; use irpc::{ - channel::{none::NoReceiver, oneshot, spsc}, + channel::{none::NoReceiver, oneshot, mpsc}, rpc::{listen, Handler}, util::{make_client_endpoint, make_server_endpoint}, Channels, Client, LocalSender, Request, Service, WithChannels, @@ -36,7 +36,7 @@ struct List; impl Channels for List { type Rx = NoReceiver; - type Tx = spsc::Sender; + type Tx = mpsc::Sender; } #[derive(Debug, Serialize, Deserialize)] @@ -157,11 +157,11 @@ impl StorageApi { } } - pub async fn list(&self) -> anyhow::Result> { + pub async fn list(&self) -> anyhow::Result> { let msg = List; match self.inner.request().await? { Request::Local(request) => { - let (tx, rx) = spsc::channel(10); + let (tx, rx) = mpsc::channel(10); request.send((msg, tx)).await?; Ok(rx) } diff --git a/irpc-iroh/examples/auth.rs b/irpc-iroh/examples/auth.rs index 6558bb3..48f9dcb 100644 --- a/irpc-iroh/examples/auth.rs +++ b/irpc-iroh/examples/auth.rs @@ -72,7 +72,7 @@ mod storage { Endpoint, }; use irpc::{ - channel::{oneshot, spsc}, + channel::{oneshot, mpsc}, Client, Service, WithChannels, }; // Import the macro @@ -122,9 +122,9 @@ mod storage { Get(Get), #[rpc(tx=oneshot::Sender<()>)] Set(Set), - #[rpc(tx=oneshot::Sender, rx=spsc::Receiver<(String, String)>)] + #[rpc(tx=oneshot::Sender, rx=mpsc::Receiver<(String, String)>)] SetMany(SetMany), - #[rpc(tx=spsc::Sender)] + #[rpc(tx=mpsc::Sender)] List(List), } @@ -265,7 +265,7 @@ mod storage { self.inner.rpc(Get { key }).await } - pub async fn list(&self) -> Result, irpc::Error> { + pub async fn list(&self) -> Result, irpc::Error> { self.inner.server_streaming(List, 10).await } diff --git a/irpc-iroh/examples/derive.rs b/irpc-iroh/examples/derive.rs index b381cb7..b5db905 100644 --- a/irpc-iroh/examples/derive.rs +++ b/irpc-iroh/examples/derive.rs @@ -60,7 +60,7 @@ mod storage { use anyhow::{Context, Result}; use iroh::{protocol::ProtocolHandler, Endpoint}; use irpc::{ - channel::{oneshot, spsc}, + channel::{oneshot, mpsc}, rpc::Handler, rpc_requests, Client, LocalSender, Service, WithChannels, }; @@ -97,7 +97,7 @@ mod storage { Get(Get), #[rpc(tx=oneshot::Sender<()>)] Set(Set), - #[rpc(tx=spsc::Sender)] + #[rpc(tx=mpsc::Sender)] List(List), } @@ -190,7 +190,7 @@ mod storage { self.inner.rpc(Get { key }).await } - pub async fn list(&self) -> irpc::Result> { + pub async fn list(&self) -> irpc::Result> { self.inner.server_streaming(List, 10).await } diff --git a/src/lib.rs b/src/lib.rs index 27dbaf1..50c7b46 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -127,9 +127,9 @@ pub trait Receiver: Debug + Sealed {} /// Trait to specify channels for a message and service pub trait Channels { - /// The sender type, can be either spsc, oneshot or none + /// The sender type, can be either mpsc, oneshot or none type Tx: Sender; - /// The receiver type, can be either spsc, oneshot or none + /// The receiver type, can be either mpsc, oneshot or none /// /// For many services, the receiver is not needed, so it can be set to [`NoReceiver`]. type Rx: Receiver; @@ -315,14 +315,14 @@ pub mod channel { /// SPSC channel, similar to tokio's mpsc channel /// - /// For the rpc case, the send side can not be cloned, hence spsc instead of mpsc. - pub mod spsc { + /// For the rpc case, the send side can not be cloned, hence mpsc instead of mpsc. + pub mod mpsc { use std::{fmt::Debug, future::Future, io, pin::Pin, sync::Arc}; use super::{RecvError, SendError}; use crate::RpcMessage; - /// Create a local spsc sender and receiver pair, with the given buffer size. + /// Create a local mpsc sender and receiver pair, with the given buffer size. /// /// This is currently using a tokio channel pair internally. pub fn channel(buffer: usize) -> (Sender, Receiver) { @@ -582,7 +582,7 @@ pub mod channel { impl crate::Receiver for NoReceiver {} } - /// Error when sending a oneshot or spsc message. For local communication, + /// Error when sending a oneshot or mpsc message. For local communication, /// the only thing that can go wrong is that the receiver has been dropped. /// /// For rpc communication, there can be any number of errors, so this is a @@ -608,7 +608,7 @@ pub mod channel { } } - /// Error when receiving a oneshot or spsc message. For local communication, + /// Error when receiving a oneshot or mpsc message. For local communication, /// the only thing that can go wrong is that the sender has been closed. /// /// For rpc communication, there can be any number of errors, so this is a @@ -871,24 +871,24 @@ impl Client { } } - /// Performs a request for which the server returns a spsc receiver. + /// Performs a request for which the server returns a mpsc receiver. pub fn server_streaming( &self, msg: Req, local_response_cap: usize, - ) -> impl Future>> + Send + 'static + ) -> impl Future>> + Send + 'static where S: Service, M: From> + Send + Sync + Unpin + 'static, R: From + Serialize + Send + Sync + 'static, - Req: Channels, Rx = NoReceiver> + Send + 'static, + Req: Channels, Rx = NoReceiver> + Send + 'static, Res: RpcMessage, { let request = self.request(); async move { - let recv: channel::spsc::Receiver = match request.await? { + let recv: channel::mpsc::Receiver = match request.await? { Request::Local(request) => { - let (tx, rx) = channel::spsc::channel(local_response_cap); + let (tx, rx) = channel::mpsc::channel(local_response_cap); request.send((msg, tx)).await?; rx } @@ -911,7 +911,7 @@ impl Client { local_update_cap: usize, ) -> impl Future< Output = Result<( - channel::spsc::Sender, + channel::mpsc::Sender, channel::oneshot::Receiver, )>, > @@ -919,18 +919,18 @@ impl Client { S: Service, M: From> + Send + Sync + Unpin + 'static, R: From + Serialize + 'static, - Req: Channels, Rx = channel::spsc::Receiver>, + Req: Channels, Rx = channel::mpsc::Receiver>, Update: RpcMessage, Res: RpcMessage, { let request = self.request(); async move { let (update_tx, res_rx): ( - channel::spsc::Sender, + channel::mpsc::Sender, channel::oneshot::Receiver, ) = match request.await? { Request::Local(request) => { - let (req_tx, req_rx) = channel::spsc::channel(local_update_cap); + let (req_tx, req_rx) = channel::mpsc::channel(local_update_cap); let (res_tx, res_rx) = channel::oneshot::channel(); request.send((msg, res_tx, req_rx)).await?; (req_tx, res_rx) @@ -947,20 +947,20 @@ impl Client { } } - /// Performs a request for which the client can send updates, and the server returns a spsc receiver. + /// Performs a request for which the client can send updates, and the server returns a mpsc receiver. pub fn bidi_streaming( &self, msg: Req, local_update_cap: usize, local_response_cap: usize, - ) -> impl Future, channel::spsc::Receiver)>> + ) -> impl Future, channel::mpsc::Receiver)>> + Send + 'static where S: Service, M: From> + Send + Sync + Unpin + 'static, R: From + Serialize + Send + 'static, - Req: Channels, Rx = channel::spsc::Receiver> + Req: Channels, Rx = channel::mpsc::Receiver> + Send + 'static, Update: RpcMessage, @@ -968,11 +968,11 @@ impl Client { { let request = self.request(); async move { - let (update_tx, res_rx): (channel::spsc::Sender, channel::spsc::Receiver) = + let (update_tx, res_rx): (channel::mpsc::Sender, channel::mpsc::Receiver) = match request.await? { Request::Local(request) => { - let (update_tx, update_rx) = channel::spsc::channel(local_update_cap); - let (res_tx, res_rx) = channel::spsc::channel(local_response_cap); + let (update_tx, update_rx) = channel::mpsc::channel(local_update_cap); + let (res_tx, res_rx) = channel::mpsc::channel(local_response_cap); request.send((msg, res_tx, update_rx)).await?; (update_tx, res_rx) } @@ -1120,7 +1120,7 @@ pub mod rpc { channel::{ none::NoSender, oneshot, - spsc::{self, DynReceiver, DynSender}, + mpsc::{self, DynReceiver, DynSender}, RecvError, SendError, }, util::{now_or_never, AsyncReadVarintExt, WriteVarintExt}, @@ -1290,9 +1290,9 @@ pub mod rpc { } } - impl From for spsc::Receiver { + impl From for mpsc::Receiver { fn from(read: quinn::RecvStream) -> Self { - spsc::Receiver::Boxed(Box::new(QuinnReceiver { + mpsc::Receiver::Boxed(Box::new(QuinnReceiver { recv: read, _marker: PhantomData, })) @@ -1320,9 +1320,9 @@ pub mod rpc { } } - impl From for spsc::Sender { + impl From for mpsc::Sender { fn from(write: quinn::SendStream) -> Self { - spsc::Sender::Boxed(Arc::new(QuinnSender(tokio::sync::Mutex::new( + mpsc::Sender::Boxed(Arc::new(QuinnSender(tokio::sync::Mutex::new( QuinnSenderState::Open(QuinnSenderInner { send: write, buffer: SmallVec::new(), diff --git a/tests/mpsc_sender.rs b/tests/mpsc_sender.rs index c9a58da..e8382bb 100644 --- a/tests/mpsc_sender.rs +++ b/tests/mpsc_sender.rs @@ -5,7 +5,7 @@ use std::{ }; use irpc::{ - channel::{spsc, SendError}, + channel::{mpsc, SendError}, util::{make_client_endpoint, make_server_endpoint}, }; use quinn::Endpoint; @@ -36,7 +36,7 @@ async fn mpsc_sender_clone_closed_error() -> TestResult<()> { }); let conn = client.connect(server_addr, "localhost")?.await?; let (send, _) = conn.open_bi().await?; - let send1 = spsc::Sender::>::from(send); + let send1 = mpsc::Sender::>::from(send); let send2 = send1.clone(); let send3 = send1.clone(); let second_client = tokio::spawn(async move { @@ -82,7 +82,7 @@ async fn mpsc_sender_clone_drop_error() -> TestResult<()> { }); let conn = client.connect(server_addr, "localhost")?.await?; let (send, _) = conn.open_bi().await?; - let send1 = spsc::Sender::>::from(send); + let send1 = mpsc::Sender::>::from(send); let send2 = send1.clone(); let send3 = send1.clone(); let second_client = tokio::spawn(async move { From 6eea8ed551a45091a8ce0d9646367c18a93b39e1 Mon Sep 17 00:00:00 2001 From: Frando Date: Thu, 19 Jun 2025 13:03:10 +0200 Subject: [PATCH 12/17] chore: fmt --- examples/compute.rs | 2 +- examples/derive.rs | 2 +- examples/storage.rs | 2 +- irpc-iroh/examples/auth.rs | 2 +- irpc-iroh/examples/derive.rs | 2 +- src/lib.rs | 5 ++--- 6 files changed, 7 insertions(+), 8 deletions(-) diff --git a/examples/compute.rs b/examples/compute.rs index f8ac923..de00a38 100644 --- a/examples/compute.rs +++ b/examples/compute.rs @@ -7,7 +7,7 @@ use std::{ use anyhow::bail; use futures_buffered::BufferedStreamExt; use irpc::{ - channel::{oneshot, mpsc}, + channel::{mpsc, oneshot}, rpc::{listen, Handler}, rpc_requests, util::{make_client_endpoint, make_server_endpoint}, diff --git a/examples/derive.rs b/examples/derive.rs index 80ebc05..9d4324a 100644 --- a/examples/derive.rs +++ b/examples/derive.rs @@ -6,7 +6,7 @@ use std::{ use anyhow::{Context, Result}; use irpc::{ - channel::{oneshot, mpsc}, + channel::{mpsc, oneshot}, rpc::Handler, rpc_requests, util::{make_client_endpoint, make_server_endpoint}, diff --git a/examples/storage.rs b/examples/storage.rs index df1c076..0909494 100644 --- a/examples/storage.rs +++ b/examples/storage.rs @@ -6,7 +6,7 @@ use std::{ use anyhow::bail; use irpc::{ - channel::{none::NoReceiver, oneshot, mpsc}, + channel::{mpsc, none::NoReceiver, oneshot}, rpc::{listen, Handler}, util::{make_client_endpoint, make_server_endpoint}, Channels, Client, LocalSender, Request, Service, WithChannels, diff --git a/irpc-iroh/examples/auth.rs b/irpc-iroh/examples/auth.rs index 48f9dcb..5e952d8 100644 --- a/irpc-iroh/examples/auth.rs +++ b/irpc-iroh/examples/auth.rs @@ -72,7 +72,7 @@ mod storage { Endpoint, }; use irpc::{ - channel::{oneshot, mpsc}, + channel::{mpsc, oneshot}, Client, Service, WithChannels, }; // Import the macro diff --git a/irpc-iroh/examples/derive.rs b/irpc-iroh/examples/derive.rs index b5db905..65881e0 100644 --- a/irpc-iroh/examples/derive.rs +++ b/irpc-iroh/examples/derive.rs @@ -60,7 +60,7 @@ mod storage { use anyhow::{Context, Result}; use iroh::{protocol::ProtocolHandler, Endpoint}; use irpc::{ - channel::{oneshot, mpsc}, + channel::{mpsc, oneshot}, rpc::Handler, rpc_requests, Client, LocalSender, Service, WithChannels, }; diff --git a/src/lib.rs b/src/lib.rs index 50c7b46..0372441 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1118,10 +1118,9 @@ pub mod rpc { use crate::{ channel::{ - none::NoSender, - oneshot, mpsc::{self, DynReceiver, DynSender}, - RecvError, SendError, + none::NoSender, + oneshot, RecvError, SendError, }, util::{now_or_never, AsyncReadVarintExt, WriteVarintExt}, RequestError, RpcMessage, From bbc2f6a4574c7fddc7cc140656dc2f65036e2130 Mon Sep 17 00:00:00 2001 From: Frando Date: Thu, 19 Jun 2025 12:35:23 +0200 Subject: [PATCH 13/17] refactor: rename tx/rx to reply/updates and WithChannels to Request --- examples/compute.rs | 133 ++++---- examples/derive.rs | 56 ++-- examples/storage.rs | 90 +++--- irpc-derive/src/lib.rs | 78 ++--- irpc-iroh/examples/auth.rs | 65 ++-- irpc-iroh/examples/derive.rs | 36 +-- irpc-iroh/src/lib.rs | 10 +- src/lib.rs | 337 +++++++++++---------- src/util.rs | 2 +- tests/compile_fail/extra_attr_types.rs | 2 +- tests/compile_fail/extra_attr_types.stderr | 2 +- tests/compile_fail/wrong_attr_types.stderr | 2 +- tests/derive.rs | 10 +- 13 files changed, 433 insertions(+), 390 deletions(-) diff --git a/examples/compute.rs b/examples/compute.rs index de00a38..f920fd0 100644 --- a/examples/compute.rs +++ b/examples/compute.rs @@ -11,7 +11,7 @@ use irpc::{ rpc::{listen, Handler}, rpc_requests, util::{make_client_endpoint, make_server_endpoint}, - Client, LocalSender, Request, Service, WithChannels, + Client, LocalSender, Request, RequestSender, Service, }; use n0_future::{ stream::StreamExt, @@ -59,13 +59,13 @@ enum ComputeRequest { #[rpc_requests(ComputeService, message = ComputeMessage)] #[derive(Serialize, Deserialize)] enum ComputeProtocol { - #[rpc(tx=oneshot::Sender)] + #[rpc(reply=oneshot::Sender)] Sqr(Sqr), - #[rpc(rx=mpsc::Receiver, tx=oneshot::Sender)] + #[rpc(updates=mpsc::Receiver, reply=oneshot::Sender)] Sum(Sum), - #[rpc(tx=mpsc::Sender)] + #[rpc(reply=mpsc::Sender)] Fibonacci(Fibonacci), - #[rpc(rx=mpsc::Receiver, tx=mpsc::Sender)] + #[rpc(updates=mpsc::Receiver, reply=mpsc::Sender)] Multiply(Multiply), } @@ -76,10 +76,10 @@ struct ComputeActor { impl ComputeActor { pub fn local() -> ComputeApi { - let (tx, rx) = tokio::sync::mpsc::channel(128); - let actor = Self { recv: rx }; + let (reply, request) = tokio::sync::mpsc::channel(128); + let actor = Self { recv: request }; n0_future::task::spawn(actor.run()); - let local = LocalSender::::from(tx); + let local = LocalSender::::from(reply); ComputeApi { inner: local.into(), } @@ -99,34 +99,45 @@ impl ComputeActor { match msg { ComputeMessage::Sqr(sqr) => { trace!("sqr {:?}", sqr); - let WithChannels { - tx, inner, span, .. + let Request { + reply, + message, + span, + .. } = sqr; let _entered = span.enter(); - let result = (inner.num as u128) * (inner.num as u128); - tx.send(result).await?; + let result = (message.num as u128) * (message.num as u128); + reply.send(result).await?; } ComputeMessage::Sum(sum) => { trace!("sum {:?}", sum); - let WithChannels { rx, tx, span, .. } = sum; + let Request { + updates, + reply, + span, + .. + } = sum; let _entered = span.enter(); - let mut receiver = rx; + let mut receiver = updates; let mut total = 0; while let Some(num) = receiver.recv().await? { total += num; } - tx.send(total).await?; + reply.send(total).await?; } ComputeMessage::Fibonacci(fib) => { trace!("fibonacci {:?}", fib); - let WithChannels { - tx, inner, span, .. + let Request { + reply, + message, + span, + .. } = fib; let _entered = span.enter(); - let sender = tx; + let sender = reply; let mut a = 0u64; let mut b = 1u64; - while a <= inner.max { + while a <= message.max { sender.send(a).await?; let next = a + b; a = b; @@ -135,17 +146,17 @@ impl ComputeActor { } ComputeMessage::Multiply(mult) => { trace!("multiply {:?}", mult); - let WithChannels { - rx, - tx, - inner, + let Request { + updates, + reply, + message, span, .. } = mult; let _entered = span.enter(); - let mut receiver = rx; - let sender = tx; - let multiplier = inner.initial; + let mut receiver = updates; + let sender = reply; + let multiplier = message.initial; while let Some(num) = receiver.recv().await? { sender.send(multiplier * num).await?; } @@ -171,13 +182,13 @@ impl ComputeApi { let Some(local) = self.inner.local() else { bail!("cannot listen on a remote service"); }; - let handler: Handler = Arc::new(move |msg, rx, tx| { + let handler: Handler = Arc::new(move |msg, request, reply| { let local = local.clone(); Box::pin(match msg { - ComputeProtocol::Sqr(msg) => local.send((msg, tx)), - ComputeProtocol::Sum(msg) => local.send((msg, tx, rx)), - ComputeProtocol::Fibonacci(msg) => local.send((msg, tx)), - ComputeProtocol::Multiply(msg) => local.send((msg, tx, rx)), + ComputeProtocol::Sqr(msg) => local.send((msg, reply)), + ComputeProtocol::Sum(msg) => local.send((msg, reply, request)), + ComputeProtocol::Fibonacci(msg) => local.send((msg, reply)), + ComputeProtocol::Multiply(msg) => local.send((msg, reply, request)), }) }); Ok(AbortOnDropHandle::new(task::spawn(listen( @@ -188,14 +199,14 @@ impl ComputeApi { pub async fn sqr(&self, num: u64) -> anyhow::Result> { let msg = Sqr { num }; match self.inner.request().await? { - Request::Local(request) => { - let (tx, rx) = oneshot::channel(); - request.send((msg, tx)).await?; - Ok(rx) + RequestSender::Local(sender) => { + let (reply, request) = oneshot::channel(); + sender.send((msg, reply)).await?; + Ok(request) } - Request::Remote(request) => { - let (_tx, rx) = request.write(msg).await?; - Ok(rx.into()) + RequestSender::Remote(sender) => { + let (_reply, request) = sender.write(msg).await?; + Ok(request.into()) } } } @@ -203,15 +214,15 @@ impl ComputeApi { pub async fn sum(&self) -> anyhow::Result<(mpsc::Sender, oneshot::Receiver)> { let msg = Sum; match self.inner.request().await? { - Request::Local(request) => { - let (num_tx, num_rx) = mpsc::channel(10); - let (sum_tx, sum_rx) = oneshot::channel(); - request.send((msg, sum_tx, num_rx)).await?; - Ok((num_tx, sum_rx)) + RequestSender::Local(sender) => { + let (num_reply, num_request) = mpsc::channel(10); + let (sum_reply, sum_request) = oneshot::channel(); + sender.send((msg, sum_reply, num_request)).await?; + Ok((num_reply, sum_request)) } - Request::Remote(request) => { - let (tx, rx) = request.write(msg).await?; - Ok((tx.into(), rx.into())) + RequestSender::Remote(sender) => { + let (reply, request) = sender.write(msg).await?; + Ok((reply.into(), request.into())) } } } @@ -219,14 +230,14 @@ impl ComputeApi { pub async fn fibonacci(&self, max: u64) -> anyhow::Result> { let msg = Fibonacci { max }; match self.inner.request().await? { - Request::Local(request) => { - let (tx, rx) = mpsc::channel(128); - request.send((msg, tx)).await?; - Ok(rx) + RequestSender::Local(sender) => { + let (reply, request) = mpsc::channel(128); + sender.send((msg, reply)).await?; + Ok(request) } - Request::Remote(request) => { - let (_tx, rx) = request.write(msg).await?; - Ok(rx.into()) + RequestSender::Remote(sender) => { + let (_reply, request) = sender.write(msg).await?; + Ok(request.into()) } } } @@ -237,15 +248,15 @@ impl ComputeApi { ) -> anyhow::Result<(mpsc::Sender, mpsc::Receiver)> { let msg = Multiply { initial }; match self.inner.request().await? { - Request::Local(request) => { - let (in_tx, in_rx) = mpsc::channel(128); - let (out_tx, out_rx) = mpsc::channel(128); - request.send((msg, out_tx, in_rx)).await?; - Ok((in_tx, out_rx)) + RequestSender::Local(sender) => { + let (in_reply, in_request) = mpsc::channel(128); + let (out_reply, out_request) = mpsc::channel(128); + sender.send((msg, out_reply, in_request)).await?; + Ok((in_reply, out_request)) } - Request::Remote(request) => { - let (tx, rx) = request.write(msg).await?; - Ok((tx.into(), rx.into())) + RequestSender::Remote(sender) => { + let (reply, request) = sender.write(msg).await?; + Ok((reply.into(), request.into())) } } } diff --git a/examples/derive.rs b/examples/derive.rs index 9d4324a..a8927c5 100644 --- a/examples/derive.rs +++ b/examples/derive.rs @@ -10,7 +10,7 @@ use irpc::{ rpc::Handler, rpc_requests, util::{make_client_endpoint, make_server_endpoint}, - Client, LocalSender, Service, WithChannels, + Client, LocalSender, Request, Service, }; // Import the macro use n0_future::task::{self, AbortOnDropHandle}; @@ -51,13 +51,13 @@ struct SetMany; #[rpc_requests(StorageService, message = StorageMessage)] #[derive(Serialize, Deserialize)] enum StorageProtocol { - #[rpc(tx=oneshot::Sender>)] + #[rpc(reply=oneshot::Sender>)] Get(Get), - #[rpc(tx=oneshot::Sender<()>)] + #[rpc(reply=oneshot::Sender<()>)] Set(Set), - #[rpc(tx=oneshot::Sender, rx=mpsc::Receiver<(String, String)>)] + #[rpc(reply=oneshot::Sender, updates=mpsc::Receiver<(String, String)>)] SetMany(SetMany), - #[rpc(tx=mpsc::Sender)] + #[rpc(reply=mpsc::Sender)] List(List), } @@ -68,13 +68,13 @@ struct StorageActor { impl StorageActor { pub fn spawn() -> StorageApi { - let (tx, rx) = tokio::sync::mpsc::channel(1); + let (reply, request) = tokio::sync::mpsc::channel(1); let actor = Self { - recv: rx, + recv: request, state: BTreeMap::new(), }; n0_future::task::spawn(actor.run()); - let local = LocalSender::::from(tx); + let local = LocalSender::::from(reply); StorageApi { inner: local.into(), } @@ -90,30 +90,32 @@ impl StorageActor { match msg { StorageMessage::Get(get) => { info!("get {:?}", get); - let WithChannels { tx, inner, .. } = get; - tx.send(self.state.get(&inner.key).cloned()).await.ok(); + let Request { reply, message, .. } = get; + reply.send(self.state.get(&message.key).cloned()).await.ok(); } StorageMessage::Set(set) => { info!("set {:?}", set); - let WithChannels { tx, inner, .. } = set; - self.state.insert(inner.key, inner.value); - tx.send(()).await.ok(); + let Request { reply, message, .. } = set; + self.state.insert(message.key, message.value); + reply.send(()).await.ok(); } StorageMessage::SetMany(set) => { info!("set-many {:?}", set); - let WithChannels { mut rx, tx, .. } = set; + let Request { + mut updates, reply, .. + } = set; let mut count = 0; - while let Ok(Some((key, value))) = rx.recv().await { + while let Ok(Some((key, value))) = updates.recv().await { self.state.insert(key, value); count += 1; } - tx.send(count).await.ok(); + reply.send(count).await.ok(); } StorageMessage::List(list) => { info!("list {:?}", list); - let WithChannels { tx, .. } = list; + let Request { reply, .. } = list; for (key, value) in &self.state { - if tx.send(format!("{key}={value}")).await.is_err() { + if reply.send(format!("{key}={value}")).await.is_err() { break; } } @@ -135,13 +137,13 @@ impl StorageApi { pub fn listen(&self, endpoint: quinn::Endpoint) -> Result> { let local = self.inner.local().context("cannot listen on remote API")?; - let handler: Handler = Arc::new(move |msg, rx, tx| { + let handler: Handler = Arc::new(move |msg, request, reply| { let local = local.clone(); Box::pin(match msg { - StorageProtocol::Get(msg) => local.send((msg, tx)), - StorageProtocol::Set(msg) => local.send((msg, tx)), - StorageProtocol::SetMany(msg) => local.send((msg, tx, rx)), - StorageProtocol::List(msg) => local.send((msg, tx)), + StorageProtocol::Get(msg) => local.send((msg, reply)), + StorageProtocol::Set(msg) => local.send((msg, reply)), + StorageProtocol::SetMany(msg) => local.send((msg, reply, request)), + StorageProtocol::List(msg) => local.send((msg, reply)), }) }); let join_handle = task::spawn(irpc::rpc::listen(endpoint, handler)); @@ -172,12 +174,12 @@ async fn client_demo(api: StorageApi) -> Result<()> { let value = api.get("hello".to_string()).await?; println!("get: hello = {:?}", value); - let (tx, rx) = api.set_many().await?; + let (reply, request) = api.set_many().await?; for i in 0..3 { - tx.send((format!("key{i}"), format!("value{i}"))).await?; + reply.send((format!("key{i}"), format!("value{i}"))).await?; } - drop(tx); - let count = rx.await?; + drop(reply); + let count = request.await?; println!("set-many: {count} values set"); let mut list = api.list().await?; diff --git a/examples/storage.rs b/examples/storage.rs index 0909494..33a8d8a 100644 --- a/examples/storage.rs +++ b/examples/storage.rs @@ -9,7 +9,7 @@ use irpc::{ channel::{mpsc, none::NoReceiver, oneshot}, rpc::{listen, Handler}, util::{make_client_endpoint, make_server_endpoint}, - Channels, Client, LocalSender, Request, Service, WithChannels, + Channels, Client, LocalSender, Request, RequestSender, Service, }; use n0_future::task::{self, AbortOnDropHandle}; use serde::{Deserialize, Serialize}; @@ -27,16 +27,16 @@ struct Get { } impl Channels for Get { - type Rx = NoReceiver; - type Tx = oneshot::Sender>; + type Request = NoReceiver; + type Response = oneshot::Sender>; } #[derive(Debug, Serialize, Deserialize)] struct List; impl Channels for List { - type Rx = NoReceiver; - type Tx = mpsc::Sender; + type Request = NoReceiver; + type Response = mpsc::Sender; } #[derive(Debug, Serialize, Deserialize)] @@ -46,8 +46,8 @@ struct Set { } impl Channels for Set { - type Rx = NoReceiver; - type Tx = oneshot::Sender<()>; + type Request = NoReceiver; + type Response = oneshot::Sender<()>; } #[derive(derive_more::From, Serialize, Deserialize)] @@ -59,9 +59,9 @@ enum StorageProtocol { #[derive(derive_more::From)] enum StorageMessage { - Get(WithChannels), - Set(WithChannels), - List(WithChannels), + Get(Request), + Set(Request), + List(Request), } struct StorageActor { @@ -71,13 +71,13 @@ struct StorageActor { impl StorageActor { pub fn local() -> StorageApi { - let (tx, rx) = tokio::sync::mpsc::channel(1); + let (reply, request) = tokio::sync::mpsc::channel(1); let actor = Self { - recv: rx, + recv: request, state: BTreeMap::new(), }; n0_future::task::spawn(actor.run()); - let local = LocalSender::::from(tx); + let local = LocalSender::::from(reply); StorageApi { inner: local.into(), } @@ -93,20 +93,20 @@ impl StorageActor { match msg { StorageMessage::Get(get) => { info!("get {:?}", get); - let WithChannels { tx, inner, .. } = get; - tx.send(self.state.get(&inner.key).cloned()).await.ok(); + let Request { reply, message, .. } = get; + reply.send(self.state.get(&message.key).cloned()).await.ok(); } StorageMessage::Set(set) => { info!("set {:?}", set); - let WithChannels { tx, inner, .. } = set; - self.state.insert(inner.key, inner.value); - tx.send(()).await.ok(); + let Request { reply, message, .. } = set; + self.state.insert(message.key, message.value); + reply.send(()).await.ok(); } StorageMessage::List(list) => { info!("list {:?}", list); - let WithChannels { tx, .. } = list; + let Request { reply, .. } = list; for (key, value) in &self.state { - if tx.send(format!("{key}={value}")).await.is_err() { + if reply.send(format!("{key}={value}")).await.is_err() { break; } } @@ -129,12 +129,12 @@ impl StorageApi { let Some(local) = self.inner.local() else { bail!("cannot listen on a remote service"); }; - let handler: Handler = Arc::new(move |msg, _rx, tx| { + let handler: Handler = Arc::new(move |msg, _request, reply| { let local = local.clone(); Box::pin(match msg { - StorageProtocol::Get(msg) => local.send((msg, tx)), - StorageProtocol::Set(msg) => local.send((msg, tx)), - StorageProtocol::List(msg) => local.send((msg, tx)), + StorageProtocol::Get(msg) => local.send((msg, reply)), + StorageProtocol::Set(msg) => local.send((msg, reply)), + StorageProtocol::List(msg) => local.send((msg, reply)), }) }); Ok(AbortOnDropHandle::new(task::spawn(listen( @@ -145,14 +145,14 @@ impl StorageApi { pub async fn get(&self, key: String) -> anyhow::Result>> { let msg = Get { key }; match self.inner.request().await? { - Request::Local(request) => { - let (tx, rx) = oneshot::channel(); - request.send((msg, tx)).await?; - Ok(rx) + RequestSender::Local(sender) => { + let (reply, request) = oneshot::channel(); + sender.send((msg, reply)).await?; + Ok(request) } - Request::Remote(request) => { - let (_tx, rx) = request.write(msg).await?; - Ok(rx.into()) + RequestSender::Remote(sender) => { + let (_reply, request) = sender.write(msg).await?; + Ok(request.into()) } } } @@ -160,14 +160,14 @@ impl StorageApi { pub async fn list(&self) -> anyhow::Result> { let msg = List; match self.inner.request().await? { - Request::Local(request) => { - let (tx, rx) = mpsc::channel(10); - request.send((msg, tx)).await?; - Ok(rx) + RequestSender::Local(sender) => { + let (reply, request) = mpsc::channel(10); + sender.send((msg, reply)).await?; + Ok(request) } - Request::Remote(request) => { - let (_tx, rx) = request.write(msg).await?; - Ok(rx.into()) + RequestSender::Remote(sender) => { + let (_reply, request) = sender.write(msg).await?; + Ok(request.into()) } } } @@ -175,14 +175,14 @@ impl StorageApi { pub async fn set(&self, key: String, value: String) -> anyhow::Result> { let msg = Set { key, value }; match self.inner.request().await? { - Request::Local(request) => { - let (tx, rx) = oneshot::channel(); - request.send((msg, tx)).await?; - Ok(rx) + RequestSender::Local(sender) => { + let (reply, request) = oneshot::channel(); + sender.send((msg, reply)).await?; + Ok(request) } - Request::Remote(request) => { - let (_tx, rx) = request.write(msg).await?; - Ok(rx.into()) + RequestSender::Remote(sender) => { + let (_reply, request) = sender.write(msg).await?; + Ok(request.into()) } } } diff --git a/irpc-derive/src/lib.rs b/irpc-derive/src/lib.rs index 754889e..0d5b52c 100644 --- a/irpc-derive/src/lib.rs +++ b/irpc-derive/src/lib.rs @@ -17,11 +17,11 @@ fn error_tokens(span: Span, message: &str) -> TokenStream { /// The only attribute we care about const ATTR_NAME: &str = "rpc"; -/// the tx type name -const TX_ATTR: &str = "tx"; -/// the rx type name -const RX_ATTR: &str = "rx"; -/// Fully qualified path to the default rx type +/// the reply type name +const TX_ATTR: &str = "reply"; +/// the request type name +const RX_ATTR: &str = "updates"; +/// Fully qualified path to the default request type const DEFAULT_RX_TYPE: &str = "::irpc::channel::none::NoReceiver"; /// Generate parent span method for an enum @@ -31,7 +31,7 @@ fn generate_parent_span_impl(enum_name: &Ident, variant_names: &[&Ident]) -> Tok /// Get the parent span of the message pub fn parent_span(&self) -> tracing::Span { let span = match self { - #(#enum_name::#variant_names(inner) => inner.parent_span_opt()),* + #(#enum_name::#variant_names(message) => message.parent_span_opt()),* }; span.cloned().unwrap_or_else(|| ::tracing::Span::current()) } @@ -45,18 +45,18 @@ fn generate_channels_impl( request_type: &Type, attr_span: Span, ) -> syn::Result { - // Try to get rx, default to NoReceiver if not present + // Try to get request, default to NoReceiver if not present // Use unwrap_or_else for a cleaner default - let rx = args.types.remove(RX_ATTR).unwrap_or_else(|| { + let request = args.types.remove(RX_ATTR).unwrap_or_else(|| { // We can safely unwrap here because this is a known valid type - syn::parse_str::(DEFAULT_RX_TYPE).expect("Failed to parse default rx type") + syn::parse_str::(DEFAULT_RX_TYPE).expect("Failed to parse default request type") }); - let tx = args.get(TX_ATTR, attr_span)?; + let reply = args.get(TX_ATTR, attr_span)?; let res = quote! { impl ::irpc::Channels<#service_name> for #request_type { - type Tx = #tx; - type Rx = #rx; + type Response = #reply; + type Request = #request; } }; @@ -72,10 +72,10 @@ fn generate_case_from_impls( let mut impls = quote! {}; // Generate From implementations for each case that has an rpc attribute - for (variant_name, inner_type) in variants_with_attr { + for (variant_name, message_type) in variants_with_attr { let impl_tokens = quote! { - impl From<#inner_type> for #enum_name { - fn from(value: #inner_type) -> Self { + impl From<#message_type> for #enum_name { + fn from(value: #message_type) -> Self { #enum_name::#variant_name(value) } } @@ -98,11 +98,11 @@ fn generate_message_enum_from_impls( ) -> TokenStream2 { let mut impls = quote! {}; - // Generate From> implementations for each case with an rpc attribute - for (variant_name, inner_type) in variants_with_attr { + // Generate From> implementations for each case with an rpc attribute + for (variant_name, message_type) in variants_with_attr { let impl_tokens = quote! { - impl From<::irpc::WithChannels<#inner_type, #service_name>> for #message_enum_name { - fn from(value: ::irpc::WithChannels<#inner_type, #service_name>) -> Self { + impl From<::irpc::Request<#message_type, #service_name>> for #message_enum_name { + fn from(value: ::irpc::Request<#message_type, #service_name>) -> Self { #message_enum_name::#variant_name(value) } } @@ -117,7 +117,7 @@ fn generate_message_enum_from_impls( impls } -/// Generate type aliases for WithChannels +/// Generate type aliases for Request fn generate_type_aliases( variants: &[(Ident, Type)], service_name: &Ident, @@ -125,15 +125,15 @@ fn generate_type_aliases( ) -> TokenStream2 { let mut aliases = quote! {}; - for (variant_name, inner_type) in variants { + for (variant_name, message_type) in variants { // Create a type name using the variant name + suffix // For example: Sum + "Msg" = SumMsg let type_name = format!("{}{}", variant_name, suffix); let type_ident = Ident::new(&type_name, variant_name.span()); let alias = quote! { - /// Type alias for WithChannels<#inner_type, #service_name> - pub type #type_ident = ::irpc::WithChannels<#inner_type, #service_name>; + /// Type alias for Request<#message_type, #service_name> + pub type #type_ident = ::irpc::Request<#message_type, #service_name>; }; aliases = quote! { @@ -153,17 +153,17 @@ fn generate_type_aliases( /// # Macro Arguments /// /// * First positional argument (required): The service type that will handle these requests -/// * `message` (optional): Generate an extended enum wrapping each type in `WithChannels` -/// * `alias` (optional): Generate type aliases with the given suffix for each `WithChannels` +/// * `message` (optional): Generate an extended enum wrapping each type in `Request` +/// * `alias` (optional): Generate type aliases with the given suffix for each `Request` /// /// # Variant Attributes /// /// Individual enum variants can be annotated with the `#[rpc(...)]` attribute to specify channel types: /// -/// * `#[rpc(tx=SomeType)]`: Specify the transmitter/sender channel type (required) -/// * `#[rpc(tx=SomeType, rx=OtherType)]`: Also specify a receiver channel type (optional) +/// * `#[rpc(reply=SomeType)]`: Specify the transmitter/sender channel type (required) +/// * `#[rpc(reply=SomeType, updates=OtherType)]`: Also specify a receiver channel type (optional) /// -/// If `rx` is not specified, it defaults to `NoReceiver`. +/// If `request` is not specified, it defaults to `NoReceiver`. /// /// # Examples /// @@ -171,9 +171,9 @@ fn generate_type_aliases( /// ``` /// #[rpc_requests(ComputeService)] /// enum ComputeProtocol { -/// #[rpc(tx=oneshot::Sender)] +/// #[rpc(reply=oneshot::Sender)] /// Sqr(Sqr), -/// #[rpc(tx=oneshot::Sender)] +/// #[rpc(reply=oneshot::Sender)] /// Sum(Sum), /// } /// ``` @@ -182,9 +182,9 @@ fn generate_type_aliases( /// ``` /// #[rpc_requests(ComputeService, message = ComputeMessage)] /// enum ComputeProtocol { -/// #[rpc(tx=oneshot::Sender)] +/// #[rpc(reply=oneshot::Sender)] /// Sqr(Sqr), -/// #[rpc(tx=oneshot::Sender)] +/// #[rpc(reply=oneshot::Sender)] /// Sum(Sum), /// } /// ``` @@ -193,10 +193,10 @@ fn generate_type_aliases( /// ``` /// #[rpc_requests(ComputeService, alias = "Msg")] /// enum ComputeProtocol { -/// #[rpc(tx=oneshot::Sender)] -/// Sqr(Sqr), // Generates type SqrMsg = WithChannels -/// #[rpc(tx=oneshot::Sender)] -/// Sum(Sum), // Generates type SumMsg = WithChannels +/// #[rpc(reply=oneshot::Sender)] +/// Sqr(Sqr), // Generates type SqrMsg = Request +/// #[rpc(reply=oneshot::Sender)] +/// Sum(Sum), // Generates type SumMsg = Request /// } /// ``` #[proc_macro_attribute] @@ -299,10 +299,10 @@ pub fn rpc_requests(attr: TokenStream, item: TokenStream) -> TokenStream { let extended_enum_code = if let Some(message_enum_name) = message_enum_name { let message_variants = all_variants .iter() - .map(|(variant_name, inner_type)| { + .map(|(variant_name, message_type)| { quote! { #[allow(missing_docs)] - #variant_name(::irpc::WithChannels<#inner_type, #service_name>) + #variant_name(::irpc::Request<#message_type, #service_name>) } }) .collect::>(); @@ -349,7 +349,7 @@ pub fn rpc_requests(attr: TokenStream, item: TokenStream) -> TokenStream { // From implementations for the original enum #original_from_impls - // Type aliases for WithChannels + // Type aliases for Request #type_aliases // Extended enum and its implementations diff --git a/irpc-iroh/examples/auth.rs b/irpc-iroh/examples/auth.rs index 5e952d8..cc28d0b 100644 --- a/irpc-iroh/examples/auth.rs +++ b/irpc-iroh/examples/auth.rs @@ -73,7 +73,7 @@ mod storage { }; use irpc::{ channel::{mpsc, oneshot}, - Client, Service, WithChannels, + Client, Request, Service, }; // Import the macro use irpc_derive::rpc_requests; @@ -116,15 +116,15 @@ mod storage { #[rpc_requests(StorageService, message = StorageMessage)] #[derive(Serialize, Deserialize)] enum StorageProtocol { - #[rpc(tx=oneshot::Sender>)] + #[rpc(reply=oneshot::Sender>)] Auth(Auth), - #[rpc(tx=oneshot::Sender>)] + #[rpc(reply=oneshot::Sender>)] Get(Get), - #[rpc(tx=oneshot::Sender<()>)] + #[rpc(reply=oneshot::Sender<()>)] Set(Set), - #[rpc(tx=oneshot::Sender, rx=mpsc::Receiver<(String, String)>)] + #[rpc(reply=oneshot::Sender, updates=mpsc::Receiver<(String, String)>)] SetMany(SetMany), - #[rpc(tx=mpsc::Sender)] + #[rpc(reply=mpsc::Sender)] List(List), } @@ -139,20 +139,20 @@ mod storage { let this = self.clone(); Box::pin(async move { let mut authed = false; - while let Some((msg, rx, tx)) = read_request(&conn).await? { - let msg_with_channels = upcast_message(msg, rx, tx); + while let Some((msg, request, reply)) = read_request(&conn).await? { + let msg_with_channels = upcast_message(msg, request, reply); match msg_with_channels { StorageMessage::Auth(msg) => { - let WithChannels { inner, tx, .. } = msg; + let Request { message, reply, .. } = msg; if authed { conn.close(1u32.into(), b"invalid message"); break; - } else if inner.token != this.auth_token { + } else if message.token != this.auth_token { conn.close(1u32.into(), b"permission denied"); break; } else { authed = true; - tx.send(Ok(())).await.ok(); + reply.send(Ok(())).await.ok(); } } _ => { @@ -171,13 +171,17 @@ mod storage { } } - fn upcast_message(msg: StorageProtocol, rx: RecvStream, tx: SendStream) -> StorageMessage { + fn upcast_message( + msg: StorageProtocol, + request: RecvStream, + reply: SendStream, + ) -> StorageMessage { match msg { - StorageProtocol::Auth(msg) => WithChannels::from((msg, tx, rx)).into(), - StorageProtocol::Get(msg) => WithChannels::from((msg, tx, rx)).into(), - StorageProtocol::Set(msg) => WithChannels::from((msg, tx, rx)).into(), - StorageProtocol::SetMany(msg) => WithChannels::from((msg, tx, rx)).into(), - StorageProtocol::List(msg) => WithChannels::from((msg, tx, rx)).into(), + StorageProtocol::Auth(msg) => Request::from((msg, reply, request)).into(), + StorageProtocol::Get(msg) => Request::from((msg, reply, request)).into(), + StorageProtocol::Set(msg) => Request::from((msg, reply, request)).into(), + StorageProtocol::SetMany(msg) => Request::from((msg, reply, request)).into(), + StorageProtocol::List(msg) => Request::from((msg, reply, request)).into(), } } @@ -196,29 +200,34 @@ mod storage { StorageMessage::Auth(_) => unreachable!("handled in ProtocolHandler::accept"), StorageMessage::Get(get) => { info!("get {:?}", get); - let WithChannels { tx, inner, .. } = get; - let res = self.state.lock().unwrap().get(&inner.key).cloned(); - tx.send(res).await.ok(); + let Request { reply, message, .. } = get; + let res = self.state.lock().unwrap().get(&message.key).cloned(); + reply.send(res).await.ok(); } StorageMessage::Set(set) => { info!("set {:?}", set); - let WithChannels { tx, inner, .. } = set; - self.state.lock().unwrap().insert(inner.key, inner.value); - tx.send(()).await.ok(); + let Request { reply, message, .. } = set; + self.state + .lock() + .unwrap() + .insert(message.key, message.value); + reply.send(()).await.ok(); } StorageMessage::SetMany(list) => { - let WithChannels { tx, mut rx, .. } = list; + let Request { + reply, mut updates, .. + } = list; let mut i = 0; - while let Ok(Some((key, value))) = rx.recv().await { + while let Ok(Some((key, value))) = updates.recv().await { let mut state = self.state.lock().unwrap(); state.insert(key, value); i += 1; } - tx.send(i).await.ok(); + reply.send(i).await.ok(); } StorageMessage::List(list) => { info!("list {:?}", list); - let WithChannels { tx, .. } = list; + let Request { reply, .. } = list; let values = { let state = self.state.lock().unwrap(); // TODO: use async lock to not clone here. @@ -229,7 +238,7 @@ mod storage { values }; for value in values { - if tx.send(value).await.is_err() { + if reply.send(value).await.is_err() { break; } } diff --git a/irpc-iroh/examples/derive.rs b/irpc-iroh/examples/derive.rs index 65881e0..d1fcb8a 100644 --- a/irpc-iroh/examples/derive.rs +++ b/irpc-iroh/examples/derive.rs @@ -62,7 +62,7 @@ mod storage { use irpc::{ channel::{mpsc, oneshot}, rpc::Handler, - rpc_requests, Client, LocalSender, Service, WithChannels, + rpc_requests, Client, LocalSender, Request, Service, }; // Import the macro use irpc_iroh::{IrohProtocol, IrohRemoteConnection}; @@ -93,11 +93,11 @@ mod storage { #[rpc_requests(StorageService, message = StorageMessage)] #[derive(Serialize, Deserialize)] enum StorageProtocol { - #[rpc(tx=oneshot::Sender>)] + #[rpc(reply=oneshot::Sender>)] Get(Get), - #[rpc(tx=oneshot::Sender<()>)] + #[rpc(reply=oneshot::Sender<()>)] Set(Set), - #[rpc(tx=mpsc::Sender)] + #[rpc(reply=mpsc::Sender)] List(List), } @@ -108,13 +108,13 @@ mod storage { impl StorageActor { pub fn spawn() -> StorageApi { - let (tx, rx) = tokio::sync::mpsc::channel(1); + let (reply, request) = tokio::sync::mpsc::channel(1); let actor = Self { - recv: rx, + recv: request, state: BTreeMap::new(), }; n0_future::task::spawn(actor.run()); - let local = LocalSender::::from(tx); + let local = LocalSender::::from(reply); StorageApi { inner: local.into(), } @@ -130,20 +130,20 @@ mod storage { match msg { StorageMessage::Get(get) => { info!("get {:?}", get); - let WithChannels { tx, inner, .. } = get; - tx.send(self.state.get(&inner.key).cloned()).await.ok(); + let Request { reply, message, .. } = get; + reply.send(self.state.get(&message.key).cloned()).await.ok(); } StorageMessage::Set(set) => { info!("set {:?}", set); - let WithChannels { tx, inner, .. } = set; - self.state.insert(inner.key, inner.value); - tx.send(()).await.ok(); + let Request { reply, message, .. } = set; + self.state.insert(message.key, message.value); + reply.send(()).await.ok(); } StorageMessage::List(list) => { info!("list {:?}", list); - let WithChannels { tx, .. } = list; + let Request { reply, .. } = list; for (key, value) in &self.state { - if tx.send(format!("{key}={value}")).await.is_err() { + if reply.send(format!("{key}={value}")).await.is_err() { break; } } @@ -175,12 +175,12 @@ mod storage { .inner .local() .context("can not listen on remote service")?; - let handler: Handler = Arc::new(move |msg, _rx, tx| { + let handler: Handler = Arc::new(move |msg, _request, reply| { let local = local.clone(); Box::pin(match msg { - StorageProtocol::Get(msg) => local.send((msg, tx)), - StorageProtocol::Set(msg) => local.send((msg, tx)), - StorageProtocol::List(msg) => local.send((msg, tx)), + StorageProtocol::Get(msg) => local.send((msg, reply)), + StorageProtocol::Set(msg) => local.send((msg, reply)), + StorageProtocol::List(msg) => local.send((msg, reply)), }) }); Ok(IrohProtocol::new(handler)) diff --git a/irpc-iroh/src/lib.rs b/irpc-iroh/src/lib.rs index 2d8d065..6accbbe 100644 --- a/irpc-iroh/src/lib.rs +++ b/irpc-iroh/src/lib.rs @@ -128,10 +128,10 @@ pub async fn handle_connection( handler: Handler, ) -> io::Result<()> { loop { - let Some((msg, rx, tx)) = read_request(&connection).await? else { + let Some((msg, request, reply)) = read_request(&connection).await? else { return Ok(()); }; - handler(msg, rx, tx).await?; + handler(msg, request, reply).await?; } } @@ -166,9 +166,9 @@ pub async fn read_request( .map_err(|e| io::Error::new(io::ErrorKind::UnexpectedEof, e))?; let msg: R = postcard::from_bytes(&buf).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; - let rx = recv; - let tx = send; - Ok(Some((msg, rx, tx))) + let request = recv; + let reply = send; + Ok(Some((msg, request, reply))) } /// Utility function to listen for incoming connections and handle them with the provided handler diff --git a/src/lib.rs b/src/lib.rs index 0372441..5478b7e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -20,17 +20,17 @@ //! //! ## Interaction patterns //! -//! For each request, there can be a response and update channel. Each channel +//! For each request, there can be a reply and update channel. Each channel //! can be either oneshot, carry multiple messages, or be disabled. This enables //! the typical interaction patterns known from libraries like grpc: //! -//! - rpc: 1 request, 1 response -//! - server streaming: 1 request, multiple responses -//! - client streaming: multiple requests, 1 response -//! - bidi streaming: multiple requests, multiple responses +//! - rpc: 1 request, 1 reply +//! - server streaming: 1 request, multiple replys +//! - client streaming: multiple requests, 1 reply +//! - bidi streaming: multiple requests, multiple replys //! //! as well as more complex patterns. It is however not possible to have multiple -//! differently typed tx channels for a single message type. +//! differently typed reply channels for a single message type. //! //! ## Transports //! @@ -111,7 +111,7 @@ impl RpcMessage for T where /// This is usually implemented by a zero-sized struct. /// It has various bounds to make derives easier. /// -/// A service acts as a scope for defining the tx and rx channels for each +/// A service acts as a scope for defining the reply and request channels for each /// message type, and provides some type safety when sending messages. pub trait Service: Send + Sync + Debug + Clone + 'static {} @@ -128,11 +128,11 @@ pub trait Receiver: Debug + Sealed {} /// Trait to specify channels for a message and service pub trait Channels { /// The sender type, can be either mpsc, oneshot or none - type Tx: Sender; + type Response: Sender; /// The receiver type, can be either mpsc, oneshot or none /// /// For many services, the receiver is not needed, so it can be set to [`NoReceiver`]. - type Rx: Receiver; + type Request: Receiver; } /// Channels that abstract over local or remote sending @@ -152,8 +152,8 @@ pub mod channel { /// /// This is currently using a tokio channel pair internally. pub fn channel() -> (Sender, Receiver) { - let (tx, rx) = tokio::sync::oneshot::channel(); - (tx.into(), rx.into()) + let (reply, request) = tokio::sync::oneshot::channel(); + (reply.into(), request.into()) } /// A generic boxed sender. @@ -206,8 +206,8 @@ pub mod channel { } impl From> for Sender { - fn from(tx: tokio::sync::oneshot::Sender) -> Self { - Self::Tokio(tx) + fn from(reply: tokio::sync::oneshot::Sender) -> Self { + Self::Tokio(reply) } } @@ -216,7 +216,7 @@ pub mod channel { fn try_from(value: Sender) -> Result { match value { - Sender::Tokio(tx) => Ok(tx), + Sender::Tokio(reply) => Ok(reply), Sender::Boxed(_) => Err(value), } } @@ -229,7 +229,9 @@ pub mod channel { /// Local senders will never yield, but can fail if the receiver has been closed. pub async fn send(self, value: T) -> std::result::Result<(), SendError> { match self { - Sender::Tokio(tx) => tx.send(value).map_err(|_| SendError::ReceiverClosed), + Sender::Tokio(reply) => { + reply.send(value).map_err(|_| SendError::ReceiverClosed) + } Sender::Boxed(f) => f(value).await.map_err(SendError::from), } } @@ -265,16 +267,18 @@ pub mod channel { fn poll(self: Pin<&mut Self>, cx: &mut task::Context) -> task::Poll { match self.get_mut() { - Self::Tokio(rx) => Pin::new(rx).poll(cx).map_err(|_| RecvError::SenderClosed), - Self::Boxed(rx) => Pin::new(rx).poll(cx).map_err(RecvError::Io), + Self::Tokio(request) => Pin::new(request) + .poll(cx) + .map_err(|_| RecvError::SenderClosed), + Self::Boxed(request) => Pin::new(request).poll(cx).map_err(RecvError::Io), } } } /// Convert a tokio oneshot receiver to a receiver for this crate impl From> for Receiver { - fn from(rx: tokio::sync::oneshot::Receiver) -> Self { - Self::Tokio(FusedOneshotReceiver(rx)) + fn from(request: tokio::sync::oneshot::Receiver) -> Self { + Self::Tokio(FusedOneshotReceiver(request)) } } @@ -283,7 +287,7 @@ pub mod channel { fn try_from(value: Receiver) -> Result { match value { - Receiver::Tokio(tx) => Ok(tx.0), + Receiver::Tokio(reply) => Ok(reply.0), Receiver::Boxed(_) => Err(value), } } @@ -326,8 +330,8 @@ pub mod channel { /// /// This is currently using a tokio channel pair internally. pub fn channel(buffer: usize) -> (Sender, Receiver) { - let (tx, rx) = tokio::sync::mpsc::channel(buffer); - (tx.into(), rx.into()) + let (reply, request) = tokio::sync::mpsc::channel(buffer); + (reply.into(), request.into()) } /// Single producer, single consumer sender. @@ -355,7 +359,7 @@ pub mod channel { T: RpcMessage, { match self { - Sender::Tokio(tx) => tx.closed().await, + Sender::Tokio(reply) => reply.closed().await, Sender::Boxed(sink) => sink.closed().await, } } @@ -373,8 +377,8 @@ pub mod channel { } impl From> for Sender { - fn from(tx: tokio::sync::mpsc::Sender) -> Self { - Self::Tokio(tx) + fn from(reply: tokio::sync::mpsc::Sender) -> Self { + Self::Tokio(reply) } } @@ -383,7 +387,7 @@ pub mod channel { fn try_from(value: Sender) -> Result { match value { - Sender::Tokio(tx) => Ok(tx), + Sender::Tokio(reply) => Ok(reply), Sender::Boxed(_) => Err(value), } } @@ -439,7 +443,7 @@ pub mod channel { .field("avail", &x.capacity()) .field("cap", &x.max_capacity()) .finish(), - Self::Boxed(inner) => f.debug_tuple("Boxed").field(&inner).finish(), + Self::Boxed(message) => f.debug_tuple("Boxed").field(&message).finish(), } } } @@ -455,9 +459,10 @@ pub mod channel { /// future until completion if you want to reuse the sender or any clone afterwards. pub async fn send(&self, value: T) -> std::result::Result<(), SendError> { match self { - Sender::Tokio(tx) => { - tx.send(value).await.map_err(|_| SendError::ReceiverClosed) - } + Sender::Tokio(reply) => reply + .send(value) + .await + .map_err(|_| SendError::ReceiverClosed), Sender::Boxed(sink) => sink.send(value).await.map_err(SendError::from), } } @@ -485,7 +490,7 @@ pub mod channel { /// future until completion if you want to reuse the sender or any clone afterwards. pub async fn try_send(&mut self, value: T) -> std::result::Result { match self { - Sender::Tokio(tx) => match tx.try_send(value) { + Sender::Tokio(reply) => match reply.try_send(value) { Ok(()) => Ok(true), Err(tokio::sync::mpsc::error::TrySendError::Closed(_)) => { Err(SendError::ReceiverClosed) @@ -514,8 +519,8 @@ pub mod channel { /// Returns an an io error if there was an error receiving the message. pub async fn recv(&mut self) -> std::result::Result, RecvError> { match self { - Self::Tokio(rx) => Ok(rx.recv().await), - Self::Boxed(rx) => Ok(rx.recv().await?), + Self::Tokio(request) => Ok(request.recv().await), + Self::Boxed(request) => Ok(request.recv().await?), } } @@ -531,8 +536,8 @@ pub mod channel { } impl From> for Receiver { - fn from(rx: tokio::sync::mpsc::Receiver) -> Self { - Self::Tokio(rx) + fn from(request: tokio::sync::mpsc::Receiver) -> Self { + Self::Tokio(request) } } @@ -541,7 +546,7 @@ pub mod channel { fn try_from(value: Receiver) -> Result { match value { - Receiver::Tokio(tx) => Ok(tx), + Receiver::Tokio(reply) => Ok(reply), Receiver::Boxed(_) => Err(value), } } @@ -550,12 +555,12 @@ pub mod channel { impl Debug for Receiver { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - Self::Tokio(inner) => f + Self::Tokio(message) => f .debug_struct("Tokio") - .field("avail", &inner.capacity()) - .field("cap", &inner.max_capacity()) + .field("avail", &message.capacity()) + .field("cap", &message.max_capacity()) .finish(), - Self::Boxed(inner) => f.debug_tuple("Boxed").field(&inner).finish(), + Self::Boxed(message) => f.debug_tuple("Boxed").field(&message).finish(), } } } @@ -639,35 +644,35 @@ pub mod channel { /// This expands the protocol message to a full message that includes the /// active and unserializable channels. /// -/// The channel kind for rx and tx is defined by implementing the `Channels` +/// The channel kind for request and reply is defined by implementing the `Channels` /// trait, either manually or using a macro. /// /// When the `message_spans` feature is enabled, this also includes a tracing /// span to carry the tracing context during message passing. -pub struct WithChannels, S: Service> { - /// The inner message. - pub inner: I, - /// The return channel to send the response to. Can be set to [`crate::channel::none::NoSender`] if not needed. - pub tx: >::Tx, +pub struct Request, S: Service> { + /// The request message. + pub message: I, + /// The return channel to send the reply to. Can be set to [`crate::channel::none::NoSender`] if not needed. + pub reply: >::Response, /// The request channel to receive the request from. Can be set to [`NoReceiver`] if not needed. - pub rx: >::Rx, + pub updates: >::Request, /// The current span where the full message was created. #[cfg(feature = "message_spans")] #[cfg_attr(quicrpc_docsrs, doc(cfg(feature = "message_spans")))] pub span: tracing::Span, } -impl + Debug, S: Service> Debug for WithChannels { +impl + Debug, S: Service> Debug for Request { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_tuple("") - .field(&self.inner) - .field(&self.tx) - .field(&self.rx) + .field(&self.message) + .field(&self.reply) + .field(&self.updates) .finish() } } -impl, S: Service> WithChannels { +impl, S: Service> Request { /// Get the parent span #[cfg(feature = "message_spans")] pub fn parent_span_opt(&self) -> Option<&tracing::Span> { @@ -675,21 +680,21 @@ impl, S: Service> WithChannels { } } -/// Tuple conversion from inner message and tx/rx channels to a WithChannels struct +/// Tuple conversion from message message and reply/request channels to a Request struct /// -/// For the case where you want both tx and rx channels. -impl, S: Service, Tx, Rx> From<(I, Tx, Rx)> for WithChannels +/// For the case where you want both reply and request channels. +impl, S: Service, Response, Updates> From<(I, Response, Updates)> for Request where I: Channels, - >::Tx: From, - >::Rx: From, + >::Response: From, + >::Request: From, { - fn from(inner: (I, Tx, Rx)) -> Self { - let (inner, tx, rx) = inner; + fn from(message: (I, Response, Updates)) -> Self { + let (message, reply, updates) = message; Self { - inner, - tx: tx.into(), - rx: rx.into(), + message, + reply: reply.into(), + updates: updates.into(), #[cfg(feature = "message_spans")] #[cfg_attr(quicrpc_docsrs, doc(cfg(feature = "message_spans")))] span: tracing::Span::current(), @@ -697,21 +702,21 @@ where } } -/// Tuple conversion from inner message and tx channel to a WithChannels struct +/// Tuple conversion from message message and reply channel to a Request struct /// -/// For the very common case where you just need a tx channel to send the response to. -impl From<(I, Tx)> for WithChannels +/// For the very common case where you just need a reply channel to send the reply to. +impl From<(I, Response)> for Request where - I: Channels, + I: Channels, S: Service, - >::Tx: From, + >::Response: From, { - fn from(inner: (I, Tx)) -> Self { - let (inner, tx) = inner; + fn from(message: (I, Response)) -> Self { + let (message, reply) = message; Self { - inner, - tx: tx.into(), - rx: NoReceiver, + message, + reply: reply.into(), + updates: NoReceiver, #[cfg(feature = "message_spans")] #[cfg_attr(quicrpc_docsrs, doc(cfg(feature = "message_spans")))] span: tracing::Span::current(), @@ -719,15 +724,15 @@ where } } -/// Deref so you can access the inner fields directly. +/// Deref so you can access the message fields directly. /// -/// If the inner message has fields named `tx`, `rx` or `span`, you need to use the -/// `inner` field to access them. -impl, S: Service> Deref for WithChannels { +/// If the message message has fields named `reply`, `request` or `span`, you need to use the +/// `message` field to access them. +impl, S: Service> Deref for Request { type Target = I; fn deref(&self) -> &Self::Target { - &self.inner + &self.message } } @@ -738,8 +743,8 @@ impl, S: Service> Deref for WithChannels { /// type. It can be thought of as the definition of the protocol. /// /// `M` is typically an enum with a case for each possible message type, where -/// each case is a `WithChannels` struct that extends the inner protocol message -/// with a local tx and rx channel as well as a tracing span to allow for +/// each case is a `Request` struct that extends the message protocol message +/// with a local reply and request channel as well as a tracing span to allow for /// keeping tracing context across async boundaries. /// /// In some cases, `M` and `R` can be enums for a subset of the protocol. E.g. @@ -757,14 +762,14 @@ impl Clone for Client { } impl From> for Client { - fn from(tx: LocalSender) -> Self { - Self(ClientInner::Local(tx.0), PhantomData) + fn from(reply: LocalSender) -> Self { + Self(ClientInner::Local(reply.0), PhantomData) } } impl From> for Client { - fn from(tx: tokio::sync::mpsc::Sender) -> Self { - LocalSender::from(tx).into() + fn from(reply: tokio::sync::mpsc::Sender) -> Self { + LocalSender::from(reply).into() } } @@ -788,7 +793,7 @@ impl Client { /// requests. pub fn local(&self) -> Option> { match &self.0 { - ClientInner::Local(tx) => Some(tx.clone().into()), + ClientInner::Local(reply) => Some(reply.clone().into()), ClientInner::Remote(..) => None, } } @@ -808,7 +813,10 @@ impl Client { pub fn request( &self, ) -> impl Future< - Output = result::Result, rpc::RemoteSender>, RequestError>, + Output = result::Result< + RequestSender, rpc::RemoteSender>, + RequestError, + >, > + 'static where S: Service, @@ -818,26 +826,26 @@ impl Client { #[cfg(feature = "rpc")] { let cloned = match &self.0 { - ClientInner::Local(tx) => Request::Local(tx.clone()), - ClientInner::Remote(connection) => Request::Remote(connection.clone_boxed()), + ClientInner::Local(reply) => RequestSender::Local(reply.clone()), + ClientInner::Remote(connection) => RequestSender::Remote(connection.clone_boxed()), }; async move { match cloned { - Request::Local(tx) => Ok(Request::Local(tx.into())), - Request::Remote(conn) => { + RequestSender::Local(reply) => Ok(RequestSender::Local(reply.into())), + RequestSender::Remote(conn) => { let (send, recv) = conn.open_bi().await?; - Ok(Request::Remote(rpc::RemoteSender::new(send, recv))) + Ok(RequestSender::Remote(rpc::RemoteSender::new(send, recv))) } } } } #[cfg(not(feature = "rpc"))] { - let ClientInner::Local(tx) = &self.0 else { + let ClientInner::Local(reply) = &self.0 else { unreachable!() }; - let tx = tx.clone().into(); - async move { Ok(Request::Local(tx)) } + let reply = reply.clone().into(); + async move { Ok(RequestSender::Local(reply)) } } } @@ -845,25 +853,27 @@ impl Client { pub fn rpc(&self, msg: Req) -> impl Future> + Send + 'static where S: Service, - M: From> + Send + Sync + Unpin + 'static, + M: From> + Send + Sync + Unpin + 'static, R: From + Serialize + Send + Sync + 'static, - Req: Channels, Rx = NoReceiver> + Send + 'static, + Req: Channels, Request = NoReceiver> + + Send + + 'static, Res: RpcMessage, { let request = self.request(); async move { let recv: channel::oneshot::Receiver = match request.await? { - Request::Local(request) => { - let (tx, rx) = channel::oneshot::channel(); - request.send((msg, tx)).await?; - rx + RequestSender::Local(tx) => { + let (reply, request) = channel::oneshot::channel(); + tx.send((msg, reply)).await?; + request } #[cfg(not(feature = "rpc"))] - Request::Remote(_request) => unreachable!(), + RequestSender::Remote(_request) => unreachable!(), #[cfg(feature = "rpc")] - Request::Remote(request) => { - let (_tx, rx) = request.write(msg).await?; - rx.into() + RequestSender::Remote(tx) => { + let (_reply, request) = tx.write(msg).await?; + request.into() } }; let res = recv.await?; @@ -875,29 +885,31 @@ impl Client { pub fn server_streaming( &self, msg: Req, - local_response_cap: usize, + local_reply_cap: usize, ) -> impl Future>> + Send + 'static where S: Service, - M: From> + Send + Sync + Unpin + 'static, + M: From> + Send + Sync + Unpin + 'static, R: From + Serialize + Send + Sync + 'static, - Req: Channels, Rx = NoReceiver> + Send + 'static, + Req: Channels, Request = NoReceiver> + + Send + + 'static, Res: RpcMessage, { let request = self.request(); async move { let recv: channel::mpsc::Receiver = match request.await? { - Request::Local(request) => { - let (tx, rx) = channel::mpsc::channel(local_response_cap); - request.send((msg, tx)).await?; - rx + RequestSender::Local(tx) => { + let (reply, request) = channel::mpsc::channel(local_reply_cap); + tx.send((msg, reply)).await?; + request } #[cfg(not(feature = "rpc"))] - Request::Remote(_request) => unreachable!(), + RequestSender::Remote(_request) => unreachable!(), #[cfg(feature = "rpc")] - Request::Remote(request) => { - let (_tx, rx) = request.write(msg).await?; - rx.into() + RequestSender::Remote(tx) => { + let (_reply, request) = tx.write(msg).await?; + request.into() } }; Ok(recv) @@ -917,33 +929,37 @@ impl Client { > where S: Service, - M: From> + Send + Sync + Unpin + 'static, + M: From> + Send + Sync + Unpin + 'static, R: From + Serialize + 'static, - Req: Channels, Rx = channel::mpsc::Receiver>, + Req: Channels< + S, + Response = channel::oneshot::Sender, + Request = channel::mpsc::Receiver, + >, Update: RpcMessage, Res: RpcMessage, { let request = self.request(); async move { - let (update_tx, res_rx): ( + let (update_reply, res_request): ( channel::mpsc::Sender, channel::oneshot::Receiver, ) = match request.await? { - Request::Local(request) => { - let (req_tx, req_rx) = channel::mpsc::channel(local_update_cap); - let (res_tx, res_rx) = channel::oneshot::channel(); - request.send((msg, res_tx, req_rx)).await?; - (req_tx, res_rx) + RequestSender::Local(request) => { + let (req_reply, req_request) = channel::mpsc::channel(local_update_cap); + let (res_reply, res_request) = channel::oneshot::channel(); + request.send((msg, res_reply, req_request)).await?; + (req_reply, res_request) } #[cfg(not(feature = "rpc"))] - Request::Remote(_request) => unreachable!(), + RequestSender::Remote(_request) => unreachable!(), #[cfg(feature = "rpc")] - Request::Remote(request) => { - let (tx, rx) = request.write(msg).await?; - (tx.into(), rx.into()) + RequestSender::Remote(request) => { + let (reply, request) = request.write(msg).await?; + (reply.into(), request.into()) } }; - Ok((update_tx, res_rx)) + Ok((update_reply, res_request)) } } @@ -952,39 +968,44 @@ impl Client { &self, msg: Req, local_update_cap: usize, - local_response_cap: usize, + local_reply_cap: usize, ) -> impl Future, channel::mpsc::Receiver)>> + Send + 'static where S: Service, - M: From> + Send + Sync + Unpin + 'static, + M: From> + Send + Sync + Unpin + 'static, R: From + Serialize + Send + 'static, - Req: Channels, Rx = channel::mpsc::Receiver> - + Send + Req: Channels< + S, + Response = channel::mpsc::Sender, + Request = channel::mpsc::Receiver, + > + Send + 'static, Update: RpcMessage, Res: RpcMessage, { let request = self.request(); async move { - let (update_tx, res_rx): (channel::mpsc::Sender, channel::mpsc::Receiver) = - match request.await? { - Request::Local(request) => { - let (update_tx, update_rx) = channel::mpsc::channel(local_update_cap); - let (res_tx, res_rx) = channel::mpsc::channel(local_response_cap); - request.send((msg, res_tx, update_rx)).await?; - (update_tx, res_rx) - } - #[cfg(not(feature = "rpc"))] - Request::Remote(_request) => unreachable!(), - #[cfg(feature = "rpc")] - Request::Remote(request) => { - let (tx, rx) = request.write(msg).await?; - (tx.into(), rx.into()) - } - }; - Ok((update_tx, res_rx)) + let (update_reply, res_request): ( + channel::mpsc::Sender, + channel::mpsc::Receiver, + ) = match request.await? { + RequestSender::Local(request) => { + let (update_reply, update_request) = channel::mpsc::channel(local_update_cap); + let (res_reply, res_request) = channel::mpsc::channel(local_reply_cap); + request.send((msg, res_reply, update_request)).await?; + (update_reply, res_request) + } + #[cfg(not(feature = "rpc"))] + RequestSender::Remote(_request) => unreachable!(), + #[cfg(feature = "rpc")] + RequestSender::Remote(request) => { + let (reply, request) = request.write(msg).await?; + (reply.into(), request.into()) + } + }; + Ok((update_reply, res_request)) } } } @@ -1004,7 +1025,7 @@ pub(crate) enum ClientInner { impl Clone for ClientInner { fn clone(&self) -> Self { match self { - Self::Local(tx) => Self::Local(tx.clone()), + Self::Local(reply) => Self::Local(reply.clone()), #[cfg(feature = "rpc")] Self::Remote(conn) => Self::Remote(conn.clone_boxed()), #[cfg(not(feature = "rpc"))] @@ -1080,7 +1101,7 @@ impl From for io::Error { /// /// This is a wrapper around an in-memory channel (currently [`tokio::sync::mpsc::Sender`]), /// that adds nice syntax for sending messages that can be converted into -/// [`WithChannels`]. +/// [`Request`]. #[derive(Debug)] #[repr(transparent)] pub struct LocalSender(tokio::sync::mpsc::Sender, std::marker::PhantomData); @@ -1092,8 +1113,8 @@ impl Clone for LocalSender { } impl From> for LocalSender { - fn from(tx: tokio::sync::mpsc::Sender) -> Self { - Self(tx, PhantomData) + fn from(reply: tokio::sync::mpsc::Sender) -> Self { + Self(reply, PhantomData) } } @@ -1538,9 +1559,9 @@ pub mod rpc { .map_err(|e| io::Error::new(io::ErrorKind::UnexpectedEof, e))?; let msg: R = postcard::from_bytes(&buf) .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; - let rx = recv; - let tx = send; - handler(msg, rx, tx).await?; + let request = recv; + let reply = send; + handler(msg, request, reply).await?; } }; let span = trace_span!("rpc", id = request_id); @@ -1552,7 +1573,7 @@ pub mod rpc { /// A request to a service. This can be either local or remote. #[derive(Debug)] -pub enum Request { +pub enum RequestSender { /// Local in memory request Local(L), /// Remote cross process request @@ -1561,10 +1582,10 @@ pub enum Request { impl LocalSender { /// Send a message to the service - pub fn send(&self, value: impl Into>) -> SendFut + pub fn send(&self, value: impl Into>) -> SendFut where T: Channels, - M: From>, + M: From>, { let value: M = value.into().into(); SendFut::new(self.0.clone(), value) diff --git a/src/util.rs b/src/util.rs index 5d04877..e89b044 100644 --- a/src/util.rs +++ b/src/util.rs @@ -118,7 +118,7 @@ mod quinn_setup_utils { _end_entity: &rustls::pki_types::CertificateDer<'_>, _intermediates: &[rustls::pki_types::CertificateDer<'_>], _server_name: &rustls::pki_types::ServerName<'_>, - _ocsp_response: &[u8], + _ocsp_reply: &[u8], _now: rustls::pki_types::UnixTime, ) -> Result { Ok(rustls::client::danger::ServerCertVerified::assertion()) diff --git a/tests/compile_fail/extra_attr_types.rs b/tests/compile_fail/extra_attr_types.rs index d6c288f..94f5e60 100644 --- a/tests/compile_fail/extra_attr_types.rs +++ b/tests/compile_fail/extra_attr_types.rs @@ -2,7 +2,7 @@ use irpc::rpc_requests; #[rpc_requests(Service, Msg)] enum Enum { - #[rpc(tx = NoSender, rx = NoReceiver, fnord = Foo)] + #[rpc(reply = NoSender, request = NoReceiver, fnord = Foo)] A(u8), } diff --git a/tests/compile_fail/extra_attr_types.stderr b/tests/compile_fail/extra_attr_types.stderr index c19048b..7f71b9a 100644 --- a/tests/compile_fail/extra_attr_types.stderr +++ b/tests/compile_fail/extra_attr_types.stderr @@ -1,5 +1,5 @@ error: Unknown arguments provided: ["fnord"] --> tests/compile_fail/extra_attr_types.rs:5:5 | -5 | #[rpc(tx = NoSender, rx = NoReceiver, fnord = Foo)] +5 | #[rpc(reply = NoSender, request = NoReceiver, fnord = Foo)] | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/tests/compile_fail/wrong_attr_types.stderr b/tests/compile_fail/wrong_attr_types.stderr index f679dbd..fd1c062 100644 --- a/tests/compile_fail/wrong_attr_types.stderr +++ b/tests/compile_fail/wrong_attr_types.stderr @@ -1,4 +1,4 @@ -error: rpc requires a tx type +error: rpc requires a reply type --> tests/compile_fail/wrong_attr_types.rs:5:5 | 5 | #[rpc(fnord = Bla)] diff --git a/tests/derive.rs b/tests/derive.rs index 3e0122f..4fef41d 100644 --- a/tests/derive.rs +++ b/tests/derive.rs @@ -36,16 +36,16 @@ fn derive_simple() { #[derive(Debug, Serialize, Deserialize)] struct Response4; - #[rpc_requests(Service, message = RequestWithChannels)] + #[rpc_requests(Service, message = RequestRequest)] #[derive(Debug, Serialize, Deserialize)] enum Request { - #[rpc(tx=oneshot::Sender<()>)] + #[rpc(reply=oneshot::Sender<()>)] Rpc(RpcRequest), - #[rpc(tx=NoSender)] + #[rpc(reply=NoSender)] ServerStreaming(ServerStreamingRequest), - #[rpc(tx=NoSender)] + #[rpc(reply=NoSender)] BidiStreaming(BidiStreamingRequest), - #[rpc(tx=NoSender)] + #[rpc(reply=NoSender)] ClientStreaming(ClientStreamingRequest), } From e775b49493d336521bc5d232086811eb0f7a809b Mon Sep 17 00:00:00 2001 From: Frando Date: Thu, 19 Jun 2025 13:06:37 +0200 Subject: [PATCH 14/17] fixup --- examples/compute.rs | 50 ++++++++++++++++++++++----------------------- examples/derive.rs | 8 ++++---- 2 files changed, 29 insertions(+), 29 deletions(-) diff --git a/examples/compute.rs b/examples/compute.rs index f920fd0..cb7037a 100644 --- a/examples/compute.rs +++ b/examples/compute.rs @@ -182,13 +182,13 @@ impl ComputeApi { let Some(local) = self.inner.local() else { bail!("cannot listen on a remote service"); }; - let handler: Handler = Arc::new(move |msg, request, reply| { + let handler: Handler = Arc::new(move |msg, updates, reply| { let local = local.clone(); Box::pin(match msg { ComputeProtocol::Sqr(msg) => local.send((msg, reply)), - ComputeProtocol::Sum(msg) => local.send((msg, reply, request)), + ComputeProtocol::Sum(msg) => local.send((msg, reply, updates)), ComputeProtocol::Fibonacci(msg) => local.send((msg, reply)), - ComputeProtocol::Multiply(msg) => local.send((msg, reply, request)), + ComputeProtocol::Multiply(msg) => local.send((msg, reply, updates)), }) }); Ok(AbortOnDropHandle::new(task::spawn(listen( @@ -200,13 +200,13 @@ impl ComputeApi { let msg = Sqr { num }; match self.inner.request().await? { RequestSender::Local(sender) => { - let (reply, request) = oneshot::channel(); - sender.send((msg, reply)).await?; - Ok(request) + let (tx, rx) = oneshot::channel(); + sender.send((msg, tx)).await?; + Ok(rx) } RequestSender::Remote(sender) => { - let (_reply, request) = sender.write(msg).await?; - Ok(request.into()) + let (_tx, rx) = sender.write(msg).await?; + Ok(rx.into()) } } } @@ -215,14 +215,14 @@ impl ComputeApi { let msg = Sum; match self.inner.request().await? { RequestSender::Local(sender) => { - let (num_reply, num_request) = mpsc::channel(10); - let (sum_reply, sum_request) = oneshot::channel(); - sender.send((msg, sum_reply, num_request)).await?; - Ok((num_reply, sum_request)) + let (num_tx, num_rx) = mpsc::channel(10); + let (sum_tx, sum_rx) = oneshot::channel(); + sender.send((msg, sum_tx, num_rx)).await?; + Ok((num_tx, sum_rx)) } RequestSender::Remote(sender) => { - let (reply, request) = sender.write(msg).await?; - Ok((reply.into(), request.into())) + let (tx, rx) = sender.write(msg).await?; + Ok((tx.into(), rx.into())) } } } @@ -231,13 +231,13 @@ impl ComputeApi { let msg = Fibonacci { max }; match self.inner.request().await? { RequestSender::Local(sender) => { - let (reply, request) = mpsc::channel(128); - sender.send((msg, reply)).await?; - Ok(request) + let (tx, rx) = mpsc::channel(128); + sender.send((msg, tx)).await?; + Ok(rx) } RequestSender::Remote(sender) => { - let (_reply, request) = sender.write(msg).await?; - Ok(request.into()) + let (_tx, rx) = sender.write(msg).await?; + Ok(rx.into()) } } } @@ -249,14 +249,14 @@ impl ComputeApi { let msg = Multiply { initial }; match self.inner.request().await? { RequestSender::Local(sender) => { - let (in_reply, in_request) = mpsc::channel(128); - let (out_reply, out_request) = mpsc::channel(128); - sender.send((msg, out_reply, in_request)).await?; - Ok((in_reply, out_request)) + let (in_tx, in_rx) = mpsc::channel(128); + let (out_tx, out_rx) = mpsc::channel(128); + sender.send((msg, out_tx, in_rx)).await?; + Ok((in_tx, out_rx)) } RequestSender::Remote(sender) => { - let (reply, request) = sender.write(msg).await?; - Ok((reply.into(), request.into())) + let (tx, rx) = sender.write(msg).await?; + Ok((tx.into(), rx.into())) } } } diff --git a/examples/derive.rs b/examples/derive.rs index a8927c5..e38980d 100644 --- a/examples/derive.rs +++ b/examples/derive.rs @@ -174,12 +174,12 @@ async fn client_demo(api: StorageApi) -> Result<()> { let value = api.get("hello".to_string()).await?; println!("get: hello = {:?}", value); - let (reply, request) = api.set_many().await?; + let (tx, rx) = api.set_many().await?; for i in 0..3 { - reply.send((format!("key{i}"), format!("value{i}"))).await?; + tx.send((format!("key{i}"), format!("value{i}"))).await?; } - drop(reply); - let count = request.await?; + drop(tx); + let count = rx.await?; println!("set-many: {count} values set"); let mut list = api.list().await?; From f9bed52497abceb732063e3f1a87b006bc3c28de Mon Sep 17 00:00:00 2001 From: Frando Date: Thu, 19 Jun 2025 13:08:52 +0200 Subject: [PATCH 15/17] fixup --- examples/storage.rs | 12 ++++++------ irpc-derive/src/lib.rs | 4 ++-- src/lib.rs | 38 ++++++++++++++++++-------------------- 3 files changed, 26 insertions(+), 28 deletions(-) diff --git a/examples/storage.rs b/examples/storage.rs index 33a8d8a..7120d5d 100644 --- a/examples/storage.rs +++ b/examples/storage.rs @@ -27,16 +27,16 @@ struct Get { } impl Channels for Get { - type Request = NoReceiver; - type Response = oneshot::Sender>; + type Updates = NoReceiver; + type Reply = oneshot::Sender>; } #[derive(Debug, Serialize, Deserialize)] struct List; impl Channels for List { - type Request = NoReceiver; - type Response = mpsc::Sender; + type Updates = NoReceiver; + type Reply = mpsc::Sender; } #[derive(Debug, Serialize, Deserialize)] @@ -46,8 +46,8 @@ struct Set { } impl Channels for Set { - type Request = NoReceiver; - type Response = oneshot::Sender<()>; + type Updates = NoReceiver; + type Reply = oneshot::Sender<()>; } #[derive(derive_more::From, Serialize, Deserialize)] diff --git a/irpc-derive/src/lib.rs b/irpc-derive/src/lib.rs index 0d5b52c..e9f9a14 100644 --- a/irpc-derive/src/lib.rs +++ b/irpc-derive/src/lib.rs @@ -55,8 +55,8 @@ fn generate_channels_impl( let res = quote! { impl ::irpc::Channels<#service_name> for #request_type { - type Response = #reply; - type Request = #request; + type Reply = #reply; + type Updates = #request; } }; diff --git a/src/lib.rs b/src/lib.rs index 5478b7e..09257aa 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -128,11 +128,11 @@ pub trait Receiver: Debug + Sealed {} /// Trait to specify channels for a message and service pub trait Channels { /// The sender type, can be either mpsc, oneshot or none - type Response: Sender; + type Reply: Sender; /// The receiver type, can be either mpsc, oneshot or none /// /// For many services, the receiver is not needed, so it can be set to [`NoReceiver`]. - type Request: Receiver; + type Updates: Receiver; } /// Channels that abstract over local or remote sending @@ -653,9 +653,9 @@ pub struct Request, S: Service> { /// The request message. pub message: I, /// The return channel to send the reply to. Can be set to [`crate::channel::none::NoSender`] if not needed. - pub reply: >::Response, + pub reply: >::Reply, /// The request channel to receive the request from. Can be set to [`NoReceiver`] if not needed. - pub updates: >::Request, + pub updates: >::Updates, /// The current span where the full message was created. #[cfg(feature = "message_spans")] #[cfg_attr(quicrpc_docsrs, doc(cfg(feature = "message_spans")))] @@ -683,13 +683,13 @@ impl, S: Service> Request { /// Tuple conversion from message message and reply/request channels to a Request struct /// /// For the case where you want both reply and request channels. -impl, S: Service, Response, Updates> From<(I, Response, Updates)> for Request +impl, S: Service, Reply, Updates> From<(I, Reply, Updates)> for Request where I: Channels, - >::Response: From, - >::Request: From, + >::Reply: From, + >::Updates: From, { - fn from(message: (I, Response, Updates)) -> Self { + fn from(message: (I, Reply, Updates)) -> Self { let (message, reply, updates) = message; Self { message, @@ -705,13 +705,13 @@ where /// Tuple conversion from message message and reply channel to a Request struct /// /// For the very common case where you just need a reply channel to send the reply to. -impl From<(I, Response)> for Request +impl From<(I, Reply)> for Request where - I: Channels, + I: Channels, S: Service, - >::Response: From, + >::Reply: From, { - fn from(message: (I, Response)) -> Self { + fn from(message: (I, Reply)) -> Self { let (message, reply) = message; Self { message, @@ -855,7 +855,7 @@ impl Client { S: Service, M: From> + Send + Sync + Unpin + 'static, R: From + Serialize + Send + Sync + 'static, - Req: Channels, Request = NoReceiver> + Req: Channels, Updates = NoReceiver> + Send + 'static, Res: RpcMessage, @@ -891,9 +891,7 @@ impl Client { S: Service, M: From> + Send + Sync + Unpin + 'static, R: From + Serialize + Send + Sync + 'static, - Req: Channels, Request = NoReceiver> - + Send - + 'static, + Req: Channels, Updates = NoReceiver> + Send + 'static, Res: RpcMessage, { let request = self.request(); @@ -933,8 +931,8 @@ impl Client { R: From + Serialize + 'static, Req: Channels< S, - Response = channel::oneshot::Sender, - Request = channel::mpsc::Receiver, + Reply = channel::oneshot::Sender, + Updates = channel::mpsc::Receiver, >, Update: RpcMessage, Res: RpcMessage, @@ -978,8 +976,8 @@ impl Client { R: From + Serialize + Send + 'static, Req: Channels< S, - Response = channel::mpsc::Sender, - Request = channel::mpsc::Receiver, + Reply = channel::mpsc::Sender, + Updates = channel::mpsc::Receiver, > + Send + 'static, Update: RpcMessage, From 1c36364e3d23c1457205dc921885450c404620f1 Mon Sep 17 00:00:00 2001 From: Frando Date: Thu, 19 Jun 2025 13:15:19 +0200 Subject: [PATCH 16/17] fixup --- examples/derive.rs | 4 ++-- irpc-derive/src/lib.rs | 28 ++++++++++++++-------------- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/examples/derive.rs b/examples/derive.rs index e38980d..2743e07 100644 --- a/examples/derive.rs +++ b/examples/derive.rs @@ -137,12 +137,12 @@ impl StorageApi { pub fn listen(&self, endpoint: quinn::Endpoint) -> Result> { let local = self.inner.local().context("cannot listen on remote API")?; - let handler: Handler = Arc::new(move |msg, request, reply| { + let handler: Handler = Arc::new(move |msg, updates, reply| { let local = local.clone(); Box::pin(match msg { StorageProtocol::Get(msg) => local.send((msg, reply)), StorageProtocol::Set(msg) => local.send((msg, reply)), - StorageProtocol::SetMany(msg) => local.send((msg, reply, request)), + StorageProtocol::SetMany(msg) => local.send((msg, reply, updates)), StorageProtocol::List(msg) => local.send((msg, reply)), }) }); diff --git a/irpc-derive/src/lib.rs b/irpc-derive/src/lib.rs index e9f9a14..66334c8 100644 --- a/irpc-derive/src/lib.rs +++ b/irpc-derive/src/lib.rs @@ -31,7 +31,7 @@ fn generate_parent_span_impl(enum_name: &Ident, variant_names: &[&Ident]) -> Tok /// Get the parent span of the message pub fn parent_span(&self) -> tracing::Span { let span = match self { - #(#enum_name::#variant_names(message) => message.parent_span_opt()),* + #(#enum_name::#variant_names(inner) => inner.parent_span_opt()),* }; span.cloned().unwrap_or_else(|| ::tracing::Span::current()) } @@ -45,9 +45,9 @@ fn generate_channels_impl( request_type: &Type, attr_span: Span, ) -> syn::Result { - // Try to get request, default to NoReceiver if not present + // Try to get updates, default to NoReceiver if not present // Use unwrap_or_else for a cleaner default - let request = args.types.remove(RX_ATTR).unwrap_or_else(|| { + let updates = args.types.remove(RX_ATTR).unwrap_or_else(|| { // We can safely unwrap here because this is a known valid type syn::parse_str::(DEFAULT_RX_TYPE).expect("Failed to parse default request type") }); @@ -56,7 +56,7 @@ fn generate_channels_impl( let res = quote! { impl ::irpc::Channels<#service_name> for #request_type { type Reply = #reply; - type Updates = #request; + type Updates = #updates; } }; @@ -72,10 +72,10 @@ fn generate_case_from_impls( let mut impls = quote! {}; // Generate From implementations for each case that has an rpc attribute - for (variant_name, message_type) in variants_with_attr { + for (variant_name, inner_type) in variants_with_attr { let impl_tokens = quote! { - impl From<#message_type> for #enum_name { - fn from(value: #message_type) -> Self { + impl From<#inner_type> for #enum_name { + fn from(value: #inner_type) -> Self { #enum_name::#variant_name(value) } } @@ -99,10 +99,10 @@ fn generate_message_enum_from_impls( let mut impls = quote! {}; // Generate From> implementations for each case with an rpc attribute - for (variant_name, message_type) in variants_with_attr { + for (variant_name, inner_type) in variants_with_attr { let impl_tokens = quote! { - impl From<::irpc::Request<#message_type, #service_name>> for #message_enum_name { - fn from(value: ::irpc::Request<#message_type, #service_name>) -> Self { + impl From<::irpc::Request<#inner_type, #service_name>> for #message_enum_name { + fn from(value: ::irpc::Request<#inner_type, #service_name>) -> Self { #message_enum_name::#variant_name(value) } } @@ -125,7 +125,7 @@ fn generate_type_aliases( ) -> TokenStream2 { let mut aliases = quote! {}; - for (variant_name, message_type) in variants { + for (variant_name, inner_type) in variants { // Create a type name using the variant name + suffix // For example: Sum + "Msg" = SumMsg let type_name = format!("{}{}", variant_name, suffix); @@ -133,7 +133,7 @@ fn generate_type_aliases( let alias = quote! { /// Type alias for Request<#message_type, #service_name> - pub type #type_ident = ::irpc::Request<#message_type, #service_name>; + pub type #type_ident = ::irpc::Request<#inner_type, #service_name>; }; aliases = quote! { @@ -299,10 +299,10 @@ pub fn rpc_requests(attr: TokenStream, item: TokenStream) -> TokenStream { let extended_enum_code = if let Some(message_enum_name) = message_enum_name { let message_variants = all_variants .iter() - .map(|(variant_name, message_type)| { + .map(|(variant_name, inner_type)| { quote! { #[allow(missing_docs)] - #variant_name(::irpc::Request<#message_type, #service_name>) + #variant_name(::irpc::Request<#inner_type, #service_name>) } }) .collect::>(); From 475f40ecac59b79b1a0d1bc523e3ac222ebc545d Mon Sep 17 00:00:00 2001 From: Frando Date: Thu, 19 Jun 2025 13:19:19 +0200 Subject: [PATCH 17/17] fixup --- irpc-iroh/examples/auth.rs | 12 ++++++------ irpc-iroh/examples/derive.rs | 8 ++++---- irpc-iroh/src/lib.rs | 8 ++++---- src/util.rs | 2 +- tests/compile_fail/extra_attr_types.rs | 4 ++-- tests/compile_fail/extra_attr_types.stderr | 2 +- tests/derive.rs | 2 +- 7 files changed, 19 insertions(+), 19 deletions(-) diff --git a/irpc-iroh/examples/auth.rs b/irpc-iroh/examples/auth.rs index cc28d0b..39ea5a4 100644 --- a/irpc-iroh/examples/auth.rs +++ b/irpc-iroh/examples/auth.rs @@ -173,15 +173,15 @@ mod storage { fn upcast_message( msg: StorageProtocol, - request: RecvStream, + updates: RecvStream, reply: SendStream, ) -> StorageMessage { match msg { - StorageProtocol::Auth(msg) => Request::from((msg, reply, request)).into(), - StorageProtocol::Get(msg) => Request::from((msg, reply, request)).into(), - StorageProtocol::Set(msg) => Request::from((msg, reply, request)).into(), - StorageProtocol::SetMany(msg) => Request::from((msg, reply, request)).into(), - StorageProtocol::List(msg) => Request::from((msg, reply, request)).into(), + StorageProtocol::Auth(msg) => Request::from((msg, reply, updates)).into(), + StorageProtocol::Get(msg) => Request::from((msg, reply, updates)).into(), + StorageProtocol::Set(msg) => Request::from((msg, reply, updates)).into(), + StorageProtocol::SetMany(msg) => Request::from((msg, reply, updates)).into(), + StorageProtocol::List(msg) => Request::from((msg, reply, updates)).into(), } } diff --git a/irpc-iroh/examples/derive.rs b/irpc-iroh/examples/derive.rs index d1fcb8a..1db185f 100644 --- a/irpc-iroh/examples/derive.rs +++ b/irpc-iroh/examples/derive.rs @@ -108,13 +108,13 @@ mod storage { impl StorageActor { pub fn spawn() -> StorageApi { - let (reply, request) = tokio::sync::mpsc::channel(1); + let (tx, rx) = tokio::sync::mpsc::channel(1); let actor = Self { - recv: request, + recv: rx, state: BTreeMap::new(), }; n0_future::task::spawn(actor.run()); - let local = LocalSender::::from(reply); + let local = LocalSender::::from(tx); StorageApi { inner: local.into(), } @@ -175,7 +175,7 @@ mod storage { .inner .local() .context("can not listen on remote service")?; - let handler: Handler = Arc::new(move |msg, _request, reply| { + let handler: Handler = Arc::new(move |msg, _updates, reply| { let local = local.clone(); Box::pin(match msg { StorageProtocol::Get(msg) => local.send((msg, reply)), diff --git a/irpc-iroh/src/lib.rs b/irpc-iroh/src/lib.rs index 6accbbe..9c852db 100644 --- a/irpc-iroh/src/lib.rs +++ b/irpc-iroh/src/lib.rs @@ -128,10 +128,10 @@ pub async fn handle_connection( handler: Handler, ) -> io::Result<()> { loop { - let Some((msg, request, reply)) = read_request(&connection).await? else { + let Some((msg, updates, reply)) = read_request(&connection).await? else { return Ok(()); }; - handler(msg, request, reply).await?; + handler(msg, updates, reply).await?; } } @@ -166,9 +166,9 @@ pub async fn read_request( .map_err(|e| io::Error::new(io::ErrorKind::UnexpectedEof, e))?; let msg: R = postcard::from_bytes(&buf).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; - let request = recv; + let updates = recv; let reply = send; - Ok(Some((msg, request, reply))) + Ok(Some((msg, updates, reply))) } /// Utility function to listen for incoming connections and handle them with the provided handler diff --git a/src/util.rs b/src/util.rs index e89b044..5d04877 100644 --- a/src/util.rs +++ b/src/util.rs @@ -118,7 +118,7 @@ mod quinn_setup_utils { _end_entity: &rustls::pki_types::CertificateDer<'_>, _intermediates: &[rustls::pki_types::CertificateDer<'_>], _server_name: &rustls::pki_types::ServerName<'_>, - _ocsp_reply: &[u8], + _ocsp_response: &[u8], _now: rustls::pki_types::UnixTime, ) -> Result { Ok(rustls::client::danger::ServerCertVerified::assertion()) diff --git a/tests/compile_fail/extra_attr_types.rs b/tests/compile_fail/extra_attr_types.rs index 94f5e60..d5eea11 100644 --- a/tests/compile_fail/extra_attr_types.rs +++ b/tests/compile_fail/extra_attr_types.rs @@ -2,8 +2,8 @@ use irpc::rpc_requests; #[rpc_requests(Service, Msg)] enum Enum { - #[rpc(reply = NoSender, request = NoReceiver, fnord = Foo)] + #[rpc(reply = NoSender, updates = NoReceiver, fnord = Foo)] A(u8), } -fn main() {} \ No newline at end of file +fn main() {} diff --git a/tests/compile_fail/extra_attr_types.stderr b/tests/compile_fail/extra_attr_types.stderr index 7f71b9a..b117330 100644 --- a/tests/compile_fail/extra_attr_types.stderr +++ b/tests/compile_fail/extra_attr_types.stderr @@ -1,5 +1,5 @@ error: Unknown arguments provided: ["fnord"] --> tests/compile_fail/extra_attr_types.rs:5:5 | -5 | #[rpc(reply = NoSender, request = NoReceiver, fnord = Foo)] +5 | #[rpc(reply = NoSender, updates = NoReceiver, fnord = Foo)] | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/tests/derive.rs b/tests/derive.rs index 4fef41d..666cde8 100644 --- a/tests/derive.rs +++ b/tests/derive.rs @@ -36,7 +36,7 @@ fn derive_simple() { #[derive(Debug, Serialize, Deserialize)] struct Response4; - #[rpc_requests(Service, message = RequestRequest)] + #[rpc_requests(Service, message = RequestWithChannels)] #[derive(Debug, Serialize, Deserialize)] enum Request { #[rpc(reply=oneshot::Sender<()>)]