Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,10 @@ required-features = ["derive"]
name = "storage"
required-features = ["rpc", "quinn_endpoint_setup"]

[[example]]
name = "stream"
required-features = ["rpc", "derive", "quinn_endpoint_setup"]

[workspace]
members = ["irpc-derive", "irpc-iroh"]

Expand Down
194 changes: 194 additions & 0 deletions examples/stream.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
use std::{
collections::BTreeMap,
net::{Ipv4Addr, SocketAddr, SocketAddrV4},
};

use anyhow::{Context, Result};
use irpc::{
channel::oneshot,
rpc::RemoteService,
rpc_requests,
util::{make_client_endpoint, make_server_endpoint, MpscSenderExt, Progress, StreamSender},
Client, WithChannels,
};
// Import the macro
use n0_future::{
task::{self, AbortOnDropHandle},
StreamExt,
};
use serde::{Deserialize, Serialize};
use tracing::info;

#[derive(Debug, Serialize, Deserialize, thiserror::Error)]
#[error("{message}")]
struct Error {
message: String,
}

#[derive(Debug, Serialize, Deserialize)]
struct Set {
key: String,
value: String,
}

#[derive(Debug, Serialize, Deserialize)]
struct Get {
key: String,
}

// Use the macro to generate both the StorageProtocol and StorageMessage enums
// plus implement Channels for each type
#[rpc_requests(message = StorageMessage)]
#[derive(Serialize, Deserialize, Debug)]
enum StorageProtocol {
#[rpc(tx=oneshot::Sender<()>)]
Set(Set),
#[rpc(tx=StreamSender<String, Error>)]
Get(Get),
}

struct StorageActor {
recv: tokio::sync::mpsc::Receiver<StorageMessage>,
state: BTreeMap<String, String>,
}

impl StorageActor {
pub fn spawn() -> StorageApi {
let (tx, rx) = tokio::sync::mpsc::channel(1);
let actor = Self {
recv: rx,
state: BTreeMap::new(),
};
n0_future::task::spawn(actor.run());
StorageApi {
inner: Client::local(tx),
}
}

async fn run(mut self) {
while let Some(msg) = self.recv.recv().await {
self.handle(msg).await;
}
}

async fn handle(&mut self, msg: StorageMessage) {
match msg {
StorageMessage::Get(get) => {
info!("get {:?}", get);
let WithChannels {
tx,
inner: Get { key },
..
} = get;
let value = self.state.get(&key).cloned().unwrap_or_default();
let parts = value.split_inclusive(" ");
tx.forward_iter(parts.map(|x| Ok(x.to_string()))).await.ok();
}
StorageMessage::Set(set) => {
info!("set {:?}", set);
let WithChannels {
tx,
inner: Set { key, value },
..
} = set;
self.state.insert(key, value);
tx.send(()).await.ok();
}
}
}
}

struct StorageApi {
inner: Client<StorageProtocol>,
}

impl StorageApi {
pub fn connect(endpoint: quinn::Endpoint, addr: SocketAddr) -> Result<StorageApi> {
Ok(StorageApi {
inner: Client::quinn(endpoint, addr),
})
}

pub fn listen(&self, endpoint: quinn::Endpoint) -> Result<AbortOnDropHandle<()>> {
let local = self
.inner
.as_local()
.context("cannot listen on remote API")?;
let join_handle = task::spawn(irpc::rpc::listen(
endpoint,
StorageProtocol::remote_handler(local),
));
Ok(AbortOnDropHandle::new(join_handle))
}

pub fn get(&self, key: String) -> Progress<String, Error> {
Progress::new(self.inner.server_streaming(Get { key }, 16))
}

pub fn get_vec(&self, key: String) -> Progress<String, Error, Vec<String>> {
Progress::new(self.inner.server_streaming(Get { key }, 16))
}

pub async fn set(&self, key: String, value: String) -> irpc::Result<()> {
self.inner.rpc(Set { key, value }).await
}
}

async fn client_demo(api: StorageApi) -> Result<()> {
api.set("hello".to_string(), "world and all".to_string())
.await?;
let value = api.get("hello".to_string()).await?;
println!("get (string): hello = {value:?}");

let value = api.get_vec("hello".to_string()).await?;
println!("get (vec): hello = {value:?}");

api.set("loremipsum".to_string(), "dolor sit amet".to_string())
.await?;

let mut parts = api.get("loremipsum".to_string()).stream();
while let Some(part) = parts.next().await {
match part {
Ok(item) => println!("Received item: {item}"),
Err(e) => println!("Error receiving item: {e}"),
}
}

Ok(())
}

async fn local() -> Result<()> {
let api = StorageActor::spawn();
client_demo(api).await?;
Ok(())
}

async fn remote() -> Result<()> {
let port = 10113;
let addr: SocketAddr = SocketAddrV4::new(Ipv4Addr::LOCALHOST, port).into();

let (server_handle, cert) = {
let (endpoint, cert) = make_server_endpoint(addr)?;
let api = StorageActor::spawn();
let handle = api.listen(endpoint)?;
(handle, cert)
};

let endpoint =
make_client_endpoint(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0).into(), &[&cert])?;
let api = StorageApi::connect(endpoint, addr)?;
client_demo(api).await?;

drop(server_handle);
Ok(())
}

#[tokio::main]
async fn main() -> Result<()> {
tracing_subscriber::fmt::init();
println!("Local use");
local().await?;
println!("Remote use");
remote().await.unwrap();
Ok(())
}
93 changes: 92 additions & 1 deletion irpc-derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use syn::{
parse_macro_input,
punctuated::Punctuated,
spanned::Spanned,
Data, DeriveInput, Fields, Ident, LitStr, Token, Type,
Data, DeriveInput, Fields, Ident, LitStr, Token, Type, Variant,
};

// Helper function for error reporting
Expand Down Expand Up @@ -610,3 +610,94 @@ fn vis_pub() -> syn::Visibility {
},
})
}

// TODO(Frando): Remove if the generics approach works out fine?
#[proc_macro_derive(StreamItem)]
pub fn derive_irpc_stream_item(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput);
let span = input.span();
let name = input.ident;
let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();

let data = if let Data::Enum(data) = input.data {
data
} else {
return error_tokens(span, "IrpcStreamItem can only be derived for enums");
};

let mut item_variant: Option<Variant> = None;
let mut error_variant: Option<Variant> = None;
let mut done_variant: Option<Variant> = None;

for variant in data.variants {
let vname = variant.ident.to_string();
match vname.as_str() {
"Item" => item_variant = Some(variant),
"Error" => error_variant = Some(variant),
"Done" => done_variant = Some(variant),
_ => return error_tokens(span, &format!("Unknown variant: {}", vname)),
}
}

let Some(item_var) = item_variant else {
return error_tokens(span, "Missing Item variant");
};
let Some(error_var) = error_variant else {
return error_tokens(span, "Missing Error variant");
};
let Some(done_var) = done_variant else {
return error_tokens(span, "Missing Done variant");
};

let item_field_ty = if let Fields::Unnamed(fields) = item_var.fields {
if fields.unnamed.len() == 1 {
fields.unnamed.into_iter().next().unwrap().ty
} else {
return error_tokens(span, "Item variant must have exactly one unnamed field");
}
} else {
return error_tokens(span, "Item variant must have unnamed fields");
};

let error_field_ty = if let Fields::Unnamed(fields) = error_var.fields {
if fields.unnamed.len() == 1 {
fields.unnamed.into_iter().next().unwrap().ty
} else {
return error_tokens(span, "Error variant must have exactly one unnamed field");
}
} else {
return error_tokens(span, "Error variant must have unnamed fields");
};

if !done_var.fields.is_empty() {
return error_tokens(span, "Done variant must be a unit variant with no fields");
}

let expanded = quote! {
impl #impl_generics StreamItem for #name #ty_generics #where_clause {
type Error = #error_field_ty;
type Item = #item_field_ty;

fn into_result_opt(self) -> Option<std::result::Result<<Self as StreamItem>::Item, <Self as StreamItem>::Error>> {
match self {
Self::Item(item) => Some(Ok(item)),
Self::Error(err) => Some(Err(err)),
Self::Done => None,
}
}

fn from_result(res: std::result::Result<<Self as StreamItem>::Item, <Self as StreamItem>::Error>) -> Self {
match res {
Ok(item) => Self::Item(item),
Err(err) => Self::Error(err),
}
}

fn done() -> Self {
Self::Done
}
}
};

TokenStream::from(expanded)
}
35 changes: 22 additions & 13 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,11 @@ use channel::{mpsc, none::NoSender, oneshot};
///
/// Basic usage example:
/// ```
/// use serde::{Serialize, Deserialize};
/// use irpc::{rpc_requests, channel::{oneshot, mpsc}};
/// use irpc::{
/// channel::{mpsc, oneshot},
/// rpc_requests,
/// };
/// use serde::{Deserialize, Serialize};
///
/// #[rpc_requests(message = ComputeMessage)]
/// #[derive(Debug, Serialize, Deserialize)]
Expand Down Expand Up @@ -155,20 +158,28 @@ use channel::{mpsc, none::NoSender, oneshot};
///
/// With `wrap`:
/// ```
/// use serde::{Serialize, Deserialize};
/// use irpc::{rpc_requests, channel::{oneshot, mpsc}, Client};
/// use irpc::{
/// channel::{mpsc, oneshot},
/// rpc_requests, Client,
/// };
/// use serde::{Deserialize, Serialize};
///
/// #[rpc_requests(message = StoreMessage)]
/// #[derive(Debug, Serialize, Deserialize)]
/// enum StoreProtocol {
/// #[rpc(wrap=GetRequest, tx=oneshot::Sender<String>)]
/// Get(String),
/// #[rpc(wrap=SetRequest, tx=oneshot::Sender<()>)]
/// Set { key: String, value: String }
/// Set { key: String, value: String },
/// }
///
/// async fn client_usage(client: Client<StoreProtocol>) -> anyhow::Result<()> {
/// client.rpc(SetRequest { key: "foo".to_string(), value: "bar".to_string() }).await?;
/// client
/// .rpc(SetRequest {
/// key: "foo".to_string(),
/// value: "bar".to_string(),
/// })
/// .await?;
/// let value = client.rpc(GetRequest("foo".to_string())).await?;
/// Ok(())
/// }
Expand All @@ -192,7 +203,6 @@ use channel::{mpsc, none::NoSender, oneshot};
#[cfg(feature = "derive")]
#[cfg_attr(quicrpc_docsrs, doc(cfg(feature = "derive")))]
pub use irpc_derive::rpc_requests;

use sealed::Sealed;
use serde::{de::DeserializeOwned, Serialize};

Expand Down Expand Up @@ -1274,6 +1284,11 @@ pub mod rpc {
};

use n0_future::{future::Boxed as BoxFuture, task::JoinSet};
/// This is used by irpc-derive to refer to quinn types (SendStream and RecvStream)
/// to make generated code work for users without having to depend on quinn directly
/// (i.e. when using iroh).
#[doc(hidden)]
pub use quinn;
use quinn::ConnectionError;
use serde::de::DeserializeOwned;
use smallvec::SmallVec;
Expand All @@ -1289,12 +1304,6 @@ pub mod rpc {
LocalSender, RequestError, RpcMessage, Service,
};

/// This is used by irpc-derive to refer to quinn types (SendStream and RecvStream)
/// to make generated code work for users without having to depend on quinn directly
/// (i.e. when using iroh).
#[doc(hidden)]
pub use quinn;

/// Default max message size (16 MiB).
pub const MAX_MESSAGE_SIZE: u64 = 1024 * 1024 * 16;

Expand Down
Loading
Loading