From 96000ed315aaaa48073f8b055022c633231212e5 Mon Sep 17 00:00:00 2001 From: Frando Date: Wed, 23 Jul 2025 13:59:46 +0200 Subject: [PATCH 1/2] feat: nested services --- examples/nested.rs | 96 ++++++++++++ examples/storage.rs | 93 +++++++++++- irpc-derive/src/lib.rs | 147 +++++++++++++------ src/lib.rs | 321 ++++++++++++++++++++++++++++++++++------- 4 files changed, 557 insertions(+), 100 deletions(-) create mode 100644 examples/nested.rs diff --git a/examples/nested.rs b/examples/nested.rs new file mode 100644 index 0000000..8429bcb --- /dev/null +++ b/examples/nested.rs @@ -0,0 +1,96 @@ +use std::collections::HashMap; + +use irpc::{channel::oneshot, rpc_requests, Client}; +use serde::{Deserialize, Serialize}; + +#[rpc_requests(message = TestMessage)] +#[derive(Debug, Serialize, Deserialize)] +enum TestProtocol { + #[rpc(tx = oneshot::Sender<()>)] + Put(PutRequest), + #[rpc(tx = oneshot::Sender>)] + Get(GetRequest), + #[rpc(nested = NestedMessage)] + Nested(NestedProtocol), +} + +#[derive(Debug, Serialize, Deserialize)] +struct PutRequest { + key: String, + value: String, +} + +#[derive(Debug, Serialize, Deserialize)] +struct GetRequest { + key: String, +} + +#[rpc_requests(message = NestedMessage)] +#[derive(Debug, Serialize, Deserialize)] +enum NestedProtocol { + #[rpc(tx = oneshot::Sender<()>)] + Put(PutRequest2), +} + +#[derive(Debug, Serialize, Deserialize)] +struct PutRequest2 { + key: String, + value: u32, +} + +#[tokio::main] +async fn main() -> anyhow::Result<()> { + let (tx, rx) = tokio::sync::mpsc::channel(10); + tokio::task::spawn(actor(rx)); + let client: Client = Client::from(tx); + client + .rpc(PutRequest { + key: "foo".to_string(), + value: "bar".to_string(), + }) + .await?; + let v = client + .rpc(GetRequest { + key: "foo".to_string(), + }) + .await?; + println!("{v:?}"); + assert_eq!(v.as_deref(), Some("bar")); + client + .map::() + .rpc(PutRequest2 { + key: "foo".to_string(), + value: 22, + }) + .await?; + let v = client + .rpc(GetRequest { + key: "foo".to_string(), + }) + .await?; + println!("{v:?}"); + assert_eq!(v.as_deref(), Some("22")); + Ok(()) +} + +async fn actor(mut rx: tokio::sync::mpsc::Receiver) { + let mut store = HashMap::new(); + while let Some(msg) = rx.recv().await { + match msg { + TestMessage::Put(msg) => { + store.insert(msg.inner.key, msg.inner.value); + msg.tx.send(()).await.ok(); + } + TestMessage::Get(msg) => { + let res = store.get(&msg.key).cloned(); + msg.tx.send(res).await.ok(); + } + TestMessage::Nested(inner) => match inner { + NestedMessage::Put(msg) => { + store.insert(msg.inner.key, msg.inner.value.to_string()); + msg.tx.send(()).await.ok(); + } + }, + } + } +} diff --git a/examples/storage.rs b/examples/storage.rs index 100a16a..5685927 100644 --- a/examples/storage.rs +++ b/examples/storage.rs @@ -1,3 +1,6 @@ +//! This example does not use the `rpc_requests` macro and instead implements +//! everything manually. + use std::{ collections::BTreeMap, net::{Ipv4Addr, SocketAddr, SocketAddrV4}, @@ -8,12 +11,14 @@ use irpc::{ channel::{mpsc, none::NoReceiver, oneshot}, rpc::{listen, RemoteService}, util::{make_client_endpoint, make_server_endpoint}, - Channels, Client, Request, Service, WithChannels, + Channels, Client, MappedClient, Request, Service, WithChannels, }; use n0_future::task::{self, AbortOnDropHandle}; use serde::{Deserialize, Serialize}; use tracing::info; +use self::shout_crate::*; + impl Service for StorageProtocol { type Message = StorageMessage; } @@ -52,6 +57,7 @@ enum StorageProtocol { Get(Get), Set(Set), List(List), + Shout(ShoutProtocol), } #[derive(derive_more::From)] @@ -59,6 +65,7 @@ enum StorageMessage { Get(WithChannels), Set(WithChannels), List(WithChannels), + Shout(ShoutMessage), } impl RemoteService for StorageProtocol { @@ -67,6 +74,64 @@ impl RemoteService for StorageProtocol { StorageProtocol::Get(msg) => WithChannels::from((msg, tx, rx)).into(), StorageProtocol::Set(msg) => WithChannels::from((msg, tx, rx)).into(), StorageProtocol::List(msg) => WithChannels::from((msg, tx, rx)).into(), + StorageProtocol::Shout(msg) => msg.with_remote_channels(rx, tx).into(), + } + } +} + +/// This is a protocol that could live in a different crate. +mod shout_crate { + use irpc::{ + channel::{none::NoReceiver, oneshot}, + rpc::RemoteService, + Channels, Service, WithChannels, + }; + use serde::{Deserialize, Serialize}; + use tracing::info; + + #[derive(derive_more::From, Serialize, Deserialize, Debug)] + pub enum ShoutProtocol { + Shout(Shout), + } + + impl Service for ShoutProtocol { + type Message = ShoutMessage; + } + + #[derive(Debug, Serialize, Deserialize)] + pub struct Shout { + pub key: String, + } + + impl Channels for Shout { + type Rx = NoReceiver; + type Tx = oneshot::Sender; + } + + #[derive(derive_more::From)] + pub enum ShoutMessage { + Shout(WithChannels), + } + + impl RemoteService for ShoutProtocol { + fn with_remote_channels( + self, + rx: quinn::RecvStream, + tx: quinn::SendStream, + ) -> Self::Message { + match self { + ShoutProtocol::Shout(msg) => WithChannels::from((msg, tx, rx)).into(), + } + } + } + + pub async fn handle_message(msg: ShoutMessage) { + match msg { + ShoutMessage::Shout(msg) => { + info!("shout.shout: {msg:?}"); + let WithChannels { tx, inner, .. } = msg; + tx.send(inner.key.to_uppercase()).await.ok(); + } } } } @@ -84,9 +149,9 @@ impl StorageActor { state: BTreeMap::new(), }; n0_future::task::spawn(actor.run()); - StorageApi { - inner: Client::local(tx), - } + let inner = Client::local(tx); + let shout = inner.map().to_owned(); + StorageApi { inner, shout } } async fn run(mut self) { @@ -117,18 +182,22 @@ impl StorageActor { } } } + // We delegate these messages to the handler in the other "crate". + StorageMessage::Shout(msg) => shout_crate::handle_message(msg).await, } } } + struct StorageApi { inner: Client, + shout: MappedClient<'static, StorageProtocol, ShoutProtocol>, } impl StorageApi { pub fn connect(endpoint: quinn::Endpoint, addr: SocketAddr) -> anyhow::Result { - Ok(StorageApi { - inner: Client::quinn(endpoint, addr), - }) + let inner = Client::quinn(endpoint, addr); + let shout = inner.map().to_owned(); + Ok(StorageApi { inner, shout }) } pub fn listen(&self, endpoint: quinn::Endpoint) -> anyhow::Result> { @@ -185,6 +254,12 @@ impl StorageApi { } } } + + pub async fn shout(&self, key: String) -> anyhow::Result { + let msg = Shout { key }; + let res = self.shout.rpc(msg).await?; + Ok(res) + } } async fn local() -> anyhow::Result<()> { @@ -198,6 +273,8 @@ async fn local() -> anyhow::Result<()> { println!("list value = {value:?}"); } println!("value = {value:?}"); + let res = api.shout("hello".to_string()).await?; + println!("shout.shout = {res:?}"); Ok(()) } @@ -222,6 +299,8 @@ async fn remote() -> anyhow::Result<()> { while let Some(value) = list.recv().await? { println!("list value = {value:?}"); } + let shout = api.shout("hello".to_string()).await?; + println!("shout.shout = {shout:?}"); drop(handle); Ok(()) } diff --git a/irpc-derive/src/lib.rs b/irpc-derive/src/lib.rs index a295048..78cf4bf 100644 --- a/irpc-derive/src/lib.rs +++ b/irpc-derive/src/lib.rs @@ -29,29 +29,28 @@ fn generate_parent_span_impl(enum_name: &Ident, variant_names: &[&Ident]) -> Tok quote! { impl #enum_name { /// Get the parent span of the message - pub fn parent_span(&self) -> ::tracing::Span { - let span = match self { - #(#enum_name::#variant_names(inner) => inner.parent_span_opt()),* - }; - span.cloned().unwrap_or_else(|| ::tracing::Span::current()) + pub fn parent_span(&self) -> tracing::Span { + match self { + #(#enum_name::#variant_names(inner) => inner.parent_span().clone()),* + } } } } } fn generate_channels_impl( - mut args: NamedTypeArgs, + mut types: NamedTypes, service_name: &Ident, request_type: &Type, attr_span: Span, ) -> syn::Result { // Try to get rx, default to NoReceiver if not present // Use unwrap_or_else for a cleaner default - let rx = args.types.remove(RX_ATTR).unwrap_or_else(|| { + let rx = types.0.remove(RX_ATTR).unwrap_or_else(|| { // We can safely unwrap here because this is a known valid type syn::parse_str::(DEFAULT_RX_TYPE).expect("Failed to parse default rx type") }); - let tx = args.get(TX_ATTR, attr_span)?; + let tx = types.get(TX_ATTR, attr_span)?; let res = quote! { impl ::irpc::Channels<#service_name> for #request_type { @@ -60,19 +59,19 @@ fn generate_channels_impl( } }; - args.check_empty(attr_span)?; + types.check_empty(attr_span)?; Ok(res) } /// Generates From implementations for protocol enum variants. fn generate_protocol_enum_from_impls( enum_name: &Ident, - variants_with_attr: &[(Ident, Type)], + variants_with_attr: &[(Ident, Type, VariantKind)], ) -> TokenStream2 { let mut impls = quote! {}; // Generate From implementations for each case that has an rpc attribute - for (variant_name, inner_type) in variants_with_attr { + for (variant_name, inner_type, _) in variants_with_attr { let impl_tokens = quote! { impl From<#inner_type> for #enum_name { fn from(value: #inner_type) -> Self { @@ -93,17 +92,30 @@ fn generate_protocol_enum_from_impls( /// Generates From implementations for message enum variants. fn generate_message_enum_from_impls( message_enum_name: &Ident, - variants_with_attr: &[(Ident, Type)], + variants_with_attr: &[(Ident, Type, VariantKind)], service_name: &Ident, ) -> TokenStream2 { let mut impls = quote! {}; // Generate From> implementations for each case with an rpc attribute - for (variant_name, inner_type) in variants_with_attr { - let impl_tokens = quote! { - impl From<::irpc::WithChannels<#inner_type, #service_name>> for #message_enum_name { - fn from(value: ::irpc::WithChannels<#inner_type, #service_name>) -> Self { - #message_enum_name::#variant_name(value) + for (variant_name, inner_type, kind) in variants_with_attr { + let impl_tokens = match kind { + VariantKind::Direct => { + quote! { + impl From<::irpc::WithChannels<#inner_type, #service_name>> for #message_enum_name { + fn from(value: ::irpc::WithChannels<#inner_type, #service_name>) -> Self { + #message_enum_name::#variant_name(value) + } + } + } + } + VariantKind::Nested(ident) => { + quote! { + impl From<#ident> for #message_enum_name { + fn from(value: #ident) -> Self { + #message_enum_name::#variant_name(value) + } + } } } }; @@ -121,14 +133,23 @@ fn generate_message_enum_from_impls( fn generate_remote_service_impl( message_enum_name: &Ident, proto_enum_name: &Ident, - variants_with_attr: &[(Ident, Type)], + variants_with_attr: &[(Ident, Type, VariantKind)], ) -> TokenStream2 { let variants = variants_with_attr .iter() - .map(|(variant_name, _inner_type)| { - quote! { - #proto_enum_name::#variant_name(msg) => { - #message_enum_name::from(::irpc::WithChannels::from((msg, tx, rx))) + .map(|(variant_name, _inner_type, kind)| match kind { + VariantKind::Direct => { + quote! { + #proto_enum_name::#variant_name(msg) => { + #message_enum_name::from(::irpc::WithChannels::from((msg, tx, rx))) + } + } + } + VariantKind::Nested(_) => { + quote! { + #proto_enum_name::#variant_name(msg) => { + #message_enum_name::from(msg.with_remote_channels(rx, tx)) + } } } }); @@ -176,6 +197,11 @@ fn generate_type_aliases( aliases } +enum VariantKind { + Direct, + Nested(Ident), +} + #[proc_macro_attribute] pub fn rpc_requests(attr: TokenStream, item: TokenStream) -> TokenStream { let mut input = parse_macro_input!(item as DeriveInput); @@ -253,16 +279,31 @@ pub fn rpc_requests(attr: TokenStream, item: TokenStream) -> TokenStream { // Process variants with rpc attributes if let Some(attr) = rpc_attr { - variants_with_attr.push((variant.ident.clone(), request_type.clone())); - - let args = match attr.parse_args::() { + let args = match attr.parse_args::() { Ok(info) => info, Err(e) => return e.to_compile_error().into(), }; - match generate_channels_impl(args, enum_name, request_type, attr.span()) { - Ok(impls) => channel_impls.push(impls), - Err(e) => return e.to_compile_error().into(), + match args { + VariantArgs::NamedTypes(types) => { + variants_with_attr.push(( + variant.ident.clone(), + request_type.clone(), + VariantKind::Direct, + )); + + match generate_channels_impl(types, enum_name, request_type, attr.span()) { + Ok(impls) => channel_impls.push(impls), + Err(e) => return e.to_compile_error().into(), + } + } + VariantArgs::Nested(ident) => { + variants_with_attr.push(( + variant.ident.clone(), + request_type.clone(), + VariantKind::Nested(ident), + )); + } } } } @@ -281,12 +322,20 @@ pub fn rpc_requests(attr: TokenStream, item: TokenStream) -> TokenStream { // Generate the extended message enum if requested let extended_enum_code = if let Some(message_enum_name) = args.message_enum_name.as_ref() { - let message_variants = all_variants + let message_variants = variants_with_attr .iter() - .map(|(variant_name, inner_type)| { - quote! { - #[allow(missing_docs)] - #variant_name(::irpc::WithChannels<#inner_type, #enum_name>) + .map(|(variant_name, inner_type, kind)| match kind { + VariantKind::Direct => { + quote! { + #[allow(missing_docs)] + #variant_name(::irpc::WithChannels<#inner_type, #enum_name>) + } + } + VariantKind::Nested(ident) => { + quote! { + #[allow(missing_docs)] + #variant_name(#ident) + } } }) .collect::>(); @@ -443,28 +492,31 @@ impl Parse for MacroArgs { } } -struct NamedTypeArgs { - types: BTreeMap, +enum VariantArgs { + NamedTypes(NamedTypes), + Nested(Ident), } -impl NamedTypeArgs { +struct NamedTypes(BTreeMap); + +impl NamedTypes { /// Get and remove a type from the map, failing if it doesn't exist fn get(&mut self, key: &str, span: Span) -> syn::Result { - self.types + self.0 .remove(key) .ok_or_else(|| syn::Error::new(span, format!("rpc requires a {key} type"))) } /// Fail if there are any unknown arguments remaining fn check_empty(&self, span: Span) -> syn::Result<()> { - if self.types.is_empty() { + if self.0.is_empty() { Ok(()) } else { Err(syn::Error::new( span, format!( "Unknown arguments provided: {:?}", - self.types.keys().collect::>() + self.0.keys().collect::>() ), )) } @@ -472,7 +524,7 @@ impl NamedTypeArgs { } /// Parse the rpc args as a comma separated list of name=type pairs -impl Parse for NamedTypeArgs { +impl Parse for VariantArgs { fn parse(input: ParseStream) -> syn::Result { let mut types = BTreeMap::new(); @@ -483,6 +535,19 @@ impl Parse for NamedTypeArgs { let key: Ident = input.parse()?; let _: Token![=] = input.parse()?; + + if key == "nested" { + return if types.is_empty() { + let value: Ident = input.parse()?; + Ok(VariantArgs::Nested(value)) + } else { + Err(syn::Error::new( + input.span(), + format!("nested may not be combined with other arguments"), + )) + }; + } + let value: Type = input.parse()?; types.insert(key.to_string(), value); @@ -493,6 +558,6 @@ impl Parse for NamedTypeArgs { let _: Token![,] = input.parse()?; } - Ok(NamedTypeArgs { types }) + Ok(VariantArgs::NamedTypes(NamedTypes(types))) } } diff --git a/src/lib.rs b/src/lib.rs index fbb89f9..f439a81 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -76,9 +76,10 @@ //! quic-rpc, this crate does not abstract over the stream type and is focused //! on [iroh](https://docs.rs/iroh/latest/iroh/index.html) and our [iroh quinn fork](https://docs.rs/iroh-quinn/latest/iroh-quinn/index.html). #![cfg_attr(quicrpc_docsrs, feature(doc_cfg))] -use std::{fmt::Debug, future::Future, io, marker::PhantomData, ops::Deref, result}; +use std::{borrow::Cow, fmt::Debug, future::Future, io, marker::PhantomData, ops::Deref, result}; use channel::{mpsc, oneshot}; + /// Processes an RPC request enum and generates trait implementations for use with `irpc`. /// /// This attribute macro may be applied to an enum where each variant represents @@ -208,7 +209,7 @@ pub trait Service: Serialize + DeserializeOwned + Send + Sync + Debug + 'static /// This is expected to be an enum with identical variant names than the /// protocol enum, but its single unit field is the [`WithChannels`] struct /// that contains the inner request plus the `tx` and `rx` channels. - type Message: Send + Unpin + 'static; + type Message: Send + Sync + Unpin + 'static; } mod sealed { @@ -222,7 +223,7 @@ pub trait Sender: Debug + Sealed {} pub trait Receiver: Debug + Sealed {} /// Trait to specify channels for a message and service -pub trait Channels: Send + 'static { +pub trait Channels: Send + Sync + 'static { /// The sender type, can be either mpsc, oneshot or none type Tx: Sender; /// The receiver type, can be either mpsc, oneshot or none @@ -297,8 +298,8 @@ pub mod channel { impl Debug for Sender { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - Self::Tokio(_) => f.debug_tuple("Tokio").finish(), - Self::Boxed(_) => f.debug_tuple("Boxed").finish(), + Self::Tokio(_) => f.debug_tuple("TokioOneshotSender").finish(), + Self::Boxed(_) => f.debug_tuple("BoxedOneshotSender").finish(), } } } @@ -401,8 +402,8 @@ pub mod channel { impl Debug for Receiver { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - Self::Tokio(_) => f.debug_tuple("Tokio").finish(), - Self::Boxed(_) => f.debug_tuple("Boxed").finish(), + Self::Tokio(_) => f.debug_tuple("TokioOneshotReceiver").finish(), + Self::Boxed(_) => f.debug_tuple("BoxedOneshotReceiver").finish(), } } } @@ -780,8 +781,8 @@ impl + Debug, S: Service> Debug for WithChannels { impl, S: Service> WithChannels { /// Get the parent span #[cfg(feature = "spans")] - pub fn parent_span_opt(&self) -> Option<&tracing::Span> { - Some(&self.span) + pub fn parent_span(&self) -> &tracing::Span { + &self.span } } @@ -894,6 +895,16 @@ impl Client { Self(ClientInner::Remote(Box::new(remote)), PhantomData) } + /// Returns a client that maps from a nested service to this client's service. + pub fn map(&self) -> MappedClient + where + SInner: Service, + S::Message: From, + S: From, + { + MappedClient(Cow::Borrowed(self), PhantomData) + } + /// Creates a new client from a `tokio::sync::mpsc::Sender`. pub fn local(tx: tokio::sync::mpsc::Sender) -> Self { tx.into() @@ -958,19 +969,37 @@ impl Client { S::Message: From>, Req: Channels, Rx = NoReceiver>, Res: RpcMessage, + { + self.rpc_inner( + msg, + |msg, tx| S::Message::from(WithChannels::from((msg, tx))), + |msg| S::from(msg), + ) + } + + fn rpc_inner( + &self, + msg: Req, + local: impl Fn(Req, oneshot::Sender) -> S::Message + Send + 'static, + remote: impl Fn(Req) -> S + Send + 'static, + ) -> impl Future> + Send + 'static + where + Req: Send + Sync + 'static, + Res: RpcMessage, { let request = self.request(); async move { let recv: oneshot::Receiver = match request.await? { Request::Local(request) => { let (tx, rx) = oneshot::channel(); - request.send((msg, tx)).await?; + request.send_raw(local(msg, tx)).await?; rx } #[cfg(not(feature = "rpc"))] Request::Remote(_request) => unreachable!(), #[cfg(feature = "rpc")] Request::Remote(request) => { + let msg = remote(msg); let (_tx, rx) = request.write(msg).await?; rx.into() } @@ -979,7 +1008,6 @@ impl Client { Ok(res) } } - /// Performs a request for which the server returns a mpsc receiver. pub fn server_streaming( &self, @@ -991,62 +1019,97 @@ impl Client { S::Message: From>, Req: Channels, Rx = NoReceiver>, Res: RpcMessage, + { + self.server_streaming_inner( + msg, + local_response_cap, + |msg, tx| S::Message::from(WithChannels::from((msg, tx))), + |msg| S::from(msg), + ) + } + + fn server_streaming_inner( + &self, + msg: Req, + local_response_cap: usize, + local: impl Fn(Req, mpsc::Sender) -> S::Message + Send + 'static, + remote: impl Fn(Req) -> S + Send + 'static, + ) -> impl Future>> + Send + 'static + where + Req: Send + Sync + 'static, + Res: RpcMessage, { let request = self.request(); async move { let recv: mpsc::Receiver = match request.await? { Request::Local(request) => { let (tx, rx) = mpsc::channel(local_response_cap); - request.send((msg, tx)).await?; + request.send_raw(local(msg, tx)).await?; rx } #[cfg(not(feature = "rpc"))] Request::Remote(_request) => unreachable!(), #[cfg(feature = "rpc")] Request::Remote(request) => { - let (_tx, rx) = request.write(msg).await?; + let (_tx, rx) = request.write(remote(msg)).await?; rx.into() } }; Ok(recv) } } - /// Performs a request for which the client can send updates. pub fn client_streaming( &self, msg: Req, local_update_cap: usize, - ) -> impl Future, oneshot::Receiver)>> + ) -> impl Future, oneshot::Receiver)>> + Send where - S: From, + S: Service + From, S::Message: From>, Req: Channels, Rx = mpsc::Receiver>, Update: RpcMessage, Res: RpcMessage, + { + self.client_streaming_inner( + msg, + local_update_cap, + |msg, tx, rx| S::Message::from(WithChannels::from((msg, tx, rx))), + |msg| S::from(msg), + ) + } + + fn client_streaming_inner( + &self, + msg: Req, + local_update_cap: usize, + local: impl Fn(Req, oneshot::Sender, mpsc::Receiver) -> S::Message + Send + 'static, + remote: impl Fn(Req) -> S + Send + 'static, + ) -> impl Future, oneshot::Receiver)>> + Send + where + Req: Send + Sync + 'static, + Update: RpcMessage, + Res: RpcMessage, { let request = self.request(); async move { - let (update_tx, res_rx): (mpsc::Sender, oneshot::Receiver) = - match request.await? { - Request::Local(request) => { - let (req_tx, req_rx) = mpsc::channel(local_update_cap); - let (res_tx, res_rx) = oneshot::channel(); - request.send((msg, res_tx, req_rx)).await?; - (req_tx, res_rx) - } - #[cfg(not(feature = "rpc"))] - Request::Remote(_request) => unreachable!(), - #[cfg(feature = "rpc")] - Request::Remote(request) => { - let (tx, rx) = request.write(msg).await?; - (tx.into(), rx.into()) - } - }; - Ok((update_tx, res_rx)) + match request.await? { + Request::Local(request) => { + let (update_tx, update_rx) = mpsc::channel(local_update_cap); + let (res_tx, res_rx) = oneshot::channel(); + request.send_raw(local(msg, res_tx, update_rx)).await?; + Ok((update_tx, res_rx)) + } + #[cfg(not(feature = "rpc"))] + Request::Remote(_) => unreachable!(), + #[cfg(feature = "rpc")] + Request::Remote(request) => { + let (tx, rx) = request.write(remote(msg)).await?; + Ok((tx.into(), rx.into())) + } + } } } - /// Performs a request for which the client can send updates, and the server returns a mpsc receiver. pub fn bidi_streaming( &self, @@ -1055,35 +1118,154 @@ impl Client { local_response_cap: usize, ) -> impl Future, mpsc::Receiver)>> + Send + 'static where - S: From, + S: Service + From, S::Message: From>, Req: Channels, Rx = mpsc::Receiver>, Update: RpcMessage, Res: RpcMessage, + { + self.bidi_streaming_inner( + msg, + local_update_cap, + local_response_cap, + |msg, tx, rx| S::Message::from(WithChannels::from((msg, tx, rx))), + |msg| S::from(msg), + ) + } + + fn bidi_streaming_inner( + &self, + msg: Req, + local_update_cap: usize, + local_response_cap: usize, + local: impl Fn(Req, mpsc::Sender, mpsc::Receiver) -> S::Message + Send + 'static, + remote: impl Fn(Req) -> S + Send + 'static, + ) -> impl Future, mpsc::Receiver)>> + Send + 'static + where + Req: Send + Sync + 'static, + Update: RpcMessage, + Res: RpcMessage, { let request = self.request(); async move { - let (update_tx, res_rx): (mpsc::Sender, mpsc::Receiver) = - match request.await? { - Request::Local(request) => { - let (update_tx, update_rx) = mpsc::channel(local_update_cap); - let (res_tx, res_rx) = mpsc::channel(local_response_cap); - request.send((msg, res_tx, update_rx)).await?; - (update_tx, res_rx) - } - #[cfg(not(feature = "rpc"))] - Request::Remote(_request) => unreachable!(), - #[cfg(feature = "rpc")] - Request::Remote(request) => { - let (tx, rx) = request.write(msg).await?; - (tx.into(), rx.into()) - } - }; - Ok((update_tx, res_rx)) + match request.await? { + Request::Local(request) => { + let (update_tx, update_rx) = mpsc::channel(local_update_cap); + let (res_tx, res_rx) = mpsc::channel(local_response_cap); + request.send_raw(local(msg, res_tx, update_rx)).await?; + Ok((update_tx, res_rx)) + } + #[cfg(not(feature = "rpc"))] + Request::Remote(_) => unreachable!(), + #[cfg(feature = "rpc")] + Request::Remote(request) => { + let (tx, rx) = request.write(remote(msg)).await?; + Ok((tx.into(), rx.into())) + } + } } } } +/// A client that maps to a nested service. +/// +/// See [`Client::map`]. +pub struct MappedClient<'a, S: Service, SInner: Service>(Cow<'a, Client>, PhantomData); + +impl<'a, S, SInner> MappedClient<'a, S, SInner> +where + S: Service + From, + SInner: Service, + S::Message: From, +{ + pub fn to_owned(self) -> MappedClient<'static, S, SInner> { + MappedClient(Cow::Owned(self.0.into_owned()), PhantomData) + } + + /// Performs a request for which the server returns a oneshot receiver. + pub fn rpc(&self, msg: Req) -> impl Future> + Send + 'static + where + SInner: From, + SInner::Message: From>, + Req: Channels, Rx = NoReceiver>, + Res: RpcMessage, + { + self.0.rpc_inner( + msg, + |msg, tx| S::Message::from(SInner::Message::from(WithChannels::from((msg, tx)))), + |msg| S::from(SInner::from(msg)), + ) + } + + /// Performs a request for which the server returns a mpsc receiver. + pub fn server_streaming( + &self, + msg: Req, + local_response_cap: usize, + ) -> impl Future>> + Send + 'static + where + SInner: From, + SInner::Message: From>, + Req: Channels, Rx = NoReceiver>, + Res: RpcMessage, + { + self.0.server_streaming_inner( + msg, + local_response_cap, + |msg, tx| S::Message::from(SInner::Message::from(WithChannels::from((msg, tx)))), + |msg| S::from(SInner::from(msg)), + ) + } + + /// Performs a request for which the client can send updates. + pub fn client_streaming( + &self, + msg: Req, + local_update_cap: usize, + ) -> impl Future, oneshot::Receiver)>> + Send + where + SInner: From, + SInner::Message: From>, + Req: Channels, Rx = mpsc::Receiver>, + Update: RpcMessage, + Res: RpcMessage, + { + self.0.client_streaming_inner( + msg, + local_update_cap, + |msg, tx, rx| { + S::Message::from(SInner::Message::from(WithChannels::from((msg, tx, rx)))) + }, + |msg| S::from(SInner::from(msg)), + ) + } + + /// Performs a request for which the client can send updates, and the server returns a mpsc receiver. + pub fn bidi_streaming( + &self, + msg: Req, + local_update_cap: usize, + local_response_cap: usize, + ) -> impl Future, mpsc::Receiver)>> + Send + 'static + where + SInner: From, + SInner::Message: From>, + Req: Channels, Rx = mpsc::Receiver>, + Update: RpcMessage, + Res: RpcMessage, + { + self.0.bidi_streaming_inner( + msg, + local_update_cap, + local_response_cap, + |msg, tx, rx| { + S::Message::from(SInner::Message::from(WithChannels::from((msg, tx, rx)))) + }, + |msg| S::from(SInner::from(msg)), + ) + } +} + #[derive(Debug)] pub(crate) enum ClientInner { Local(tokio::sync::mpsc::Sender), @@ -1400,6 +1582,25 @@ pub mod rpc { send.write_all(&buf).await?; Ok((send, recv)) } + + // pub async fn write_mapped( + // self, + // msg: M, + // ) -> std::result::Result<(quinn::SendStream, quinn::RecvStream), WriteError> + // where + // M: Into, + // S2: Into, + // { + // let RemoteSender(mut send, recv, _) = self; + // let msg = msg.into().into(); + // if postcard::experimental::serialized_size(&msg)? as u64 > MAX_MESSAGE_SIZE { + // return Err(WriteError::MaxMessageSizeExceeded); + // } + // let mut buf = SmallVec::<[u8; 128]>::new(); + // buf.write_length_prefixed(msg)?; + // send.write_all(&buf).await?; + // Ok((send, recv)) + // } } impl From for oneshot::Receiver { @@ -1826,6 +2027,22 @@ impl LocalSender { SendFut::new(self.0.clone(), value) } + // pub async fn send_mapped( + // &self, + // value: impl Into>, + // ) -> std::result::Result<(), SendError> + // where + // S2: Service, + // S2::Message: From>, + // T: Channels, + // S::Message: From, + // { + // let value: S2::Message = value.into().into(); + // let value: S::Message = value.into(); + // let fut = SendFut::::new(self.0.clone(), value); + // fut.await + // } + /// Send a message to the service without the type conversion magic pub fn send_raw(&self, value: S::Message) -> SendFut { SendFut::new(self.0.clone(), value) From 26c4d8026f29a5de8f0922ab43384f82f3f90db6 Mon Sep 17 00:00:00 2001 From: Frando Date: Fri, 25 Jul 2025 13:50:30 +0200 Subject: [PATCH 2/2] fixup --- Cargo.toml | 4 ++++ src/lib.rs | 36 ++++++++++++++++++++---------------- 2 files changed, 24 insertions(+), 16 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index d781192..61c7e84 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -86,6 +86,10 @@ required-features = ["derive"] name = "storage" required-features = ["rpc", "quinn_endpoint_setup"] +[[example]] +name = "nested" +required-features = ["rpc", "derive", "quinn_endpoint_setup"] + [workspace] members = ["irpc-derive", "irpc-iroh"] diff --git a/src/lib.rs b/src/lib.rs index f439a81..bc643c5 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -980,8 +980,8 @@ impl Client { fn rpc_inner( &self, msg: Req, - local: impl Fn(Req, oneshot::Sender) -> S::Message + Send + 'static, - remote: impl Fn(Req) -> S + Send + 'static, + map_local: impl Fn(Req, oneshot::Sender) -> S::Message + Send + 'static, + _map_remote: impl Fn(Req) -> S + Send + 'static, ) -> impl Future> + Send + 'static where Req: Send + Sync + 'static, @@ -992,14 +992,14 @@ impl Client { let recv: oneshot::Receiver = match request.await? { Request::Local(request) => { let (tx, rx) = oneshot::channel(); - request.send_raw(local(msg, tx)).await?; + request.send_raw(map_local(msg, tx)).await?; rx } #[cfg(not(feature = "rpc"))] Request::Remote(_request) => unreachable!(), #[cfg(feature = "rpc")] Request::Remote(request) => { - let msg = remote(msg); + let msg = _map_remote(msg); let (_tx, rx) = request.write(msg).await?; rx.into() } @@ -1032,8 +1032,8 @@ impl Client { &self, msg: Req, local_response_cap: usize, - local: impl Fn(Req, mpsc::Sender) -> S::Message + Send + 'static, - remote: impl Fn(Req) -> S + Send + 'static, + map_local: impl Fn(Req, mpsc::Sender) -> S::Message + Send + 'static, + _map_remote: impl Fn(Req) -> S + Send + 'static, ) -> impl Future>> + Send + 'static where Req: Send + Sync + 'static, @@ -1044,14 +1044,14 @@ impl Client { let recv: mpsc::Receiver = match request.await? { Request::Local(request) => { let (tx, rx) = mpsc::channel(local_response_cap); - request.send_raw(local(msg, tx)).await?; + request.send_raw(map_local(msg, tx)).await?; rx } #[cfg(not(feature = "rpc"))] Request::Remote(_request) => unreachable!(), #[cfg(feature = "rpc")] Request::Remote(request) => { - let (_tx, rx) = request.write(remote(msg)).await?; + let (_tx, rx) = request.write(_map_remote(msg)).await?; rx.into() } }; @@ -1083,8 +1083,10 @@ impl Client { &self, msg: Req, local_update_cap: usize, - local: impl Fn(Req, oneshot::Sender, mpsc::Receiver) -> S::Message + Send + 'static, - remote: impl Fn(Req) -> S + Send + 'static, + map_local: impl Fn(Req, oneshot::Sender, mpsc::Receiver) -> S::Message + + Send + + 'static, + _map_remote: impl Fn(Req) -> S + Send + 'static, ) -> impl Future, oneshot::Receiver)>> + Send where Req: Send + Sync + 'static, @@ -1097,14 +1099,14 @@ impl Client { Request::Local(request) => { let (update_tx, update_rx) = mpsc::channel(local_update_cap); let (res_tx, res_rx) = oneshot::channel(); - request.send_raw(local(msg, res_tx, update_rx)).await?; + request.send_raw(map_local(msg, res_tx, update_rx)).await?; Ok((update_tx, res_rx)) } #[cfg(not(feature = "rpc"))] Request::Remote(_) => unreachable!(), #[cfg(feature = "rpc")] Request::Remote(request) => { - let (tx, rx) = request.write(remote(msg)).await?; + let (tx, rx) = request.write(_map_remote(msg)).await?; Ok((tx.into(), rx.into())) } } @@ -1138,8 +1140,10 @@ impl Client { msg: Req, local_update_cap: usize, local_response_cap: usize, - local: impl Fn(Req, mpsc::Sender, mpsc::Receiver) -> S::Message + Send + 'static, - remote: impl Fn(Req) -> S + Send + 'static, + map_local: impl Fn(Req, mpsc::Sender, mpsc::Receiver) -> S::Message + + Send + + 'static, + _map_remote: impl Fn(Req) -> S + Send + 'static, ) -> impl Future, mpsc::Receiver)>> + Send + 'static where Req: Send + Sync + 'static, @@ -1152,14 +1156,14 @@ impl Client { Request::Local(request) => { let (update_tx, update_rx) = mpsc::channel(local_update_cap); let (res_tx, res_rx) = mpsc::channel(local_response_cap); - request.send_raw(local(msg, res_tx, update_rx)).await?; + request.send_raw(map_local(msg, res_tx, update_rx)).await?; Ok((update_tx, res_rx)) } #[cfg(not(feature = "rpc"))] Request::Remote(_) => unreachable!(), #[cfg(feature = "rpc")] Request::Remote(request) => { - let (tx, rx) = request.write(remote(msg)).await?; + let (tx, rx) = request.write(_map_remote(msg)).await?; Ok((tx.into(), rx.into())) } }