From eca8177c7ee7e395a3ae27071f1ad839adf0ea1d Mon Sep 17 00:00:00 2001 From: Ruediger Klaehn Date: Thu, 7 Aug 2025 13:58:00 +0300 Subject: [PATCH 1/3] Add StreamItem and associated ext traits and macros. --- irpc-derive/src/lib.rs | 92 ++++++++++++++++++++- src/lib.rs | 35 +++++--- src/util.rs | 182 +++++++++++++++++++++++++++++++++++++++++ 3 files changed, 295 insertions(+), 14 deletions(-) diff --git a/irpc-derive/src/lib.rs b/irpc-derive/src/lib.rs index 1bd98af..2ffacbd 100644 --- a/irpc-derive/src/lib.rs +++ b/irpc-derive/src/lib.rs @@ -8,7 +8,7 @@ use syn::{ parse_macro_input, punctuated::Punctuated, spanned::Spanned, - Data, DeriveInput, Fields, Ident, LitStr, Token, Type, + Data, DeriveInput, Fields, Ident, LitStr, Token, Type, Variant, }; // Helper function for error reporting @@ -610,3 +610,93 @@ fn vis_pub() -> syn::Visibility { }, }) } + +#[proc_macro_derive(StreamItem)] +pub fn derive_irpc_stream_item(input: TokenStream) -> TokenStream { + let input = parse_macro_input!(input as DeriveInput); + let span = input.span(); + let name = input.ident; + let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl(); + + let data = if let Data::Enum(data) = input.data { + data + } else { + return error_tokens(span, "IrpcStreamItem can only be derived for enums"); + }; + + let mut item_variant: Option = None; + let mut error_variant: Option = None; + let mut done_variant: Option = None; + + for variant in data.variants { + let vname = variant.ident.to_string(); + match vname.as_str() { + "Item" => item_variant = Some(variant), + "Error" => error_variant = Some(variant), + "Done" => done_variant = Some(variant), + _ => return error_tokens(span, &format!("Unknown variant: {}", vname)), + } + } + + let Some(item_var) = item_variant else { + return error_tokens(span, "Missing Item variant"); + }; + let Some(error_var) = error_variant else { + return error_tokens(span, "Missing Error variant"); + }; + let Some(done_var) = done_variant else { + return error_tokens(span, "Missing Done variant"); + }; + + let item_field_ty = if let Fields::Unnamed(fields) = item_var.fields { + if fields.unnamed.len() == 1 { + fields.unnamed.into_iter().next().unwrap().ty + } else { + return error_tokens(span, "Item variant must have exactly one unnamed field"); + } + } else { + return error_tokens(span, "Item variant must have unnamed fields"); + }; + + let error_field_ty = if let Fields::Unnamed(fields) = error_var.fields { + if fields.unnamed.len() == 1 { + fields.unnamed.into_iter().next().unwrap().ty + } else { + return error_tokens(span, "Error variant must have exactly one unnamed field"); + } + } else { + return error_tokens(span, "Error variant must have unnamed fields"); + }; + + if !done_var.fields.is_empty() { + return error_tokens(span, "Done variant must be a unit variant with no fields"); + } + + let expanded = quote! { + impl #impl_generics StreamItem for #name #ty_generics #where_clause { + type Error = #error_field_ty; + type Item = #item_field_ty; + + fn into_result_opt(self) -> Option::Item, ::Error>> { + match self { + Self::Item(item) => Some(Ok(item)), + Self::Error(err) => Some(Err(err)), + Self::Done => None, + } + } + + fn from_result(res: std::result::Result<::Item, ::Error>) -> Self { + match res { + Ok(item) => Self::Item(item), + Err(err) => Self::Error(err), + } + } + + fn done() -> Self { + Self::Done + } + } + }; + + TokenStream::from(expanded) +} diff --git a/src/lib.rs b/src/lib.rs index d192d76..a0c01f7 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -87,8 +87,11 @@ use channel::{mpsc, none::NoSender, oneshot}; /// /// Basic usage example: /// ``` -/// use serde::{Serialize, Deserialize}; -/// use irpc::{rpc_requests, channel::{oneshot, mpsc}}; +/// use irpc::{ +/// channel::{mpsc, oneshot}, +/// rpc_requests, +/// }; +/// use serde::{Deserialize, Serialize}; /// /// #[rpc_requests(message = ComputeMessage)] /// #[derive(Debug, Serialize, Deserialize)] @@ -155,8 +158,11 @@ use channel::{mpsc, none::NoSender, oneshot}; /// /// With `wrap`: /// ``` -/// use serde::{Serialize, Deserialize}; -/// use irpc::{rpc_requests, channel::{oneshot, mpsc}, Client}; +/// use irpc::{ +/// channel::{mpsc, oneshot}, +/// rpc_requests, Client, +/// }; +/// use serde::{Deserialize, Serialize}; /// /// #[rpc_requests(message = StoreMessage)] /// #[derive(Debug, Serialize, Deserialize)] @@ -164,11 +170,16 @@ use channel::{mpsc, none::NoSender, oneshot}; /// #[rpc(wrap=GetRequest, tx=oneshot::Sender)] /// Get(String), /// #[rpc(wrap=SetRequest, tx=oneshot::Sender<()>)] -/// Set { key: String, value: String } +/// Set { key: String, value: String }, /// } /// /// async fn client_usage(client: Client) -> anyhow::Result<()> { -/// client.rpc(SetRequest { key: "foo".to_string(), value: "bar".to_string() }).await?; +/// client +/// .rpc(SetRequest { +/// key: "foo".to_string(), +/// value: "bar".to_string(), +/// }) +/// .await?; /// let value = client.rpc(GetRequest("foo".to_string())).await?; /// Ok(()) /// } @@ -192,7 +203,6 @@ use channel::{mpsc, none::NoSender, oneshot}; #[cfg(feature = "derive")] #[cfg_attr(quicrpc_docsrs, doc(cfg(feature = "derive")))] pub use irpc_derive::rpc_requests; - use sealed::Sealed; use serde::{de::DeserializeOwned, Serialize}; @@ -1274,6 +1284,11 @@ pub mod rpc { }; use n0_future::{future::Boxed as BoxFuture, task::JoinSet}; + /// This is used by irpc-derive to refer to quinn types (SendStream and RecvStream) + /// to make generated code work for users without having to depend on quinn directly + /// (i.e. when using iroh). + #[doc(hidden)] + pub use quinn; use quinn::ConnectionError; use serde::de::DeserializeOwned; use smallvec::SmallVec; @@ -1289,12 +1304,6 @@ pub mod rpc { LocalSender, RequestError, RpcMessage, Service, }; - /// This is used by irpc-derive to refer to quinn types (SendStream and RecvStream) - /// to make generated code work for users without having to depend on quinn directly - /// (i.e. when using iroh). - #[doc(hidden)] - pub use quinn; - /// Default max message size (16 MiB). pub const MAX_MESSAGE_SIZE: u64 = 1024 * 1024 * 16; diff --git a/src/util.rs b/src/util.rs index 5d04877..2658b03 100644 --- a/src/util.rs +++ b/src/util.rs @@ -448,3 +448,185 @@ mod now_or_never { } #[cfg(feature = "rpc")] pub(crate) use now_or_never::now_or_never; + +mod stream_item { + use std::{future::Future, io}; + + use n0_future::{stream, Stream, StreamExt}; + + use crate::channel::{mpsc, RecvError, SendError}; + + /// Trait for an enum that has three variants, item, error, and done. + /// + /// This is very common for irpc stream items if you want to provide an explicit + /// end of stream marker to make sure unsuccessful termination is not mistaken + /// for successful end of stream. + pub trait StreamItem: crate::RpcMessage { + /// The error case of the item enum. + type Error; + /// The item case of the item enum. + type Item; + /// Converts the stream item into either None for end of stream, or a Result + /// containing the item or an error. Error is assumed as a termination, so + /// if you get error you won't get an additional end of stream marker. + fn into_result_opt(self) -> Option>; + /// Converts a result into the item enum. + fn from_result(item: std::result::Result) -> Self; + /// Produces a done marker for the item enum. + fn done() -> Self; + } + + pub trait MpscSenderExt: Sized { + /// Forward a stream of items to the sender. + /// + /// This will convert items and errors into the item enum type, and add + /// a done marker if the stream ends without an error. + #[allow(dead_code)] + fn forward_stream( + self, + stream: impl Stream>, + ) -> impl Future>; + + /// Forward an iterator of items to the sender. + /// + /// This will convert items and errors into the item enum type, and add + /// a done marker if the iterator ends without an error. + fn forward_iter( + self, + iter: impl Iterator>, + ) -> impl Future>; + } + + impl MpscSenderExt for mpsc::Sender { + async fn forward_stream( + self, + stream: impl Stream>, + ) -> std::result::Result<(), SendError> { + tokio::pin!(stream); + while let Some(item) = stream.next().await { + let done = item.is_err(); + self.send(T::from_result(item)).await?; + if done { + return Ok(()); + }; + } + self.send(T::done()).await + } + + async fn forward_iter( + self, + iter: impl Iterator>, + ) -> std::result::Result<(), SendError> { + for item in iter { + let done = item.is_err(); + self.send(T::from_result(item)).await?; + if done { + return Ok(()); + }; + } + self.send(T::done()).await + } + } + + pub trait IrpcReceiverFutExt { + /// Collects the receiver returned by this future into a collection, + /// provided that we get a receiver and draining the receiver does not + /// produce any error items. + /// + /// The collection must implement Default and Extend. + /// Note that using this with a very large stream might use a lot of memory. + fn try_collect(self) -> impl Future> + where + C: Default + Extend, + E: From, + E: From, + E: From; + + /// Converts the receiver returned by this future into a stream of items, + /// where each item is either a successful item or an error. + /// + /// There will be at most one error item, which will terminate the stream. + /// If the future returns an error, the stream will yield that error as the + /// first item and then terminate. + fn into_stream(self) -> impl Stream> + where + E: From, + E: From, + E: From; + } + + impl IrpcReceiverFutExt for F + where + T: StreamItem, + F: Future, crate::Error>>, + { + async fn try_collect(self) -> std::result::Result + where + C: Default + Extend, + E: From, + E: From, + E: From, + { + let mut items = C::default(); + let mut stream = self.into_stream::(); + while let Some(item) = stream.next().await { + match item { + Ok(i) => items.extend(Some(i)), + Err(e) => return Err(e), + } + } + Ok(items) + } + + fn into_stream(self) -> impl Stream> + where + E: From, + E: From, + E: From, + { + enum State { + Init(S), + Receiving(mpsc::Receiver), + Done, + } + fn eof() -> RecvError { + io::Error::new(io::ErrorKind::UnexpectedEof, "unexpected end of stream").into() + } + async fn process_recv( + mut rx: mpsc::Receiver, + ) -> Option<(std::result::Result, State)> + where + T: StreamItem, + E: From, + E: From, + E: From, + { + match rx.recv().await { + Ok(Some(item)) => match item.into_result_opt()? { + Ok(i) => Some((Ok(i), State::Receiving(rx))), + Err(e) => Some((Err(E::from(e)), State::Done)), + }, + Ok(None) => Some((Err(E::from(eof())), State::Done)), + Err(e) => Some((Err(E::from(e)), State::Done)), + } + } + Box::pin(stream::unfold(State::Init(self), |state| async move { + match state { + State::Init(fut) => match fut.await { + Ok(rx) => process_recv(rx).await, + Err(e) => Some((Err(E::from(e)), State::Done)), + }, + State::Receiving(rx) => process_recv(rx).await, + State::Done => None, + } + })) + } + } +} + +#[cfg(all(feature = "derive", feature = "stream"))] +#[cfg_attr(quicrpc_docsrs, doc(cfg(all(feature = "derive", feature = "stream"))))] +pub use irpc_derive::StreamItem; +#[cfg(feature = "stream")] +#[cfg_attr(quicrpc_docsrs, doc(cfg(feature = "stream")))] +pub use stream_item::{IrpcReceiverFutExt, MpscSenderExt, StreamItem}; From 92f15fb2fb2cb4bb43d3567b191db28aa7ef78c7 Mon Sep 17 00:00:00 2001 From: Ruediger Klaehn Date: Thu, 7 Aug 2025 18:28:15 +0300 Subject: [PATCH 2/3] Add example for StreamItem util --- Cargo.toml | 4 + examples/stream.rs | 229 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 233 insertions(+) create mode 100644 examples/stream.rs diff --git a/Cargo.toml b/Cargo.toml index 40c2133..698588e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -86,6 +86,10 @@ required-features = ["derive"] name = "storage" required-features = ["rpc", "quinn_endpoint_setup"] +[[example]] +name = "stream" +required-features = ["rpc", "derive", "quinn_endpoint_setup"] + [workspace] members = ["irpc-derive", "irpc-iroh"] diff --git a/examples/stream.rs b/examples/stream.rs new file mode 100644 index 0000000..f8cba09 --- /dev/null +++ b/examples/stream.rs @@ -0,0 +1,229 @@ +use std::{ + collections::BTreeMap, + future::{Future, IntoFuture}, + net::{Ipv4Addr, SocketAddr, SocketAddrV4}, +}; + +use anyhow::{Context, Result}; +use futures_util::future::BoxFuture; +use irpc::{ + channel::{mpsc, oneshot}, + rpc::RemoteService, + rpc_requests, + util::{ + make_client_endpoint, make_server_endpoint, IrpcReceiverFutExt, MpscSenderExt, StreamItem, + }, + Client, WithChannels, +}; +// Import the macro +use n0_future::{ + task::{self, AbortOnDropHandle}, + Stream, StreamExt, +}; +use serde::{Deserialize, Serialize}; +use tracing::info; + +#[derive(Debug, Serialize, Deserialize)] +struct Error { + message: String, +} + +impl std::error::Error for Error {} + +impl std::fmt::Display for Error { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.message) + } +} + +#[derive(Debug, Serialize, Deserialize, StreamItem)] +enum GetItem { + Item(String), + Error(Error), + Done, +} + +#[derive(Debug, Serialize, Deserialize)] +struct Set { + key: String, + value: String, +} + +#[derive(Debug, Serialize, Deserialize)] +struct Get { + key: String, +} + +// Use the macro to generate both the StorageProtocol and StorageMessage enums +// plus implement Channels for each type +#[rpc_requests(message = StorageMessage)] +#[derive(Serialize, Deserialize, Debug)] +enum StorageProtocol { + #[rpc(tx=oneshot::Sender<()>)] + Set(Set), + #[rpc(tx=mpsc::Sender)] + Get(Get), +} + +struct StorageActor { + recv: tokio::sync::mpsc::Receiver, + state: BTreeMap, +} + +struct GetProgress { + fut: BoxFuture<'static, irpc::Result>>, +} + +impl GetProgress { + pub fn new( + fut: impl Future>> + Send + 'static, + ) -> Self { + Self { fut: Box::pin(fut) } + } + + pub fn stream(self) -> impl Stream> { + self.fut.into_stream() + } +} + +impl IntoFuture for GetProgress { + type Output = anyhow::Result; + type IntoFuture = BoxFuture<'static, Self::Output>; + + fn into_future(self) -> Self::IntoFuture { + Box::pin(self.fut.try_collect()) + } +} + +impl StorageActor { + pub fn spawn() -> StorageApi { + let (tx, rx) = tokio::sync::mpsc::channel(1); + let actor = Self { + recv: rx, + state: BTreeMap::new(), + }; + n0_future::task::spawn(actor.run()); + StorageApi { + inner: Client::local(tx), + } + } + + async fn run(mut self) { + while let Some(msg) = self.recv.recv().await { + self.handle(msg).await; + } + } + + async fn handle(&mut self, msg: StorageMessage) { + match msg { + StorageMessage::Get(get) => { + info!("get {:?}", get); + let WithChannels { + tx, + inner: Get { key }, + .. + } = get; + let value = self.state.get(&key).cloned().unwrap_or_default(); + let parts = value.split_inclusive(" "); + tx.forward_iter(parts.map(|x| Ok(x.to_string()))).await.ok(); + } + StorageMessage::Set(set) => { + info!("set {:?}", set); + let WithChannels { + tx, + inner: Set { key, value }, + .. + } = set; + self.state.insert(key, value); + tx.send(()).await.ok(); + } + } + } +} + +struct StorageApi { + inner: Client, +} + +impl StorageApi { + pub fn connect(endpoint: quinn::Endpoint, addr: SocketAddr) -> Result { + Ok(StorageApi { + inner: Client::quinn(endpoint, addr), + }) + } + + pub fn listen(&self, endpoint: quinn::Endpoint) -> Result> { + let local = self + .inner + .as_local() + .context("cannot listen on remote API")?; + let join_handle = task::spawn(irpc::rpc::listen( + endpoint, + StorageProtocol::remote_handler(local), + )); + Ok(AbortOnDropHandle::new(join_handle)) + } + + pub fn get(&self, key: String) -> GetProgress { + GetProgress::new(self.inner.server_streaming(Get { key }, 16)) + } + + pub async fn set(&self, key: String, value: String) -> irpc::Result<()> { + self.inner.rpc(Set { key, value }).await + } +} + +async fn client_demo(api: StorageApi) -> Result<()> { + api.set("hello".to_string(), "world".to_string()).await?; + let value = api.get("hello".to_string()).await?; + println!("get: hello = {value:?}"); + + api.set("loremipsum".to_string(), "dolor sit amet".to_string()) + .await?; + + let mut parts = api.get("loremipsum".to_string()).stream(); + while let Some(part) = parts.next().await { + match part { + Ok(item) => println!("Received item: {item}"), + Err(e) => println!("Error receiving item: {e}"), + } + } + + Ok(()) +} + +async fn local() -> Result<()> { + let api = StorageActor::spawn(); + client_demo(api).await?; + Ok(()) +} + +async fn remote() -> Result<()> { + let port = 10113; + let addr: SocketAddr = SocketAddrV4::new(Ipv4Addr::LOCALHOST, port).into(); + + let (server_handle, cert) = { + let (endpoint, cert) = make_server_endpoint(addr)?; + let api = StorageActor::spawn(); + let handle = api.listen(endpoint)?; + (handle, cert) + }; + + let endpoint = + make_client_endpoint(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0).into(), &[&cert])?; + let api = StorageApi::connect(endpoint, addr)?; + client_demo(api).await?; + + drop(server_handle); + Ok(()) +} + +#[tokio::main] +async fn main() -> Result<()> { + tracing_subscriber::fmt::init(); + println!("Local use"); + local().await?; + println!("Remote use"); + remote().await.unwrap(); + Ok(()) +} From 0a3461782eb060d737e979fd29dfc6de8e4338e6 Mon Sep 17 00:00:00 2001 From: Franz Heinzmann Date: Fri, 8 Aug 2025 15:21:39 +0200 Subject: [PATCH 3/3] refactor: try out a version of stream helpers without a macro (#59) --- examples/stream.rs | 71 +++++------------ irpc-derive/src/lib.rs | 1 + src/util.rs | 177 +++++++++++++++++++++++++++++++++++------ 3 files changed, 171 insertions(+), 78 deletions(-) diff --git a/examples/stream.rs b/examples/stream.rs index f8cba09..69ef482 100644 --- a/examples/stream.rs +++ b/examples/stream.rs @@ -1,48 +1,30 @@ use std::{ collections::BTreeMap, - future::{Future, IntoFuture}, net::{Ipv4Addr, SocketAddr, SocketAddrV4}, }; use anyhow::{Context, Result}; -use futures_util::future::BoxFuture; use irpc::{ - channel::{mpsc, oneshot}, + channel::oneshot, rpc::RemoteService, rpc_requests, - util::{ - make_client_endpoint, make_server_endpoint, IrpcReceiverFutExt, MpscSenderExt, StreamItem, - }, + util::{make_client_endpoint, make_server_endpoint, MpscSenderExt, Progress, StreamSender}, Client, WithChannels, }; // Import the macro use n0_future::{ task::{self, AbortOnDropHandle}, - Stream, StreamExt, + StreamExt, }; use serde::{Deserialize, Serialize}; use tracing::info; -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Serialize, Deserialize, thiserror::Error)] +#[error("{message}")] struct Error { message: String, } -impl std::error::Error for Error {} - -impl std::fmt::Display for Error { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.message) - } -} - -#[derive(Debug, Serialize, Deserialize, StreamItem)] -enum GetItem { - Item(String), - Error(Error), - Done, -} - #[derive(Debug, Serialize, Deserialize)] struct Set { key: String, @@ -61,7 +43,7 @@ struct Get { enum StorageProtocol { #[rpc(tx=oneshot::Sender<()>)] Set(Set), - #[rpc(tx=mpsc::Sender)] + #[rpc(tx=StreamSender)] Get(Get), } @@ -70,31 +52,6 @@ struct StorageActor { state: BTreeMap, } -struct GetProgress { - fut: BoxFuture<'static, irpc::Result>>, -} - -impl GetProgress { - pub fn new( - fut: impl Future>> + Send + 'static, - ) -> Self { - Self { fut: Box::pin(fut) } - } - - pub fn stream(self) -> impl Stream> { - self.fut.into_stream() - } -} - -impl IntoFuture for GetProgress { - type Output = anyhow::Result; - type IntoFuture = BoxFuture<'static, Self::Output>; - - fn into_future(self) -> Self::IntoFuture { - Box::pin(self.fut.try_collect()) - } -} - impl StorageActor { pub fn spawn() -> StorageApi { let (tx, rx) = tokio::sync::mpsc::channel(1); @@ -164,8 +121,12 @@ impl StorageApi { Ok(AbortOnDropHandle::new(join_handle)) } - pub fn get(&self, key: String) -> GetProgress { - GetProgress::new(self.inner.server_streaming(Get { key }, 16)) + pub fn get(&self, key: String) -> Progress { + Progress::new(self.inner.server_streaming(Get { key }, 16)) + } + + pub fn get_vec(&self, key: String) -> Progress> { + Progress::new(self.inner.server_streaming(Get { key }, 16)) } pub async fn set(&self, key: String, value: String) -> irpc::Result<()> { @@ -174,9 +135,13 @@ impl StorageApi { } async fn client_demo(api: StorageApi) -> Result<()> { - api.set("hello".to_string(), "world".to_string()).await?; + api.set("hello".to_string(), "world and all".to_string()) + .await?; let value = api.get("hello".to_string()).await?; - println!("get: hello = {value:?}"); + println!("get (string): hello = {value:?}"); + + let value = api.get_vec("hello".to_string()).await?; + println!("get (vec): hello = {value:?}"); api.set("loremipsum".to_string(), "dolor sit amet".to_string()) .await?; diff --git a/irpc-derive/src/lib.rs b/irpc-derive/src/lib.rs index 2ffacbd..7277b76 100644 --- a/irpc-derive/src/lib.rs +++ b/irpc-derive/src/lib.rs @@ -611,6 +611,7 @@ fn vis_pub() -> syn::Visibility { }) } +// TODO(Frando): Remove if the generics approach works out fine? #[proc_macro_derive(StreamItem)] pub fn derive_irpc_stream_item(input: TokenStream) -> TokenStream { let input = parse_macro_input!(input as DeriveInput); diff --git a/src/util.rs b/src/util.rs index 2658b03..01bbd1f 100644 --- a/src/util.rs +++ b/src/util.rs @@ -450,11 +450,147 @@ mod now_or_never { pub(crate) use now_or_never::now_or_never; mod stream_item { - use std::{future::Future, io}; + use std::{ + future::{Future, IntoFuture}, + io, + marker::PhantomData, + }; + use futures_util::future::BoxFuture; use n0_future::{stream, Stream, StreamExt}; + use serde::{Deserialize, Serialize}; + + use crate::{ + channel::{mpsc, RecvError, SendError}, + RpcMessage, + }; + + /// Type alias for a [`mpsc::Sender`] of [`Item`]s. + pub type StreamSender = mpsc::Sender>; + + /// Type alias for a [`mpsc::Receiver`] of [`Item`]s. + pub type StreamReceiver = mpsc::Receiver>; + + /// The error type returned from fallible irpc stream extension methods. + /// + /// This is an enum with two variants: one for transport errors, and one + /// for errors returned from the remote service. + #[derive(thiserror::Error, Debug)] + pub enum StreamError { + /// Transport error. + #[error(transparent)] + Transport(#[from] crate::Error), + /// Error returned from the remote service. + #[error(transparent)] + Remote(E), + } + + impl From for StreamError { + fn from(value: crate::channel::RecvError) -> Self { + Self::Transport(value.into()) + } + } - use crate::channel::{mpsc, RecvError, SendError}; + pub type StreamResult = std::result::Result>; + + /// Wrapper for server-streaming RPC calls that return a stream of items. + /// + /// This struct wraps the future returned from [`crate::Client::server_streaming`] + /// if the response channel is a stream of [`Item`]. + /// + /// The [`Progress`] implements [`IntoFuture`], so it can be `await`ed directly + /// without further chaining. It will then `try_collect` all items into a collection, + /// as specified by the `C` generic on this struct. + /// + /// The [`Progress`] can also be turned into a stream of individual items by calling + /// [`Progress::into_stream`]. + pub struct Progress< + T: RpcMessage, + E: RpcMessage + std::error::Error, + C: Extend + Default = T, + > { + fut: BoxFuture<'static, crate::Result>>>, + _collection_type: PhantomData, + } + + impl Progress + where + T: RpcMessage, + E: RpcMessage + std::error::Error, + C: Extend + Default + Send, + { + pub fn new( + fut: impl Future>>> + Send + 'static, + ) -> Self { + Self { + fut: Box::pin(fut), + _collection_type: PhantomData, + } + } + + pub fn stream(self) -> impl Stream> { + self.fut.into_stream() + } + } + + impl IntoFuture for Progress + where + T: RpcMessage, + E: RpcMessage + std::error::Error, + C: Default + Extend + Send + 'static, + { + type Output = StreamResult; + type IntoFuture = BoxFuture<'static, Self::Output>; + + fn into_future(self) -> Self::IntoFuture { + Box::pin(self.fut.try_collect()) + } + } + + /// A fallible stream item. + /// + /// This is an enum with three variants, `Ok`, `Err`, and `Done`. + /// + /// It can be used as the item type for `mpsc` channels to force an explicit end of stream marker. + /// It implements [`StreamItem`] and throught that the use of [`MpscSenderExt`], [`IrpcReceiverFutExt`], + /// and [`Progress`]. + #[derive(Debug, Serialize, Deserialize, Clone)] + pub enum Item { + /// The stream item. + Ok(T), + /// The error case. + /// + /// No futher messages may be sent afterwards. + Err(E), + /// The end-of-stream marker. + /// + /// Send this as the last message to gracefully terminate a stream. + /// No further messages may be sent afterwards. + Done, + } + + impl StreamItem for Item { + type Item = T; + type Error = E; + fn into_result_opt(self) -> Option> { + match self { + Item::Ok(item) => Some(Ok(item)), + Item::Err(error) => Some(Err(error)), + Item::Done => None, + } + } + + fn from_result(item: std::result::Result) -> Self { + match item { + Ok(item) => Self::Ok(item), + Err(err) => Self::Err(err), + } + } + + fn done() -> Self { + Self::Done + } + } /// Trait for an enum that has three variants, item, error, and done. /// @@ -463,9 +599,9 @@ mod stream_item { /// for successful end of stream. pub trait StreamItem: crate::RpcMessage { /// The error case of the item enum. - type Error; + type Error: crate::RpcMessage + std::error::Error; /// The item case of the item enum. - type Item; + type Item: crate::RpcMessage; /// Converts the stream item into either None for end of stream, or a Result /// containing the item or an error. Error is assumed as a termination, so /// if you get error you won't get an additional end of stream marker. @@ -481,7 +617,6 @@ mod stream_item { /// /// This will convert items and errors into the item enum type, and add /// a done marker if the stream ends without an error. - #[allow(dead_code)] fn forward_stream( self, stream: impl Stream>, @@ -538,9 +673,7 @@ mod stream_item { fn try_collect(self) -> impl Future> where C: Default + Extend, - E: From, - E: From, - E: From; + E: From>; /// Converts the receiver returned by this future into a stream of items, /// where each item is either a successful item or an error. @@ -550,9 +683,7 @@ mod stream_item { /// first item and then terminate. fn into_stream(self) -> impl Stream> where - E: From, - E: From, - E: From; + E: From>; } impl IrpcReceiverFutExt for F @@ -563,9 +694,7 @@ mod stream_item { async fn try_collect(self) -> std::result::Result where C: Default + Extend, - E: From, - E: From, - E: From, + E: From>, { let mut items = C::default(); let mut stream = self.into_stream::(); @@ -580,9 +709,7 @@ mod stream_item { fn into_stream(self) -> impl Stream> where - E: From, - E: From, - E: From, + E: From>, { enum State { Init(S), @@ -597,24 +724,22 @@ mod stream_item { ) -> Option<(std::result::Result, State)> where T: StreamItem, - E: From, - E: From, - E: From, + E: From>, { match rx.recv().await { Ok(Some(item)) => match item.into_result_opt()? { Ok(i) => Some((Ok(i), State::Receiving(rx))), - Err(e) => Some((Err(E::from(e)), State::Done)), + Err(e) => Some((Err(E::from(StreamError::Remote(e))), State::Done)), }, - Ok(None) => Some((Err(E::from(eof())), State::Done)), - Err(e) => Some((Err(E::from(e)), State::Done)), + Ok(None) => Some((Err(E::from(StreamError::from(eof()))), State::Done)), + Err(e) => Some((Err(E::from(StreamError::from(e))), State::Done)), } } Box::pin(stream::unfold(State::Init(self), |state| async move { match state { State::Init(fut) => match fut.await { Ok(rx) => process_recv(rx).await, - Err(e) => Some((Err(E::from(e)), State::Done)), + Err(e) => Some((Err(E::from(StreamError::from(e))), State::Done)), }, State::Receiving(rx) => process_recv(rx).await, State::Done => None, @@ -629,4 +754,6 @@ mod stream_item { pub use irpc_derive::StreamItem; #[cfg(feature = "stream")] #[cfg_attr(quicrpc_docsrs, doc(cfg(feature = "stream")))] -pub use stream_item::{IrpcReceiverFutExt, MpscSenderExt, StreamItem}; +pub use stream_item::{ + IrpcReceiverFutExt, Item, MpscSenderExt, Progress, StreamItem, StreamReceiver, StreamSender, +};